├── LICENSE.txt ├── Makefile ├── README ├── experimental ├── Makefile ├── sample-mm-main.cc ├── sample-mm.cc ├── sample-mm.h ├── sample-vmf-dp-mixture.cc └── sample-vmf-dp-mixture.h ├── gibbs-base.cc ├── gibbs-base.h ├── ncrp-base.cc ├── ncrp-base.h ├── sample-clustered-lda-main.cc ├── sample-clustered-lda.cc ├── sample-clustered-lda.h ├── sample-crosscat-mm-main.cc ├── sample-crosscat-mm.cc ├── sample-crosscat-mm.h ├── sample-fixed-ncrp-main.cc ├── sample-fixed-ncrp.cc ├── sample-fixed-ncrp.h ├── sample-gem-ncrp.cc ├── sample-gem-ncrp.h ├── sample-mult-ncrp.cc ├── sample-mult-ncrp.h ├── sample-precomputed-fixed-ncrp-main.cc ├── sample-precomputed-fixed-ncrp.cc ├── sample-precomputed-fixed-ncrp.h ├── sample-soft-crosscat-main.cc ├── sample-soft-crosscat.cc ├── sample-soft-crosscat.h ├── sampleCrossCatMixtureModel-Local.sh ├── sampleCrossCatMixtureModel.sh ├── sampleMultLDA-Local.sh ├── sampleMultLDA.sh ├── sampleMultNCRP-Local.sh ├── sampleMultNCRP.sh ├── sampleMultNCRPPRixFixe-Local.sh ├── sampleMultNCRPPRixFixe.sh ├── samplePrecomputedFixedNCRP-Local.sh ├── samplePrecomputedFixedNCRP.sh ├── sampleSoftCrossCatMixtureModel.sh ├── scripts ├── model_file_utils.py └── summarize_ncrp_model.py ├── strutil.cc ├── strutil.h └── test_data ├── testW ├── testWclustered ├── testWnoisy └── testWtree /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # XXXXXXXXXX this has to change for 64/32 2 | #INCLUDES = -I include/ -I /projects/nn/joeraii/local_libraries_mastodon/include/ 3 | # INCLUDES = -I include/ -I /projects/nn/joeraii/local_libraries/include/ 4 | INCLUDES = -I include/ -I /scratch/cluster/joeraii/ncrp/local_libraries/include/ 5 | CC = g++ 6 | # LDFLAGS = -L/p/lib -L/projects/nn/joeraii/local_libraries_mastodon/lib/ -L/projects/nn/joeraii/local_libraries/lib/ -L/p/lib/ 7 | LDFLAGS = -L/p/lib -L/scratch/cluster/joeraii/ncrp/local_libraries/lib/ -L/p/lib/ 8 | LIBRARIES = -lglog -lgflags -lpthread -lboost_iostreams-mt 9 | EXECUTABLES = sampleSoftCrossCatMixtureModel sampleMultNCRP sampleGEMNCRP sampleFixedNCRP samplePrecomputedFixedNCRP sampleClusteredLDA sampleCrossCatMixtureModel 10 | OBJECTS = dSFMT.o strutil.o gibbs-base.o ncrp-base.o sample-clustered-lda.o sample-precomputed-fixed-ncrp.o sample-fixed-ncrp.o sample-gem-ncrp.o sample-mult-ncrp.o sample-crosscat-mm.o sample-soft-crosscat.o 11 | MTFLAGS = -msse2 -DDSFMT_MEXP=521 -DHAVE_SSE2 --param max-inline-insns-single=1800 --param inline-unit-growth=500 --param large-function-growth=900 12 | CFLAGS = -O3 $(MTFLAGS) -DUSE_MT_RANDOM 13 | COMPILE = $(CC) $(CFLAGS) $(INCLUDES) 14 | FULLCOMPILE = $(COMPILE) $(LDFLAGS) $(LIBRARIES) 15 | 16 | all: $(OBJECTS) $(EXECUTABLES) 17 | 18 | dSFMT.o: dSFMT-src-2.0/dSFMT.c 19 | $(COMPILE) -c dSFMT-src-2.0/dSFMT.c -o dSFMT.o 20 | strutil.o: strutil.cc 21 | $(COMPILE) -c strutil.cc -o strutil.o 22 | gibbs-base.o: gibbs-base.cc 23 | $(COMPILE) -c gibbs-base.cc -o gibbs-base.o 24 | ncrp-base.o: ncrp-base.cc 25 | $(COMPILE) -c ncrp-base.cc -o ncrp-base.o 26 | sample-fixed-ncrp.o: sample-fixed-ncrp.h sample-fixed-ncrp.cc 27 | $(COMPILE) -c sample-fixed-ncrp.cc -o sample-fixed-ncrp.o 28 | sample-precomputed-fixed-ncrp.o: sample-precomputed-fixed-ncrp.h sample-precomputed-fixed-ncrp.cc 29 | $(COMPILE) -c sample-precomputed-fixed-ncrp.cc -o sample-precomputed-fixed-ncrp.o 30 | sample-clustered-lda.o: sample-clustered-lda.h sample-clustered-lda.cc 31 | $(COMPILE) -c sample-clustered-lda.cc -o sample-clustered-lda.o 32 | sample-gem-ncrp.o: sample-gem-ncrp.cc 33 | $(COMPILE) -c sample-gem-ncrp.cc -o sample-gem-ncrp.o 34 | sample-mult-ncrp.o: sample-mult-ncrp.cc 35 | $(COMPILE) -c sample-mult-ncrp.cc -o sample-mult-ncrp.o 36 | sample-crosscat-mm.o: sample-crosscat-mm.cc 37 | $(COMPILE) -c sample-crosscat-mm.cc -o sample-crosscat-mm.o 38 | sample-soft-crosscat.o: sample-soft-crosscat.cc 39 | $(COMPILE) -c sample-soft-crosscat.cc -o sample-soft-crosscat.o 40 | 41 | sampleMultNCRP: strutil.o dSFMT.o ncrp-base.o gibbs-base.o sample-mult-ncrp.cc sample-mult-ncrp.o 42 | $(FULLCOMPILE) strutil.o dSFMT.o sample-mult-ncrp.o ncrp-base.o gibbs-base.o -o sampleMultNCRP 43 | sampleGEMNCRP: strutil.o dSFMT.o ncrp-base.o gibbs-base.o sample-gem-ncrp.cc sample-gem-ncrp.o 44 | $(FULLCOMPILE) strutil.o dSFMT.o sample-gem-ncrp.o ncrp-base.o gibbs-base.o -o sampleGEMNCRP 45 | sampleFixedNCRP: strutil.o dSFMT.o ncrp-base.o gibbs-base.o sample-gem-ncrp.cc sample-fixed-ncrp.o 46 | $(FULLCOMPILE) sample-fixed-ncrp-main.cc strutil.o dSFMT.o sample-fixed-ncrp.o ncrp-base.o gibbs-base.o -o sampleFixedNCRP 47 | samplePrecomputedFixedNCRP: strutil.o dSFMT.o ncrp-base.o gibbs-base.o sample-gem-ncrp.cc sample-precomputed-fixed-ncrp.o sample-fixed-ncrp.o 48 | $(FULLCOMPILE) sample-precomputed-fixed-ncrp-main.cc strutil.o dSFMT.o gibbs-base.o sample-fixed-ncrp.o sample-precomputed-fixed-ncrp.o ncrp-base.o -o samplePrecomputedFixedNCRP 49 | sampleClusteredLDA: strutil.o dSFMT.o ncrp-base.o gibbs-base.o sample-clustered-lda.cc 50 | $(FULLCOMPILE) sample-clustered-lda-main.cc strutil.o dSFMT.o sample-clustered-lda.o ncrp-base.o gibbs-base.o -o sampleClusteredLDA 51 | sampleCrossCatMixtureModel: strutil.o dSFMT.o gibbs-base.o sample-crosscat-mm.cc 52 | $(FULLCOMPILE) sample-crosscat-mm-main.cc strutil.o dSFMT.o sample-crosscat-mm.o gibbs-base.o -o sampleCrossCatMixtureModel 53 | sampleSoftCrossCatMixtureModel: strutil.o dSFMT.o gibbs-base.o sample-soft-crosscat.o sample-soft-crosscat-main.cc 54 | $(FULLCOMPILE) sample-soft-crosscat-main.cc strutil.o dSFMT.o sample-soft-crosscat.o gibbs-base.o -o sampleSoftCrossCatMixtureModel 55 | sampleNonconjugateDP: strutil.o dSFMT.o gibbs-base.o sample-nonconjugate-dp.cc 56 | $(FULLCOMPILE) sample-nonconjugate-dp.cc strutil.o dSFMT.o sample-nonconjugate-dp.o gibbs-base.o -o sampleNonconjugateDP 57 | 58 | clean: 59 | -rm -f *.o *.so *.pyc *~ 60 | 61 | deepclean: clean 62 | -rm -f $(OBJECTS) 63 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | UTML Latent Variable Modeling Toolkit 2 | 3 | Joseph Reisinger 4 | joeraii@cs.utexas.edu 5 | 6 | ver. 0.00000...01 (alarmingly alpha) 7 | 8 | 9 | OVERVIEW 10 | 11 | Implements a bunch of multinomial-dirichlet latent variable models in C++, 12 | including: 13 | 14 | - Dirichlet Process Mixture Model 15 | - Latent Dirichlet Allocation 16 | - Nested Chinese Restaurant Process (hierarchical LDA) 17 | * fixed-depth multinomial 18 | * arbitrary depth w/ GEM sampler 19 | - Labeled LDA / Fixed-Structure hLDA 20 | - Tiered Clustering 21 | - Cross-Cutting Categorization 22 | - Soft Cross-Cutting Categorization 23 | - (EXPERIMENTAL) Clustered LDA (latent word model) 24 | 25 | I'm releasing this not because we need another Topic Modeling package, but 26 | because it includes cross-cutting categorization and tiered clustering, neither 27 | of which have packages I'm aware of. Also just in case people want to try and 28 | duplicate my research. 29 | 30 | If you're looking to do straight-up topic modeling, there are several far more 31 | mature, faster, excellent, better-looking packages: 32 | 33 | - MALLET (java): http://mallet.cs.umass.edu/ 34 | - R LDA (R): http://cran.r-project.org/web/packages/lda/ 35 | - Stanford Topic Modeling Toolkit (scala): http://nlp.stanford.edu/software/tmt/tmt-0.3/ 36 | - LDA-C (C): http://www.cs.princeton.edu/~blei/lda-c/index.html 37 | 38 | Also if you want to just do vanilla nCRP/hLDA, Dave's code is probably more 39 | reliable: 40 | 41 | - hLDA: http://www.cs.princeton.edu/~blei/downloads/hlda-c.tgz 42 | 43 | A lot of the common math/stats routines were ripped off from the samplib and 44 | stats source files included in Hal's Hierarchical Bayes Compiler: 45 | 46 | - hbc: http://www.umiacs.umd.edu/~hal/HBC/ 47 | 48 | 49 | COMPILING 50 | 51 | You're going to need several packages, freely available: 52 | 53 | - google-logging (glog): http://code.google.com/p/google-glog/ 54 | - google-gflags (gflags): http://code.google.com/p/google-gflags/ 55 | - google-sparsehash: http://code.google.com/p/google-sparsehash/ 56 | - Fast Mersenne Twister (DSFMT): http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/SFMT/ 57 | - (INCLUDED) strutil.h from google-protobuf: http://code.google.com/p/protobuf/ 58 | 59 | Build all those and install in the normal way. 60 | 61 | Then DONT just type 'make' ; first, go look at the Makefile and check all the 62 | paths are right. They're not, unless you're me. Fix those. Then just type 'make' 63 | 64 | 65 | RUNNING 66 | 67 | To see stdout from the various models you need GLOG_logtostderr=1 set in your 68 | environment. Here are some example invocations: 69 | 70 | (ncrp w/ depth 5 and depth-dependent eta scaling) 71 | 72 | GLOG_logtostderr=1 ./sampleMultNCRP \ 73 | --ncrp_datafile=data.txt \ 74 | --ncrp_depth=5 \ 75 | --ncrp_eta_depth_scale=0.5 76 | 77 | 78 | (lda w/ 50 topics) 79 | 80 | GLOG_logtostderr=1 ./sampleMultNCRP \ 81 | --ncrp_datafile=data.txt \ 82 | --ncrp_depth=50 \ 83 | --ncrp_max_branches=1 84 | 85 | 86 | (soft cross-cat) 87 | 88 | GLOG_logtostderr=1 ./sampleSoftCrossCatMixtureModel \ 89 | --mm_alpha=1.0 \ 90 | --eta=1.0 \ 91 | --mm_datafile=data.txt \ 92 | --M=2 \ 93 | --implementation=marginal 94 | 95 | All the scripts can be called with --help to list the various flags. 96 | 97 | The datafile format for all the models is: 98 | 99 | [NAME] [term_1]:[count] [term_2]:[count] ... [term_N]:[count] 100 | 101 | Where NAME is the (not necessarily unique) name of the document. All separating 102 | whitespace should be \t not ' ', as ' ' can be included in term names. 103 | 104 | 105 | 106 | THE FUTURE 107 | 108 | I have a ton of python scripts for dealing with the various samples generated by 109 | these models. Eventually I'll include those as well. Feel free to ask for them; 110 | more user demand = faster turn around. 111 | 112 | 113 | LICENSE 114 | 115 | Copyright 2010 Joseph Reisinger 116 | 117 | Licensed under the Apache License, Version 2.0 (the "License"); 118 | you may not use this file except in compliance with the License. 119 | You may obtain a copy of the License at 120 | 121 | http://www.apache.org/licenses/LICENSE-2.0 122 | 123 | Unless required by applicable law or agreed to in writing, software 124 | distributed under the License is distributed on an "AS IS" BASIS, 125 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 126 | See the License for the specific language governing permissions and 127 | limitations under the License. 128 | 129 | -------------------------------------------------------------------------------- /experimental/Makefile: -------------------------------------------------------------------------------- 1 | 2 | sample-mm.o: sample-mm.cc 3 | $(COMPILE) -c sample-mm.cc -o sample-mm.o 4 | sampleMixtureModel: strutil.o dSFMT.o gibbs-base.o sample-mm.cc 5 | $(FULLCOMPILE) sample-mm-main.cc strutil.o dSFMT.o sample-mm.o gibbs-base.o -o sampleMixtureModel 6 | # samplevMFDPMixture: strutil.o dSFMT.o sample-vmf-dp-mixture.cc gibbs-base.o 7 | # $(FULLCOMPILE) sample-vmf-dp-mixture.cc strutil.o dSFMT.o gibbs-base.o -o samplevMFDPMixture 8 | -------------------------------------------------------------------------------- /experimental/sample-mm-main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "sample-mm.h" 17 | 18 | int main(int argc, char **argv) { 19 | google::InitGoogleLogging(argv[0]); 20 | google::ParseCommandLineFlags(&argc, &argv, true); 21 | 22 | init_random(); 23 | 24 | MM h = MM(); 25 | h.load_data(FLAGS_mm_datafile); 26 | h.initialize(); 27 | h.run(); 28 | } 29 | -------------------------------------------------------------------------------- /experimental/sample-mm.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Samples from a multinomial mixture with clutster proportion smoother xi 17 | // and data likelihood smoother eta. 18 | // 19 | // Also builds in support for explicit feature selection via the prix-fixe topic 20 | // model 21 | // 22 | 23 | #include 24 | #include 25 | 26 | #include "gibbs-base.h" 27 | #include "sample-mm.h" 28 | 29 | // the number of clusters 30 | DEFINE_int32(K, 31 | 50, 32 | "Number of clusters."); 33 | 34 | // Prior over cluster sizes 35 | DEFINE_string(mm_prior, 36 | "dirichlet", 37 | "Prior over clusters. Can be dirichlet, dirichlet-process, or uniform."); 38 | 39 | // Smoother on clustering 40 | DEFINE_double(mm_xi, 41 | 1.0, 42 | "Smoother on the cluster assignments."); 43 | 44 | // Smoother on cluster likelihood 45 | DEFINE_double(mm_beta, 46 | 1.0, 47 | "Smoother on the cluster likelihood."); 48 | 49 | // Smoother on data/noise assignment 50 | DEFINE_double(mm_alpha, 51 | 1.0, 52 | "Smoother on the data/noise assignment."); 53 | 54 | // Number of noise topics 55 | DEFINE_int32(N, 56 | 0, 57 | "Number of noise topics"); 58 | 59 | // File holding the data 60 | DEFINE_string(mm_datafile, 61 | "", 62 | "Docify holding the data to be clustered."); 63 | 64 | const string kDirichletProcess = "dirichlet-process"; 65 | const string kDirichletMixture = "dirichlet"; 66 | const string kUniformMixture = "uniform"; 67 | 68 | void MM::initialize() { 69 | LOG(INFO) << "initialize"; 70 | 71 | CHECK(FLAGS_mm_prior == kDirichletMixture || 72 | FLAGS_mm_prior == kDirichletProcess || 73 | FLAGS_mm_prior == kUniformMixture); 74 | 75 | _current_component = FLAGS_K; 76 | 77 | _N = FLAGS_N+1; // Total number of topic components 78 | 79 | // Initialize the per-topic dirichlet parameters 80 | // NOTE: in reality this would actually have to be /per topic/ as in one 81 | // parameter per node in the hierarchy. But since its not used for now its 82 | // ok. 83 | for (int l = 0; l < FLAGS_K; l++) { 84 | _xi.push_back(FLAGS_mm_xi); 85 | } 86 | _xi_sum = FLAGS_K*FLAGS_mm_xi; 87 | 88 | // Topic Smoother (there are N+1 topics, noise + signal) 89 | for (int l = 0; l < _N; l++) { 90 | _alpha.push_back(FLAGS_mm_alpha); 91 | } 92 | _alpha_sum = _N*FLAGS_mm_alpha; 93 | 94 | 95 | // Set up likelihood smoother 96 | for (int l = 0; l < _lV; l++) { 97 | _beta.push_back(FLAGS_mm_beta); 98 | } 99 | _beta_sum = FLAGS_mm_beta*_lV; 100 | 101 | // Set up clusters 102 | for (int l = 0; l < FLAGS_K; l++) { 103 | _phi.insert(pair(l, CRP())); 104 | } 105 | 106 | // Add each document to a cluster and optionally allocate some of its words 107 | // to the noise topics 108 | for (DocumentMap::iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 109 | unsigned d = d_itr->first; // = document number 110 | 111 | // set a random cluster assignment for this guy 112 | _c[d] = sample_integer(FLAGS_K); 113 | 114 | // Initial word assignments 115 | for (int n = 0; n < d_itr->second.size(); n++) { 116 | unsigned w = d_itr->second[n]; 117 | 118 | // Topic assignment 119 | _z[d][n] = sample_integer(_N); 120 | 121 | if (_z[d][n] == 0) { 122 | // test the initialization of maps 123 | CHECK(_phi[_c[d]].nw.find(w) != _phi[_c[d]].nw.end() 124 | || _phi[_c[d]].nw[w] == 0); 125 | 126 | _phi[_c[d]].add(w, d); 127 | } else { 128 | _phi_noise[_z[d][n]].add(w, d); 129 | } 130 | } 131 | if (d > 0 && FLAGS_mm_prior != kDirichletProcess) { // interacts badly with DP, at least asserts do 132 | CHECK((_phi.size() == FLAGS_K) || FLAGS_mm_prior == kDirichletProcess) << "doc " << d; 133 | resample_posterior_c_for(d); 134 | resample_posterior_z_for(d); 135 | CHECK((_phi.size() == FLAGS_K) || FLAGS_mm_prior == kDirichletProcess) << "doc " << d; 136 | } 137 | } 138 | 139 | // Cull the DP assignments 140 | if (FLAGS_mm_prior == kDirichletProcess) { 141 | for (google::dense_hash_map::iterator itr = _phi.begin(); 142 | itr != _phi.end(); 143 | itr++) { 144 | unsigned l = itr->first; 145 | if (itr->second.ndsum == 0) { // empty component 146 | _phi.erase(itr); 147 | VLOG(1) << "erasing component"; 148 | } 149 | } 150 | } 151 | 152 | _ll = compute_log_likelihood(); 153 | } 154 | 155 | double MM::document_slice_log_likelihood(unsigned d, unsigned l) { 156 | // XXX: all of this stuff might be wrong basically; should be ratios of 157 | // gammas because we need to integrate over all possible orderings? 158 | double log_lik = 0; 159 | CHECK((l < FLAGS_K) || FLAGS_mm_prior == kDirichletProcess); 160 | for (int n = 0; n < _D[d].size(); n++) { 161 | if (_z[d][n] == 0) { 162 | // likelihood of drawing this word 163 | unsigned w = _D[d][n]; 164 | log_lik += log(_phi[l].nw[w]+_beta[w]) - log(_phi[l].nwsum+_beta_sum); 165 | 166 | // Topic model part 167 | if (FLAGS_N > 0) { 168 | log_lik += log(_alpha[0] + _phi[l].nd[d]) - log(_alpha_sum + _nd[d]); 169 | } 170 | // CHECK_LE(_phi[l].nd[d], _nd[d]); 171 | // CHECK_LE(_phi[l].nw[w], _phi[l].nwsum); 172 | // CHECK_LE(_alpha[0], _alpha_sum); 173 | // CHECK_LE(_beta[w], _beta_sum); 174 | } 175 | } 176 | return log_lik; 177 | } 178 | 179 | // Reallocate this document's words between the noise and signal topics 180 | void MM::resample_posterior_z_for(unsigned d) { 181 | for (int n = 0; n < _D[d].size(); n++) { 182 | unsigned w = _D[d][n]; 183 | // Remove this document and word from the counts 184 | if (_z[d][n] == 0) { 185 | _phi[_c[d]].remove(w, d); 186 | } else { 187 | _phi_noise[_z[d][n]].remove(w, d); 188 | } 189 | 190 | vector lp_z_dn; 191 | for (int l = 0; l < _N; l++) { 192 | if (l == 0) { 193 | // TODO: check the normalizer on documents here; do we normalize 194 | // to document length or to topic docs? 195 | 196 | lp_z_dn.push_back(log(_beta[w] + _phi[_c[d]].nw[w]) - 197 | log(_beta_sum + _phi[_c[d]].nwsum) + 198 | log(_alpha[0] + _phi[_c[d]].nd[d]) - 199 | log(_alpha_sum + _nd[d])); 200 | } else { 201 | lp_z_dn.push_back(log(_beta[w] + _phi_noise[l].nw[w]) - 202 | log(_beta_sum + _phi_noise[l].nwsum) + 203 | log(_alpha[l] + _phi_noise[l].nd[d]) - 204 | log(_alpha_sum + _nd[d])); 205 | } 206 | } 207 | 208 | // Update the assignment 209 | _z[d][n] = sample_unnormalized_log_multinomial(&lp_z_dn); 210 | 211 | // Update the counts 212 | if (_z[d][n] == 0) { 213 | _phi[_c[d]].add(w, d); 214 | } else { 215 | _phi_noise[_z[d][n]].add(w, d); 216 | } 217 | } 218 | } 219 | 220 | // Performs a single document's level assignment resample step 221 | void MM::resample_posterior_c_for(unsigned d) { 222 | unsigned old_assignment = _c[d]; 223 | CHECK_LT(d, _lD); 224 | 225 | // Remove document d from the clustering 226 | for (int n = 0; n < _D[d].size(); n++) { 227 | if (_z[d][n] == 0) { 228 | unsigned w = _D[d].at(n); 229 | _phi[_c[d]].remove(w,d); 230 | } 231 | } 232 | 233 | CHECK_LE(_phi[_c[d]].ndsum, _lD); 234 | 235 | vector > lp_c_d; 236 | 237 | unsigned test_ndsum = 0; 238 | for (google::dense_hash_map::iterator itr = _phi.begin(); 239 | itr != _phi.end(); 240 | itr++) { 241 | unsigned l = itr->first; 242 | 243 | double log_lik = 0; 244 | 245 | // First add in the prior over the clusters 246 | if (FLAGS_mm_prior == kDirichletMixture) { 247 | log_lik += log(_xi[l] + _phi[l].ndsum) - log(_xi_sum + _lD); 248 | } else if (FLAGS_mm_prior == kDirichletProcess) { 249 | log_lik += log(_phi[l].ndsum) - log(_lD - 1 + FLAGS_mm_xi); 250 | test_ndsum += _phi[l].ndsum; 251 | } 252 | 253 | // Now account for the likelihood of the data (marginal posterior of 254 | // DP-Mult) 255 | CHECK((l < FLAGS_K) || FLAGS_mm_prior == kDirichletProcess) << "A"; 256 | log_lik += document_slice_log_likelihood(d, l); 257 | 258 | lp_c_d.push_back(pair(l, log_lik)); 259 | } 260 | // CHECK_EQ(test_ndsum, _lD-1); 261 | 262 | // Add an additional new component if DP 263 | if (FLAGS_mm_prior == kDirichletProcess) { 264 | double log_lik = log(FLAGS_mm_xi) - log(_lD - 1 + FLAGS_mm_xi); 265 | for (int n = 0; n < _D[d].size(); n++) { 266 | if (_z[d][n] == 0) { 267 | unsigned w = _D[d][n]; 268 | log_lik += log(_beta[w]) - log(_beta_sum); 269 | } 270 | } 271 | 272 | lp_c_d.push_back(pair(_current_component, log_lik)); 273 | } 274 | 275 | // Update the assignment 276 | _c[d] = sample_unnormalized_log_multinomial(&lp_c_d); 277 | VLOG(1) << "resampling posterior c for " << d << ": " << old_assignment << "->" << _c[d]; 278 | 279 | 280 | // Add document d back to the clustering at the new assignment 281 | for (int n = 0; n < _D[d].size(); n++) { 282 | if (_z[d][n] == 0) { 283 | unsigned w = _D[d].at(n); 284 | _phi[_c[d]].add(w,d); 285 | } 286 | } 287 | 288 | CHECK_LE(_phi[_c[d]].ndsum, _lD); 289 | 290 | // Clean up for the DPMM 291 | if (FLAGS_mm_prior == kDirichletProcess) { 292 | if (_phi[old_assignment].ndsum == 0) { // empty component 293 | _phi.erase(old_assignment); 294 | } 295 | // Make room for a new component if we selected the new one 296 | if (_c[d] == _current_component) { 297 | _current_component += 1; 298 | } 299 | } 300 | } 301 | 302 | 303 | double MM::compute_log_likelihood() { 304 | // Compute the log likelihood for the tree 305 | double log_lik = 0; 306 | 307 | // Compute the log likelihood of the level assignments (correctly?) 308 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 309 | unsigned d = d_itr->first; 310 | 311 | // Add in prior over clusters 312 | if (FLAGS_mm_prior == kDirichletMixture) { 313 | log_lik += log(_phi[_c[d]].ndsum+_xi[_c[d]]) - log(_lD +_xi_sum); 314 | CHECK_LT(_c[d], FLAGS_K); 315 | CHECK_LE(_phi[_c[d]].ndsum, _lD); 316 | CHECK_LE(_xi[_c[d]], _xi_sum); 317 | } else if (FLAGS_mm_prior == kDirichletProcess) { 318 | log_lik += log(_phi[_c[d]].ndsum) - log(_lD - 1 + FLAGS_mm_xi); 319 | } 320 | 321 | // Log-likelihood of the slice of the document accounted for by _c[d] 322 | CHECK((_c[d] < FLAGS_K) || FLAGS_mm_prior == kDirichletProcess) << "B"; 323 | log_lik += document_slice_log_likelihood(d, _c[d]); 324 | 325 | // Account for the noise assignments 326 | for (int n = 0; n < _D[d].size(); n++) { 327 | if (_z[d][n] > 0) { 328 | // likelihood of drawing this word 329 | unsigned w = _D[d][n]; 330 | log_lik += log(_phi_noise[_z[d][n]].nw[w]+_beta[w]) - log(_phi_noise[_z[d][n]].nwsum+_beta_sum) 331 | + log(_alpha[_z[d][n]] + _phi_noise[_z[d][n]].nd[d]) - log(_alpha_sum + _nd[d]); 332 | CHECK_LE(_phi_noise[_z[d][n]].nd[d], _nd[d]); 333 | } 334 | } 335 | CHECK_LE(log_lik, 0) << "hello"; 336 | 337 | 338 | } 339 | return log_lik; 340 | } 341 | 342 | string MM::current_state() { 343 | _output_filename = FLAGS_mm_datafile; 344 | _output_filename += StringPrintf("-xi%f-beta%f", 345 | _xi_sum / (double)_xi.size(), 346 | _beta_sum / (double)_beta.size()); 347 | 348 | return StringPrintf( 349 | "ll = %f (%f at %d) xi = %f beta = %f alpha = %f K = %d N = %d", 350 | _ll, _best_ll, _best_iter, 351 | _xi_sum / (double)_xi.size(), 352 | _beta_sum / (double)_beta.size(), 353 | _alpha_sum / (double)_alpha.size(), 354 | _phi.size(), FLAGS_N); 355 | } 356 | 357 | void MM::resample_posterior() { 358 | CHECK_GT(_lV, 0); 359 | CHECK_GT(_lD, 0); 360 | CHECK_GT(_phi.size(), 0); 361 | 362 | // Interleaved version 363 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 364 | unsigned d = d_itr->first; 365 | 366 | VLOG(1) << " resampling document " << d; 367 | CHECK((_phi.size() == FLAGS_K) || FLAGS_mm_prior == kDirichletProcess); 368 | resample_posterior_c_for(d); 369 | resample_posterior_z_for(d); 370 | } 371 | 372 | print_cluster_summary(); 373 | print_noise_summary(); 374 | } 375 | 376 | // Write out all the data in an intermediate format 377 | void MM::write_data(string prefix) { 378 | // string filename = StringPrintf("%s-%d-%s.hlda.bz2", get_base_name(_output_filename).c_str(), FLAGS_random_seed, 379 | // prefix.c_str()); 380 | string filename = StringPrintf("%s-%s-N%d-%d-%s.hlda", get_base_name(_output_filename).c_str(), 381 | FLAGS_mm_prior.c_str(), FLAGS_N, FLAGS_random_seed, 382 | prefix.c_str()); 383 | VLOG(1) << "writing data to [" << filename << "]"; 384 | 385 | ofstream f(filename.c_str(), ios_base::out | ios_base::binary); 386 | 387 | // get_bz2_ostream(filename, f); 388 | 389 | f << current_state() << endl; 390 | 391 | 392 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 393 | unsigned d = d_itr->first; 394 | 395 | f << _c[d] << "\t" << _document_name[d] << endl; 396 | } 397 | 398 | f << "NOISE ASSIGNMENT" << endl; 399 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 400 | unsigned d = d_itr->first; 401 | 402 | f << _document_name[d]; 403 | for (int n = 0; n < _D[d].size(); n++) { 404 | unsigned w = _D[d][n]; 405 | f << "\t" << _word_id_to_name[w] << ":" << _z[d][n]; 406 | } 407 | f << endl; 408 | } 409 | } 410 | 411 | 412 | // Prints out the top few features from each cluster 413 | void MM::print_noise_summary() { 414 | for (clustering::iterator itr = _phi_noise.begin(); 415 | itr != _phi_noise.end(); 416 | itr++) { 417 | unsigned l = itr->first; 418 | 419 | if (l == 0) { 420 | CHECK_EQ(itr->second.nwsum, 0); 421 | } else { 422 | string buffer = show_chopped_sorted_nw(itr->second.nw); 423 | 424 | // Convert the ublas vector into a vector of pairs for sorting 425 | LOG(INFO) << "N[" << l << "] (" << StringPrintf("%.3f\%", itr->second.nwsum / (double)_total_word_count) << ") " << " " << buffer; 426 | } 427 | } 428 | } 429 | 430 | // Prints out the top few features from each cluster 431 | void MM::print_cluster_summary() { 432 | // Write the current cluster sizes to the console 433 | string s; 434 | for (google::dense_hash_map::iterator itr = _phi.begin(); 435 | itr != _phi.end(); 436 | itr++) { 437 | unsigned l = itr->first; 438 | s += StringPrintf("%d:%d ", l, _phi[l].ndsum); 439 | } 440 | LOG(INFO) << s; 441 | 442 | // Show the contents of the clusters 443 | for (clustering::iterator itr = _phi.begin(); 444 | itr != _phi.end(); 445 | itr++) { 446 | unsigned l = itr->first; 447 | 448 | string buffer = show_chopped_sorted_nw(itr->second.nw); 449 | 450 | // Convert the ublas vector into a vector of pairs for sorting 451 | LOG(INFO) << "C[" << l << "] (d " << itr->second.ndsum << ") " << " " << buffer; 452 | } 453 | } 454 | -------------------------------------------------------------------------------- /experimental/sample-mm.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Basic implementation of clustered LDA with multinomial likelihood 17 | 18 | #ifndef SAMPLE_CLUSTERED_LDA_H_ 19 | #define SAMPLE_CLUSTERED_LDA_H_ 20 | 21 | #include 22 | #include 23 | 24 | #include "gibbs-base.h" 25 | 26 | // Number of clusters 27 | DECLARE_int32(K); 28 | 29 | // Prior over cluster sizes 30 | DECLARE_string(mm_prior); 31 | 32 | // Smoother on clustering 33 | DECLARE_double(mm_xi); 34 | 35 | // Smoother on cluster likelihood 36 | DECLARE_double(mm_beta); 37 | 38 | // Smoother on data/noise assignment 39 | DECLARE_double(mm_alpha); 40 | 41 | // Number of noise topics 42 | DECLARE_int32(N); 43 | 44 | // File holding the data 45 | DECLARE_string(mm_datafile); 46 | 47 | // Implements several kinds of mixture models (uniform prior, Dirichlet prior, 48 | // DPMM all with DP-Multinomial likelihood. 49 | class MM : public GibbsSampler { 50 | public: 51 | MM() { 52 | _c.set_empty_key(kEmptyUnsignedKey); 53 | _z.set_empty_key(kEmptyUnsignedKey); 54 | _phi.set_empty_key(kEmptyUnsignedKey); 55 | _phi.set_deleted_key(kDeletedUnsignedKey); 56 | _phi_noise.set_empty_key(kEmptyUnsignedKey); 57 | _phi_noise.set_deleted_key(kDeletedUnsignedKey); 58 | } 59 | 60 | // Set up initial assignments and load the doc->word and word->feature maps 61 | void initialize(); 62 | 63 | double compute_log_likelihood(); 64 | 65 | void write_data(string prefix); 66 | protected: 67 | void resample_posterior(); 68 | void resample_posterior_c_for(unsigned d); 69 | void resample_posterior_z_for(unsigned d); 70 | 71 | string current_state(); 72 | 73 | double document_slice_log_likelihood(unsigned d, unsigned l); 74 | 75 | void print_cluster_summary(); 76 | void print_noise_summary(); 77 | protected: 78 | // Maps documents to clusters 79 | cluster_map _c; // Map data point -> cluster 80 | clustering _phi; // Map [w][z] -> CRP 81 | 82 | topic_map _z; // Map word to noise or data 83 | clustering _phi_noise; // Distribution for the noise 84 | 85 | vector _beta; // Smoother for document likelihood 86 | double _beta_sum; 87 | 88 | vector _alpha; // Smoother for noise / data assignment 89 | double _alpha_sum; 90 | 91 | vector _xi; // Smoother for word likelihood 92 | double _xi_sum; 93 | 94 | unsigned _N; // Total number of topic components (noise + 1) 95 | 96 | // Base names of the flags 97 | string _word_features_moniker; 98 | string _datafile_moniker; 99 | 100 | unsigned _ndsum; 101 | 102 | unsigned _current_component; // # of clusters currently 103 | 104 | string _output_filename; 105 | }; 106 | 107 | #endif // SAMPLE_CLUSTERED_LDA_H_ 108 | -------------------------------------------------------------------------------- /experimental/sample-vmf-dp-mixture.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Neal's Algorithm 8 for vMF-DP mixture 17 | // G0 = vMF(mu, kappa) 18 | // c_k ~ DP(G0, alpha) 19 | // phi_k ~ vMF(c_k, xi) 20 | 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #include "sample-vmf-dp-mixture.h" 27 | 28 | 29 | #define BOUNDLO(x) (((x)(1-FLAGS_epsilon_value))?(1-FLAGS_epsilon_value):(x)) 31 | #define BOUND(x) (BOUNDLO(BOUNDHI(x))) 32 | #define BOUNDPROB(x) (((x)<-300)?(-300):((((x)>300)?(300):(x)))) 33 | 34 | using namespace boost::numeric; 35 | 36 | const string kDirichletProcess = "dirichlet-process"; 37 | const string kDirichletMixture = "dirichlet"; 38 | const string kUniformMixture = "uniform"; 39 | 40 | 41 | const unsigned kM = 4; 42 | 43 | // The number of MH iterations to run for. 44 | DEFINE_int32(iterations, 45 | 99999, 46 | "Number of MH sampling iterations to run."); 47 | 48 | // Initial number of clusters 49 | DEFINE_int32(K_initial, 50 | 10, 51 | "Number of clusters initially."); 52 | 53 | // The input data set. Consists of one document per line, with words separated 54 | // by tabs, and appended with their counts. The first word in each document is 55 | // taken as the document name. 56 | DEFINE_string(datafile, 57 | "", 58 | "the input data set, words arranged in documents"); 59 | 60 | // Alpha controls the topic smoothing, with higher alpha causing more "uniform" 61 | // distributions over topics. 62 | DEFINE_double(alpha, 63 | 1.0, 64 | "Topic smoothing."); 65 | 66 | // Kappa controls the width of the vMF clusters. 67 | DEFINE_double(kappa, 68 | 100, 69 | "vMF cluster concentration."); 70 | 71 | // xi is the corpus concentration 72 | DEFINE_double(xi, 73 | 0.001, 74 | "Corpus mean concentration parameter."); 75 | 76 | // How many steps to sample before we tune the MH random walk probabilities 77 | DEFINE_int32(tune_interval, 78 | 100, 79 | "how many MH iterations to perform before tuning proposals"); 80 | 81 | // Restricts the topics (phis) to live on the positive orthant by chopping all 82 | // negative entries in the vector (sparsifying). 83 | DEFINE_bool(restrict_topics, 84 | false, 85 | "if enabled, chops out negative entries in the topics"); 86 | 87 | // A unique identifier to add to the output moniker (useful for condor) 88 | DEFINE_string(my_id, 89 | "null", 90 | "a unique identifier to add to the output file moniker"); 91 | 92 | // Returns a string summarizing the current state of the sampler. 93 | string vMFDPMixture::current_state() { 94 | return StringPrintf("ll = %f (%f at %d) xi=%.3f || accept %: phi %.3f", _ll, _best_ll, 95 | _best_iter, _xi, 1.0-_rejected_phi/(double)_proposed_phi); 96 | } 97 | 98 | void vMFDPMixture::initialize() { 99 | LOG(INFO) << "initialize"; 100 | LOG(INFO) << "Neal's Algorithm 8 nonconjugate DP sampler using m=" << kM; 101 | 102 | _best_ll = 0; 103 | _best_iter = 0; 104 | 105 | _K = FLAGS_K_initial; 106 | 107 | _filename = FLAGS_datafile; 108 | 109 | // Set up the mean hyperparameter 110 | _mu.resize(_lV); 111 | for (int i = 0; i < _lV; i++) { 112 | _mu[i] = 1.0; 113 | } 114 | _mu /= norm_2(_mu); 115 | CHECK_LT(fabs(1.0-norm_2(_mu)), FLAGS_epsilon_value); 116 | 117 | // XXX: Maybe better to do this 118 | // Topic distribution mean 119 | // _mu = sample_spherical_gaussian(_mu_mean, 1.0/_mu_kappa); 120 | // CHECK_LT(fabs(1.0-norm_2(_mu)), epsilonValue); 121 | 122 | // Set up the per topic phis 123 | for (int t = 0; t < _K; t++) { 124 | // Originally we sampled this from a vMF, but its hellishly slow 125 | // _phi[t] = sample_vmf(_mu, _xi); 126 | // _phi[t] = sample_spherical_gaussian(_mu, _xi); 127 | // CHECK_LT(fabs(1.0-norm_2(_phi[t])), epsilonValue); 128 | _phi[t] = propose_new_phi(sample_spherical_gaussian(_mu, _xi)); 129 | CHECK_LT(fabs(1.0-norm_2(_phi[t])), FLAGS_epsilon_value); 130 | } 131 | _proposal_variance_phi = 1.0; 132 | _rejected_phi = 0; 133 | 134 | 135 | // Initialize the DP part 136 | _current_component = _K; 137 | 138 | for (int l = 0; l < _K; l++) { 139 | _c.insert(pair(l, CRP())); 140 | } 141 | 142 | // For each document, allocate a topic path for it there are several ways to 143 | // do this, e.g. a single linear chain, incremental conditional sampling and 144 | // random tree 145 | for (int d = 0; d < _lD; d++) { 146 | // set a random topic assignment for this guy 147 | _z[d] = sample_integer(_K); 148 | _c[_z[d]].nd[0] += 1; // # of words in doc d with topic z 149 | 150 | // resample_posterior_z_for(d); 151 | } 152 | 153 | // Cull the DP assignments 154 | for (clustering::iterator itr = _c.begin(); 155 | itr != _c.end(); 156 | itr++) { 157 | unsigned l = itr->first; 158 | if (itr->second.nd[0] == 0) { // empty component 159 | _c.erase(itr); 160 | _phi.erase(l); 161 | } 162 | } 163 | 164 | // DCHECK(tree_is_consistent()); 165 | _ll = compute_log_likelihood(); 166 | 167 | _iteration = 0; 168 | } 169 | 170 | // This is the main workhorse function, calling each of the node samplers in 171 | // turn. 172 | void vMFDPMixture::resample_posterior() { 173 | for (int d = 0; d < _lD; d++) { 174 | resample_posterior_c(d); 175 | } 176 | for (vmf_clustering::iterator itr = _phi.begin(); 177 | itr != _phi.end(); 178 | itr++) { 179 | unsigned l = itr->first; 180 | 181 | VLOG(1) << "resampling phi_" << l << "..."; 182 | resample_posterior_phi(l); 183 | } 184 | LOG(INFO) << "total clusters = " << _c.size(); 185 | 186 | print_clustering_summary(); 187 | 188 | } 189 | 190 | // Prints out the top few features from each cluster 191 | void vMFDPMixture::print_clustering_summary() { 192 | for (clustering::iterator itr = _c.begin(); 193 | itr != _c.end(); 194 | itr++) { 195 | unsigned l = itr->first; 196 | ublas::vector doc_sum; 197 | doc_sum.resize(_lV); 198 | for (int d = 0; d < _lD; d++) { 199 | if (_z[d] == l) { 200 | for (int k = 0; k < _lV; k++) { 201 | if (_v[d][k] > 0) { 202 | doc_sum[k] += 1; 203 | } 204 | } 205 | } 206 | } 207 | 208 | // Convert the ublas vector into a vector of pairs for sorting 209 | vector sorted; 210 | for (int k = 0; k < _lV; k++) { 211 | if (doc_sum[k] > 0) { 212 | sorted.push_back(make_pair(_word_id_to_name[k], doc_sum[k])); 213 | } 214 | } 215 | 216 | sort(sorted.begin(), sorted.end(), word_score_comp); 217 | 218 | // Finally print out the summary 219 | string buffer = ""; 220 | for (int k = 0; k < min((int)sorted.size(), 5); k++) { 221 | buffer += StringPrintf("%s %d ", sorted[k].first.c_str(), sorted[k].second); 222 | } 223 | 224 | LOG(INFO) << "_c[" << l << "] (size " << _c[l].nd[0] << ") nzf = " 225 | << sorted.size() << " " << buffer; 226 | 227 | 228 | } 229 | } 230 | 231 | // Get new cluster assignments 232 | void vMFDPMixture::resample_posterior_c(unsigned d) { 233 | unsigned old_assignment = _z[d]; 234 | 235 | vector > lp_z_d; 236 | 237 | for (google::dense_hash_map::iterator itr = _c.begin(); 238 | itr != _c.end(); 239 | itr++) { 240 | unsigned l = itr->first; 241 | 242 | unsigned top = _c[l].nd[0]; 243 | if (_z[d] == l) { 244 | top -= 1; 245 | } 246 | 247 | // First add in the prior over the clusters, Neal's algorithm 8 248 | double sum = log(top) - log(_lD - 1 + FLAGS_alpha) 249 | + logp_vmf(_v[d], _phi[l], FLAGS_kappa, true); 250 | 251 | lp_z_d.push_back(pair(l, sum)); 252 | } 253 | 254 | // Add some additional new components if DP 255 | vmf_clustering new_phi; 256 | new_phi.set_empty_key(kEmptyUnsignedKey); 257 | unsigned temp_current_component = _current_component; 258 | for (int m = 0; m < kM; m++) { 259 | new_phi[temp_current_component] = propose_new_phi(sample_spherical_gaussian(_mu, _xi)); 260 | double sum = log(FLAGS_alpha/(double)kM) - log(_lD - 1 + FLAGS_alpha) 261 | + logp_vmf(_v[d], new_phi[temp_current_component], FLAGS_kappa, true); 262 | lp_z_d.push_back(pair(temp_current_component, sum)); 263 | temp_current_component += 1; 264 | } 265 | 266 | 267 | // Update the assignment 268 | _z[d] = sample_unnormalized_log_multinomial(&lp_z_d); 269 | _c[_z[d]].nd[0] += 1; 270 | _c[old_assignment].nd[0] -= 1; 271 | VLOG(1) << "resampling posterior z for " << d << ": " << old_assignment << "->" << _z[d]; 272 | 273 | 274 | // Add the new one if necessary 275 | if (_z[d] >= _current_component) { 276 | _phi[_z[d]] = new_phi[_z[d]]; 277 | _current_component = _z[d] + 1; 278 | _K += 1; 279 | } 280 | 281 | // Clean up for the DPMM 282 | if (_c[old_assignment].nd[0] == 0) { // empty component 283 | _c.erase(old_assignment); 284 | _phi.erase(old_assignment); 285 | _K -= 1; 286 | } 287 | } 288 | 289 | // The resampling routines below are all basically boilerplate copies of 290 | // each other. Probably should templatize this or something 291 | // 292 | // The basic structure is: 293 | // (1) calculate the probability for the current setting 294 | // (2) sample a new setting from the proposal distribution 295 | // (3) calculate the new probability 296 | // (4) test the ratio 297 | // (5) reject by returning to the previous value, if necessary 298 | 299 | void vMFDPMixture::resample_posterior_phi(unsigned index) { 300 | double logp_orig = logp_phi(index, false) + log_likelihood_phi(index, false); 301 | ublas::vector old_phi(_phi[index]); // Check copy semantics 302 | 303 | VLOG(1) << "proposing new phi_" << index; 304 | 305 | _phi[index] = propose_new_phi(_phi[index]); 306 | double logp_new = logp_phi(index, false) + log_likelihood_phi(index, false); 307 | 308 | if (log(sample_uniform()) > logp_new - logp_orig) { // reject 309 | VLOG(1) << "... rejected."; 310 | 311 | _phi[index] = old_phi; 312 | _rejected_phi += 1; 313 | } 314 | _proposed_phi += 1; 315 | } 316 | 317 | // Computes the log posterior likelihood, e.g. the probability of observing the 318 | // documents that we do, given the model parameters. 319 | double vMFDPMixture::compute_log_likelihood() { 320 | double ll = 0; 321 | for (int i = 0; i < _lD; i++) { 322 | ll += logp_v(i, true); 323 | } 324 | LOG(INFO) << ll; 325 | 326 | ll += logp_dirichlet_process(_c, FLAGS_alpha); 327 | LOG(INFO) << logp_dirichlet_process(_c, FLAGS_alpha); 328 | return ll; 329 | } 330 | 331 | // Log likelihoods for internal nodes are computed as the sum of the children's logp's, given 332 | // the setting for the internal node. 333 | 334 | double vMFDPMixture::log_likelihood_phi(unsigned index, bool normalize) { 335 | // phi's children are the v's 336 | double ll = 0; 337 | for (int d = 0; d < _lD; d++) { 338 | ll += logp_v(d, normalize); 339 | } 340 | return ll; 341 | } 342 | 343 | // Direct log probabilities of assignments 344 | // 345 | double vMFDPMixture::logp_phi(unsigned index, bool normalize) { 346 | VLOG(1) << "computing logp of phi_" << index; 347 | return logp_vmf(_phi[index], _mu, _xi, normalize); 348 | // return logp_vmf(_phi.at(index), _alpha); 349 | } 350 | 351 | double vMFDPMixture::logp_v(unsigned index, bool normalize) { 352 | double result = 0; 353 | result = logp_vmf(_v[index], _phi[_z[index]], FLAGS_kappa, normalize); 354 | 355 | VLOG(2) << " logp_v_" << index << " = " << result; 356 | return result; 357 | } 358 | 359 | // Proposal distributions 360 | 361 | // Phi is calculated by drawing from a spherical gaussian and then mapping 362 | // it onto the hypersphere. 363 | ublas::vector vMFDPMixture::propose_new_phi(const ublas::vector& phi) { 364 | ublas::vector result = phi + sample_gaussian_vector(0, _proposal_variance_phi, phi.size()); 365 | 366 | // Force the draws onto the positive orthant by setting all negative entires 367 | // to zero 368 | if (FLAGS_restrict_topics) { 369 | for (int i = 0; i < result.size(); i++) { 370 | if (result[i] < 0) { 371 | result[i] = 0; 372 | } 373 | } 374 | } 375 | 376 | return result / norm_2(result); 377 | } 378 | 379 | double vMFDPMixture::get_new_proposal_variance(string var, double current, double reject_rate) { 380 | double accept_rate = 1.0 - reject_rate; 381 | 382 | LOG(INFO) << StringPrintf("TUNING: original variance: %s=%f", var.c_str(), current); 383 | 384 | // To handle our mu blowing up 385 | if (current > 10000) { 386 | current = 10000; 387 | } 388 | // This voodoo is pulled directly from pymc 389 | if (accept_rate < 0.001) { 390 | current *= 0.1; // reduce by 90 percent 391 | } else if (accept_rate < 0.05) { 392 | current *= 0.5; // reduce by 50 percent 393 | } else if (accept_rate < 0.2) { 394 | current *= 0.9; // reduce by ten percent 395 | } else if (accept_rate > 0.95) { 396 | current *= 10.0; // increase by factor of ten 397 | } else if (accept_rate > 0.75) { 398 | current *= 2.0; // increase by double 399 | } else if (accept_rate > 0.5) { 400 | current *= 1.1; // increase by ten percent 401 | } 402 | LOG(INFO) << StringPrintf("TUNING: new variance: %s=%f", var.c_str(), current); 403 | return current; 404 | } 405 | 406 | // Tune the proposal distributions 407 | void vMFDPMixture::tune() { 408 | _proposal_variance_phi = get_new_proposal_variance("phi", _proposal_variance_phi, _rejected_phi / (double)_proposed_phi); 409 | 410 | // Reset the reject counts to zero 411 | _rejected_phi = 0; 412 | _proposed_phi = 0; 413 | } 414 | 415 | // Actually does all the sampling and takes care of accounting stuff 416 | void vMFDPMixture::run(int iterations) { 417 | LOG(INFO) << "begin sampling..."; 418 | bool found_first_best = false; // HACK for getting nan on the first round 419 | for (; _iteration < iterations; _iteration++) { 420 | if (_iteration > 0 && _iteration % FLAGS_tune_interval == 0) { 421 | tune(); 422 | } 423 | 424 | resample_posterior(); 425 | 426 | _ll = compute_log_likelihood(); 427 | 428 | if (!isnan(_ll) && (_ll > _best_ll || !found_first_best)) { 429 | found_first_best = true; 430 | _best_ll = _ll; 431 | _best_iter = _iteration; 432 | 433 | LOG(INFO) << "Resampling iter = " << _iteration << " " << current_state() << " *"; 434 | 435 | write_data("best"); 436 | } else { 437 | LOG(INFO) << "Resampling iter = " << _iteration << " " << current_state(); 438 | } 439 | 440 | write_data(StringPrintf("last", _iteration)); 441 | 442 | 443 | if (_iteration % FLAGS_sample_lag == 0 && FLAGS_sample_lag > 0) { 444 | write_data(StringPrintf("sample-%05d", _iteration)); 445 | } 446 | } 447 | } 448 | 449 | 450 | void vMFDPMixture::load_data(const string& input_file_name) { 451 | LOG(INFO) << "Loading document " << input_file_name; 452 | 453 | _lD = 0; 454 | unsigned unique_word_count = 0; 455 | 456 | map word_to_id; 457 | 458 | std::vector > temp_docs; 459 | 460 | _V.clear(); 461 | _word_id_to_name.clear(); 462 | 463 | CHECK_STRNE(FLAGS_datafile.c_str(), ""); 464 | 465 | ifstream input_file(input_file_name.c_str()); 466 | CHECK(input_file.is_open()); 467 | 468 | string curr_line; 469 | while (true) { 470 | if (input_file.eof()) { 471 | break; 472 | } 473 | getline(input_file, curr_line); 474 | std::vector words; 475 | curr_line = StringReplace(curr_line, "\n", "", true); 476 | 477 | SplitStringUsing(curr_line, "\t", &words); 478 | 479 | if (words.size() == 0) { 480 | continue; 481 | } 482 | 483 | // TODO(jsr) simplify this 484 | temp_docs.push_back(ublas::compressed_vector(words.size()-1, words.size()-1)); 485 | 486 | for (int i = 0; i < words.size(); i++) { 487 | CHECK_STRNE(words[i].c_str(), ""); 488 | 489 | if (i == 0) { 490 | VLOG(1) << "found new document [" << words[i] << "]"; 491 | continue; 492 | } 493 | 494 | VLOG(2) << words.at(i); 495 | 496 | std::vector word_tokens; 497 | SplitStringUsing(words.at(i), ":", &word_tokens); 498 | CHECK_EQ(word_tokens.size(), 2); 499 | 500 | string word = word_tokens.at(0); 501 | double freq = atof(word_tokens.at(1).c_str()); 502 | VLOG(2) << word << " " << freq; 503 | 504 | // Is this a new word? 505 | if (word_to_id.find(word) == word_to_id.end()) { 506 | word_to_id[word] = unique_word_count; 507 | unique_word_count += 1; 508 | _word_id_to_name[word_to_id[word]] = word; 509 | _V.insert(word_to_id[word]); 510 | } 511 | 512 | // This bit is pretty gross, truly; dynamically resize the sparse vector as needed 513 | // since there is no push_back or the equivalent 514 | if (temp_docs.at(_lD).size() <= word_to_id[word]) { 515 | temp_docs.at(_lD).resize(word_to_id[word]+1, true); 516 | } 517 | 518 | // TODO(jsr) this forces us to use each word only once per doc in the input file 519 | temp_docs.at(_lD).insert_element(word_to_id[word], freq); 520 | } 521 | CHECK_GT(sum(temp_docs.at(_lD)), 0); 522 | temp_docs.at(_lD) /= norm_2(temp_docs.at(_lD)); // L2 norm 523 | 524 | _lD += 1; 525 | } 526 | 527 | _lV = _V.size(); 528 | 529 | // Copy the temp docs over 530 | for (int d = 0; d < _lD; d++) { 531 | _v[d].resize(_lV); 532 | for (int k = 0; k < temp_docs[d].size(); k++) { 533 | _v[d][k] = temp_docs[d][k]; 534 | } 535 | } 536 | 537 | LOG(INFO) << "Loaded " << _lD << " documents with " 538 | << _V.size() << " unique words from " 539 | << input_file_name; 540 | } 541 | 542 | 543 | // Write out all the data in an intermediate format 544 | void vMFDPMixture::write_data(string prefix) { 545 | string filename = StringPrintf("%s-%s-%s.params", _filename.c_str(), FLAGS_my_id.c_str(), prefix.c_str()); 546 | 547 | ofstream f; 548 | f.open(filename.c_str(), ios::out); 549 | 550 | f << current_state() << endl; 551 | 552 | f << "iteration " << _iteration << endl; 553 | f << "best_ll " << _best_ll << endl; 554 | f << "best_iter " << _best_iter << endl;; 555 | f << "proposal_variance_phi " << _proposal_variance_phi << endl; 556 | 557 | f << "vocab "; 558 | for (WordCode::iterator itr = _word_id_to_name.begin(); itr != _word_id_to_name.end(); itr++) { 559 | f << itr->first << ":" << itr->second << "\t"; 560 | } 561 | f << endl; 562 | 563 | f << "alpha " << FLAGS_alpha << endl; 564 | f << "kappa " << FLAGS_kappa << endl; 565 | f << "xi " << FLAGS_xi << endl; 566 | 567 | for (vmf_clustering::iterator itr = _phi.begin(); 568 | itr != _phi.end(); 569 | itr++) { 570 | unsigned l = itr->first; 571 | 572 | f << "phi_" << l << " " << _phi[l] << endl; 573 | } 574 | for (clustering::iterator itr = _c.begin(); 575 | itr != _c.end(); 576 | itr++) { 577 | unsigned l = itr->first; 578 | 579 | f << "c_" << l << " " << _c[l].nd[0] << endl; 580 | } 581 | for (int k = 0; k < _lD; k++) { 582 | f << "z_" << k << " " << _z[k] << endl; 583 | } 584 | 585 | f.close(); 586 | } 587 | 588 | 589 | // Generate a vector of samples from the same gaussian 590 | ublas::vector sample_gaussian_vector(double mu, double si2, size_t dim) { 591 | ublas::vector result(dim); 592 | for (int i = 0; i < dim; i++) { 593 | result[i] = si2*sample_gaussian()+mu; 594 | } 595 | return result; 596 | } 597 | 598 | // Sample from a Gaussian (spherical) and then normalize the resulting draw onto 599 | // the unit hypersphere 600 | template 601 | ublas::vector sample_spherical_gaussian(const vec_t& mean, double si2) { 602 | ublas::vector result(mean.size()); 603 | for (int i = 0; i < mean.size(); i++) { 604 | result[i] = si2*sample_gaussian()+mean[i]; 605 | } 606 | return result / norm_2(result); 607 | } 608 | 609 | 610 | // For now always assume mu and v can differ in underlying vector representation 611 | template 612 | double logp_vmf(const vec_t1& v, const vec_t2& mu, double kappa, bool normalize) { 613 | CHECK_EQ(v.size(), mu.size()); 614 | // TODO(jsr) figure out how to do this properly 615 | DCHECK_LT(fabs(1.0-norm_2(v)), FLAGS_epsilon_value); 616 | DCHECK_LT(fabs(1.0-norm_2(mu)), FLAGS_epsilon_value); 617 | 618 | unsigned p = v.size(); 619 | 620 | if (normalize) { 621 | // Compute an approximate log modified bessel function of the first kind 622 | double l_bessel = approx_log_iv(p / 2.0 - 1.0, kappa); 623 | 624 | return kappa * inner_prod(mu, v) 625 | + (p/2.0)*log(kappa / (2*M_PI)) 626 | - log(kappa) 627 | - l_bessel; 628 | } else { // in the unnormalized case, don't compute terms only involving kappa 629 | return kappa * inner_prod(mu, v); 630 | } 631 | } 632 | // Computes the Abramowitz and Stegum approximation to the log modified bessel 633 | // funtion of the first kind -- stable for high values of nu. See Chris Elkan's 634 | double approx_log_iv(double nu, double z) { 635 | double alpha = 1 + pow(z / nu, 2); 636 | double eta = sqrt(alpha) + log(z / nu) - log(1+sqrt(alpha)); 637 | return -log(sqrt(2*M_PI*nu)) + nu*eta - 0.25 * log(alpha); 638 | } 639 | 640 | 641 | // Computes the log prob of a value in a symmetric dirichlet 642 | // or dirichlet process 643 | double logp_dirichlet_process(clustering& value, double alpha) { 644 | unsigned dim = 0; 645 | double l; 646 | double s = 0; 647 | for (clustering::iterator itr = value.begin(); 648 | itr != value.end(); 649 | itr++) { 650 | s += itr->second.nd[0] - alpha; 651 | dim += 1; 652 | } 653 | 654 | // TODO: we shouldn't have to compute the sum 655 | double alpha_sum = alpha*((double)dim); 656 | l = gammaln(alpha_sum) - ((double)dim) * gammaln(alpha); 657 | 658 | for (clustering::iterator itr = value.begin(); 659 | itr != value.end(); 660 | itr++) { 661 | 662 | l += (alpha-1) * log(BOUND((itr->second.nd[0]-alpha)/s)); 663 | } 664 | return BOUNDPROB(l); 665 | } 666 | 667 | int main(int argc, char **argv) { 668 | google::InitGoogleLogging(argv[0]); 669 | google::ParseCommandLineFlags(&argc, &argv, true); 670 | 671 | init_random(); 672 | 673 | vMFDPMixture h = vMFDPMixture(FLAGS_xi); 674 | h.load_data(FLAGS_datafile); 675 | h.initialize(); 676 | 677 | h.run(FLAGS_iterations); 678 | } 679 | 680 | 681 | 682 | -------------------------------------------------------------------------------- /experimental/sample-vmf-dp-mixture.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Metropolis-Hastings sampler for an admixture of von Mises-Fisher distributions 17 | 18 | #ifndef VMF_DP_MIXTURE_H_ 19 | #define VMF_DP_MIXTURE_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | #include 41 | 42 | #include "gibbs-base.h" 43 | #include "strutil.h" 44 | 45 | using namespace std; 46 | using namespace boost::numeric; 47 | using namespace google::protobuf; 48 | 49 | 50 | // the number of iterations to run for. 51 | DECLARE_int32(iterations); 52 | 53 | // Initial number of clusters 54 | DECLARE_int32(K_initial); 55 | 56 | // The input data set. Consists of one document per line, with words separated 57 | // by tabs, and appended with their counts. The first word in each document is 58 | // taken as the document name. 59 | DECLARE_string(datafile); 60 | 61 | // Alpha controls the topic smoothing, with higher alpha causing more "uniform" 62 | // distributions over topics. 63 | DECLARE_double(alpha); 64 | 65 | // Kappa controls the width of the vMF clusters. 66 | DECLARE_double(kappa); 67 | 68 | // xi is the corpus concentration 69 | DECLARE_double(xi); 70 | 71 | // Restricts the topics (phis) to live on the positive orthant by chopping all 72 | // negative entries in the vector (sparsifying). 73 | DECLARE_bool(restrict_cluster_means); 74 | 75 | // A unique identifier to add to the output moniker (useful for condor) 76 | DECLARE_string(my_id); 77 | 78 | typedef google::dense_hash_map< unsigned, ublas::vector > vmf_data; 79 | typedef google::dense_hash_map< unsigned, ublas::compressed_vector > vmf_clustering; 80 | 81 | class vMFDPMixture : public GibbsSampler { 82 | public: 83 | vMFDPMixture(double xi) 84 | : _xi(xi) { 85 | _z.set_empty_key(kEmptyUnsignedKey); 86 | _c.set_empty_key(kEmptyUnsignedKey); 87 | _c.set_deleted_key(kDeletedUnsignedKey); 88 | _word_id_to_name.set_empty_key(kEmptyUnsignedKey); 89 | _v.set_empty_key(kEmptyUnsignedKey); 90 | _phi.set_empty_key(kEmptyUnsignedKey); 91 | _phi.set_deleted_key(kDeletedUnsignedKey); 92 | } 93 | virtual ~vMFDPMixture() { /* TODO: free memory! */ } 94 | 95 | // Set up all the data structures and initialize distributions to random values 96 | void initialize(); 97 | 98 | // Run the MH sampler for iterations iterations. This is the main 99 | // workhorse loop. 100 | void run(int iterations); 101 | 102 | // Load a data file 103 | void load_data(const string& filename); 104 | 105 | // Dump the sufficient statistics of the model 106 | void write_data(string prefix); 107 | 108 | // return a string describing the current state of the hyperparameters 109 | // and log likelihood 110 | string current_state(); 111 | 112 | // Prints out the top features of each cluster 113 | void print_clustering_summary(); 114 | 115 | 116 | protected: 117 | // Performs an entire MH step, updating the posterior 118 | void resample_posterior(); 119 | 120 | // The various component distributions 121 | void resample_posterior_phi(unsigned index); 122 | void resample_posterior_c(unsigned d); 123 | 124 | // Posterior likelihood of the data conditional on the model parameters 125 | double compute_log_likelihood(); 126 | 127 | // Log-likelihood computation (sum of the log probabilities of the nodes children) 128 | double log_likelihood_phi(unsigned index, bool normalize); 129 | 130 | // Direct log probabilities of internal nodes given the settings of their parents 131 | double logp_phi(unsigned index, bool normalize); 132 | double logp_v(unsigned index, bool normalize); 133 | 134 | // Proposal distributions 135 | ublas::vector propose_new_phi(const ublas::vector& phi); 136 | 137 | // Tune the proposal steps 138 | double get_new_proposal_variance(string var, double current, double reject_rate); 139 | void tune(); 140 | 141 | protected: 142 | cluster_map _z; // Map data point -> cluster 143 | clustering _c; // Map [w][z] -> CRP 144 | 145 | // Parameters 146 | unsigned _K; // Number of clusters 147 | 148 | unsigned _current_component; 149 | 150 | set _V; // vocabulary 151 | unsigned _lV; // size of vocab 152 | unsigned _lD; // number of documents 153 | 154 | WordCode _word_id_to_name; // uniqe_id to string 155 | 156 | double _ll; // current log-likelihood 157 | double _best_ll; 158 | int _best_iter; 159 | unsigned _iteration; // Hold the global sampling step 160 | 161 | string _filename; // output file name 162 | 163 | // Model hyperparameters 164 | ublas::vector _mu; // uniform mean vector constant 165 | double _xi; 166 | 167 | // Model parameters 168 | vmf_clustering _phi; 169 | 170 | // The actual observed documents (sparse matrix) 171 | vmf_data _v; 172 | 173 | // Tallies for accounting and performing adaptive variance updates 174 | unsigned _rejected_xi; 175 | unsigned _rejected_phi; 176 | unsigned _proposed_phi; 177 | 178 | // Current proposal variance settings 179 | double _proposal_variance_xi; 180 | double _proposal_variance_phi; 181 | }; 182 | 183 | ublas::vector sample_gaussian_vector(double mean, double si2, unsigned dim); 184 | template 185 | ublas::vector sample_spherical_gaussian(const vec_t& mean, double si2); 186 | 187 | double approx_log_iv(double nu, double z); 188 | 189 | // For now always assume mu and v can differ in underlying vector representation 190 | template 191 | double logp_vmf(const vec_t1& v, const vec_t2& mu, double kappa, bool normalize); 192 | 193 | double logp_dirichlet_process(clustering& value, double alpha); 194 | 195 | 196 | #endif // VMF_DP_MIXTURE_H_ 197 | -------------------------------------------------------------------------------- /gibbs-base.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // This file contains all of the common code shared between the Mutinomial hLDA 17 | // sampler (fixed-depth tree) and the GEM hLDA sampler (infinite-depth tree), as 18 | // well as the fixed topic structure sampler used to learn WN senses. 19 | // 20 | // CRPNode is the basic data structure for a single node in the topic tree (dag) 21 | 22 | #include 23 | #include 24 | 25 | #include "dSFMT-src-2.0/dSFMT.h" 26 | 27 | #include "gibbs-base.h" 28 | 29 | using namespace std; 30 | 31 | // For vanilla LDA, we can save the state of the model directly in the raw 32 | // docify, using an extra : for each term denoting the topic assignment. This 33 | // allows us to do some fancy things like checkpoint and update documents on the 34 | // fly. 35 | DEFINE_int32(preassigned_topics, 36 | 0, 37 | "Topics are preassigned in docify (vanilla only)."); 38 | 39 | // Streaming makes some major structural changes to the code, moving the main loop into load_data; 40 | // This variable sets how many documents should be kept in memory at any one time 41 | DEFINE_int32(streaming, 42 | 0, 43 | "The number of documents to remember"); 44 | 45 | // Eta controls the amount of smoothing within the per-topic word distributions. 46 | // Higher eta = more smoothing. Also used in the GEM sampler. 47 | DEFINE_double(eta, 48 | 0.1, 49 | "hyperparameter eta, controls the word smoothing"); 50 | 51 | // We can incorporate prior ranking information via eta. One way to do this is 52 | // to assume that eta is proportional to some exponential of the average word 53 | // rank, thus more highly ranked words have higher priors. eta_prior_scale 54 | // represents the degree of confidence (how fast to decay) in each rank, higher 55 | // meaning less informative prior. 56 | DEFINE_double(eta_prior_exponential, 57 | 1.0, 58 | "How should we decay the eta prior?"); 59 | 60 | 61 | // the number of gibbs iterations to run for. 62 | DEFINE_int32(max_gibbs_iterations, 63 | 99999, 64 | "Number of Gibbs sampling iterations to run."); 65 | 66 | 67 | // One problem with Gibbs sampling is that nearby samples are highly 68 | // correlated, throwing off the empirical distribution. In practice you need to 69 | // wait some amount of time before reading each (independent) sample. 70 | DEFINE_int32(sample_lag, 100, "how many Gibbs iterations to perform per sample (0 no samples)"); 71 | 72 | // The random seed 73 | DEFINE_int32(random_seed, 101, "what seed value to use"); 74 | 75 | // Our tolerance for numerical (under)overflow. 76 | DEFINE_double(epsilon_value, 1e-4, "tolerance for underflow"); 77 | 78 | // How many steps of not writing a best do we need before we declare 79 | // convergence? 80 | DEFINE_int32(convergence_interval, 81 | 0, 82 | "How many steps should we wait before declaring convergence? 0 = off"); 83 | 84 | // Binarize the feature counts 85 | DEFINE_bool(binarize, 86 | false, 87 | "Binarize the word counts."); 88 | 89 | // Should the best sample get output? 90 | DEFINE_bool(output_best, 91 | true, 92 | "Output the best sample"); 93 | 94 | // Should the last sample get output? 95 | DEFINE_bool(output_last, 96 | false, 97 | "Output the last sample"); 98 | 99 | // The Mersenne Twister 100 | dsfmt_t dsfmt; 101 | 102 | void safe_remove_crp(vector* domain, const CRP* target) { 103 | vector::iterator p = find(domain->begin(), domain->end(), target); 104 | // must have existed 105 | CHECK(p != domain->end()) << "[" << target->label << "] didn't exist"; 106 | domain->erase(p); 107 | } 108 | 109 | CRP::~CRP() { 110 | // TODO: want this check for the learned versions... 111 | // CHECK_LE(m, 1); // don't delete without removing docs 112 | 113 | // LOG(INFO) << "about to remove from parents"; 114 | if (!prev.empty()) { 115 | remove_from_parents(); 116 | } 117 | 118 | // LOG(INFO) << "about to delete " << tables.size() << " tables"; 119 | // remove all of our children nodes (note that this is recursive) 120 | for (int i = 0; i < tables.size(); i++) { 121 | delete tables[i]; 122 | } 123 | // LOG(INFO) << "phew!!"; 124 | } 125 | 126 | // Remove this node from the list of nodes stored at prev 127 | void CRP::remove_from_parents() { 128 | // can only be called on an interior node (i.e. with a parent) 129 | CHECK_GT(prev.size(), 0); 130 | 131 | for (int i = 0; i < prev.size(); i++) { 132 | safe_remove_crp(&prev[i]->tables, this); 133 | } 134 | } 135 | 136 | bool GibbsSampler::sample_and_check_for_convergence() { 137 | if (_iter > 0) { // Don't resample at first iteration, so we can accurately record the starting state 138 | resample_posterior(); 139 | _converged_iterations += 1; 140 | } 141 | //if (i % FLAGS_sample_lag == 0) { // calculate ll every 100 142 | _ll = compute_log_likelihood(); 143 | //} 144 | CHECK_LE(_ll, 0) << "log likelihood cannot be positive!"; 145 | 146 | if (_ll > _best_ll || _iter == 0) { 147 | _best_ll = _ll; 148 | _best_iter = _iter; 149 | _converged_iterations = 0; 150 | 151 | LOG(INFO) << "Resampling iter = " << _iter << " " << current_state() << " *"; 152 | 153 | write_data("best"); 154 | } else { 155 | LOG(INFO) << "Resampling iter = " << _iter << " " << current_state(); 156 | } 157 | 158 | if (FLAGS_output_best && FLAGS_sample_lag > 0 && _iter % FLAGS_sample_lag == 0) { 159 | write_data(StringPrintf("sample-%05d", _iter)); 160 | } 161 | 162 | if (FLAGS_output_last) { 163 | write_data("last"); 164 | } 165 | 166 | _iter++; 167 | 168 | return FLAGS_convergence_interval > 0 && _converged_iterations >= FLAGS_convergence_interval; 169 | } 170 | 171 | void GibbsSampler::run() { 172 | _ll = compute_log_likelihood(); 173 | 174 | while (_iter < FLAGS_max_gibbs_iterations) { 175 | 176 | if (sample_and_check_for_convergence()) { 177 | LOG(INFO) << "CONVERGED!"; 178 | write_data("converged"); 179 | break; 180 | } 181 | } 182 | } 183 | 184 | bool GibbsSampler::process_document_line(const string& curr_line, unsigned line_no) { 185 | vector words; 186 | vector encoded_words; 187 | vector topics; 188 | //CHECK_EQ(x, 0); 189 | 190 | SplitStringUsing(StringReplace(curr_line, "\n", "", true), "\t", &words); 191 | 192 | // V->insert(words.begin(), words.end()); 193 | if (words.empty()) { 194 | LOG(WARNING) << "EMPTY LINE"; 195 | return false; 196 | } 197 | 198 | // the name of the document 199 | if (words.size() == 1) { 200 | LOG(WARNING) << "empty document " << words[0]; 201 | return false; 202 | } 203 | 204 | _document_name[line_no] = words[0]; 205 | VLOG(1) << "found new document [" << words[0] << "] " << line_no; 206 | _document_id[words[0]] = line_no; 207 | _nd[line_no] = 0; 208 | 209 | for (int i = 1; i < words.size(); i++) { 210 | CHECK_STRNE(words[i].c_str(), ""); 211 | // if (!(i == 0 && (HasPrefixString(words[i], "rpl_") || 212 | // HasPrefixString(words[i], "RPL_")))) { 213 | vector word_tokens; 214 | //VLOG(2) << words.at(i); 215 | SplitStringUsing(words.at(i), ":", &word_tokens); 216 | 217 | int topic; 218 | int freq; 219 | 220 | if (FLAGS_preassigned_topics == 1) { 221 | topic = atoi(word_tokens.back().c_str()); 222 | word_tokens.pop_back(); 223 | } 224 | 225 | freq = atoi(word_tokens.back().c_str()); 226 | word_tokens.pop_back(); 227 | 228 | if (FLAGS_preassigned_topics == 1) { 229 | CHECK_EQ(freq, 1); // Each term gets a unique assignment 230 | } 231 | 232 | 233 | string word = JoinStrings(word_tokens, ":"); 234 | 235 | VLOG(1) << word << " " << freq; 236 | if (_word_name_to_id.find(word) == _word_name_to_id.end()) { 237 | _word_name_to_id[word] = _unique_word_count; 238 | _word_id_to_name[_unique_word_count] = word; 239 | 240 | _unique_word_count += 1; 241 | 242 | _eta.push_back(FLAGS_eta); 243 | _eta_sum += FLAGS_eta; 244 | } 245 | _V[_word_name_to_id[word]] += freq; 246 | if (FLAGS_binarize) { 247 | freq = 1; 248 | } 249 | for (int f = 0; f < freq; f++) { 250 | encoded_words.push_back(_word_name_to_id[word]); 251 | topics.push_back(topic); 252 | } 253 | _total_word_count += freq; 254 | _nd[line_no] += freq; 255 | } 256 | _D[line_no] = encoded_words; 257 | _initial_topic_assignment[line_no] = topics; 258 | 259 | _lD = _D.size(); 260 | _lV = _V.size(); 261 | 262 | // Make sure eta is in a reasonable range 263 | CHECK_LT(_eta_sum, 1000000); 264 | 265 | if (FLAGS_streaming > 0) { 266 | streaming_step(line_no); 267 | } 268 | 269 | return true; 270 | } 271 | 272 | void GibbsSampler::streaming_step(unsigned new_d) { 273 | allocate_document(new_d); 274 | 275 | if (_D.size() > FLAGS_streaming) { 276 | deallocate_document(); 277 | _lD = _D.size(); 278 | _lV = _V.size(); 279 | sample_and_check_for_convergence(); 280 | } 281 | } 282 | 283 | void GibbsSampler::load_data(const string& filename) { 284 | _D.clear(); 285 | _V.clear(); 286 | _word_id_to_name.clear(); 287 | 288 | _initial_topic_assignment.clear(); 289 | 290 | 291 | LOG(INFO) << "loading data from [" << filename << "]"; 292 | 293 | // XXX: turn this block into a function somehow 294 | ifstream ii(filename.c_str(), ios_base::in | ios_base::binary); 295 | CHECK(ii.is_open()); 296 | boost::iostreams::filtering_streambuf in; 297 | if (is_gz_file(filename)) { 298 | in.push(boost::iostreams::gzip_decompressor()); 299 | } 300 | in.push(ii); 301 | istream input_file(&in); 302 | /////////////////////////////////////////////// 303 | 304 | string curr_line; 305 | unsigned line_no = 0; 306 | while (true) { 307 | if (input_file.eof()) { 308 | break; 309 | } 310 | getline(input_file, curr_line); 311 | if (process_document_line(curr_line, line_no)) { 312 | line_no += 1; 313 | } 314 | } 315 | 316 | // Allocate documents 317 | if (FLAGS_streaming == 0) { 318 | batch_allocation(); 319 | 320 | LOG(INFO) << "Loaded " << _lD << " documents with " 321 | << _total_word_count << " words (" << _V.size() << " unique) from " 322 | << filename; 323 | } 324 | //delete input_file; 325 | } 326 | 327 | // Machinering for printing out the tops of multinomials 328 | typedef std::pair word_score; 329 | bool word_score_comp(const word_score& left, const word_score& right) { 330 | return left.second > right.second; 331 | } 332 | 333 | string GibbsSampler::show_chopped_sorted_nw(const WordToCountMap& nw) { 334 | vector sorted; 335 | for (WordToCountMap::const_iterator nw_itr = nw.begin(); 336 | nw_itr != nw.end(); 337 | nw_itr++) { 338 | unsigned w = nw_itr->first; 339 | unsigned c = nw_itr->second; 340 | if (c > 0) { 341 | sorted.push_back(make_pair(_word_id_to_name[w], c)); 342 | } 343 | } 344 | 345 | sort(sorted.begin(), sorted.end(), word_score_comp); 346 | 347 | // Finally print out the summary 348 | string buffer = ""; 349 | for (int k = 0; k < min((int)sorted.size(), 10); k++) { 350 | buffer += StringPrintf("%s %d ", sorted[k].first.c_str(), sorted[k].second); 351 | } 352 | 353 | return buffer; 354 | } 355 | 356 | 357 | void init_random() { 358 | #ifdef USE_MT_RANDOM 359 | dsfmt_init_gen_rand(&dsfmt, FLAGS_random_seed); 360 | #else 361 | srand(FLAGS_random_seed); 362 | #endif 363 | 364 | } 365 | 366 | // Logarithm of the gamma function. 367 | // 368 | // References: 369 | // 370 | // 1) W. J. Cody and K. E. Hillstrom, 'Chebyshev Approximations for 371 | // the Natural Logarithm of the Gamma Function,' Math. Comp. 21, 372 | // 1967, pp. 198-203. 373 | // 374 | // 2) K. E. Hillstrom, ANL/AMD Program ANLC366S, DGAMMA/DLGAMA, May, 375 | // 1969. 376 | // 377 | // 3) Hart, Et. Al., Computer Approximations, Wiley and sons, New 378 | // York, 1968. 379 | // 380 | // From matlab/gammaln.m 381 | double gammaln(double x) { 382 | double result, y, xnum, xden; 383 | int i; 384 | static double d1 = -5.772156649015328605195174e-1; 385 | static double p1[] = { 386 | 4.945235359296727046734888e0, 2.018112620856775083915565e2, 387 | 2.290838373831346393026739e3, 1.131967205903380828685045e4, 388 | 2.855724635671635335736389e4, 3.848496228443793359990269e4, 389 | 2.637748787624195437963534e4, 7.225813979700288197698961e3 390 | }; 391 | static double q1[] = { 392 | 6.748212550303777196073036e1, 1.113332393857199323513008e3, 393 | 7.738757056935398733233834e3, 2.763987074403340708898585e4, 394 | 5.499310206226157329794414e4, 6.161122180066002127833352e4, 395 | 3.635127591501940507276287e4, 8.785536302431013170870835e3 396 | }; 397 | static double d2 = 4.227843350984671393993777e-1; 398 | static double p2[] = { 399 | 4.974607845568932035012064e0, 5.424138599891070494101986e2, 400 | 1.550693864978364947665077e4, 1.847932904445632425417223e5, 401 | 1.088204769468828767498470e6, 3.338152967987029735917223e6, 402 | 5.106661678927352456275255e6, 3.074109054850539556250927e6 403 | }; 404 | static double q2[] = { 405 | 1.830328399370592604055942e2, 7.765049321445005871323047e3, 406 | 1.331903827966074194402448e5, 1.136705821321969608938755e6, 407 | 5.267964117437946917577538e6, 1.346701454311101692290052e7, 408 | 1.782736530353274213975932e7, 9.533095591844353613395747e6 409 | }; 410 | static double d4 = 1.791759469228055000094023e0; 411 | static double p4[] = { 412 | 1.474502166059939948905062e4, 2.426813369486704502836312e6, 413 | 1.214755574045093227939592e8, 2.663432449630976949898078e9, 414 | 2.940378956634553899906876e10, 1.702665737765398868392998e11, 415 | 4.926125793377430887588120e11, 5.606251856223951465078242e11 416 | }; 417 | static double q4[] = { 418 | 2.690530175870899333379843e3, 6.393885654300092398984238e5, 419 | 4.135599930241388052042842e7, 1.120872109616147941376570e9, 420 | 1.488613728678813811542398e10, 1.016803586272438228077304e11, 421 | 3.417476345507377132798597e11, 4.463158187419713286462081e11 422 | }; 423 | static double c[] = { 424 | -1.910444077728e-03, 8.4171387781295e-04, 425 | -5.952379913043012e-04, 7.93650793500350248e-04, 426 | -2.777777777777681622553e-03, 8.333333333333333331554247e-02, 427 | 5.7083835261e-03 428 | }; 429 | static double a = 0.6796875; 430 | 431 | if ((x <= 0.5) || ((x > a) && (x <= 1.5))) { 432 | if (x <= 0.5) { 433 | result = -log(x); 434 | /* Test whether X < machine epsilon. */ 435 | if (x+1 == 1) { 436 | return result; 437 | } 438 | } else { 439 | result = 0; 440 | x = (x - 0.5) - 0.5; 441 | } 442 | xnum = 0; 443 | xden = 1; 444 | for (i = 0; i < 8; i++) { 445 | xnum = xnum * x + p1[i]; 446 | xden = xden * x + q1[i]; 447 | } 448 | result += x * (d1 + x * (xnum / xden)); 449 | } else if ((x <= a) || ((x > 1.5) && (x <= 4))) { 450 | if (x <= a) { 451 | result = -log(x); 452 | x = (x - 0.5) - 0.5; 453 | } else { 454 | result = 0; 455 | x -= 2; 456 | } 457 | xnum = 0; 458 | xden = 1; 459 | for (i = 0; i < 8 ;i++) { 460 | xnum = xnum * x + p2[i]; 461 | xden = xden * x + q2[i]; 462 | } 463 | result += x * (d2 + x * (xnum / xden)); 464 | } else if (x <= 12) { 465 | x -= 4; 466 | xnum = 0; 467 | xden = -1; 468 | for (i = 0; i < 8; i++) { 469 | xnum = xnum * x + p4[i]; 470 | xden = xden * x + q4[i]; 471 | } 472 | result = d4 + x*(xnum/xden); 473 | } else { 474 | // X > 12 475 | y = log(x); 476 | result = x * (y - 1) - y * 0.5 + .9189385332046727417803297; 477 | x = 1/x; 478 | y = x*x; 479 | xnum = c[6]; 480 | for (i = 0; i < 6; i++) { 481 | xnum = xnum * y + c[i]; 482 | } 483 | xnum *= x; 484 | result += xnum; 485 | } 486 | return result; 487 | } 488 | 489 | 490 | long double addLog(long double x, long double y) { 491 | if (x == 0) { 492 | return y; 493 | } 494 | if (y == 0) { 495 | return x; 496 | } 497 | 498 | if (x-y > 16) { 499 | return x; 500 | } else if (x > y) { 501 | return x + log(1 + exp(y-x)); 502 | } else if (y-x > 16) { 503 | return y; 504 | } else { 505 | return y + log(1 + exp(x-y)); 506 | } 507 | } 508 | 509 | void normalizeLog(vector*x) { 510 | long double s; 511 | int i; 512 | s = 0; 513 | 514 | long double normalized_sum = 0; 515 | 516 | for (i = 0; i < x->size(); i++) { 517 | s = addLog(s, x->at(i)); 518 | } 519 | for (i = 0; i < x->size(); i++) { 520 | (*x)[i] = exp(x->at(i) - s); 521 | normalized_sum += (*x)[i]; 522 | } 523 | 524 | // CHECK(MathUtil::NearByMargin(normalized_sum,1.0)); 525 | // LOG(INFO) << "normalized sum " << normalized_sum; 526 | CHECK_GT(normalized_sum, 0); // for nan 527 | CHECK_LT(fabs(normalized_sum - 1.0), FLAGS_epsilon_value); 528 | } 529 | void normalizeLog(vector >*x) { 530 | long double s; 531 | int i; 532 | s = 0; 533 | 534 | long double normalized_sum = 0; 535 | 536 | for (i = 0; i < x->size(); i++) { 537 | s = addLog(s, x->at(i).second); 538 | } 539 | for (i = 0; i < x->size(); i++) { 540 | (*x)[i] = pair(x->at(i).first, exp(x->at(i).second - s)); 541 | normalized_sum += x->at(i).second; 542 | } 543 | 544 | // CHECK(MathUtil::NearByMargin(normalized_sum,1.0)); 545 | // LOG(INFO) << "normalized sum " << normalized_sum; 546 | CHECK_GT(normalized_sum, 0); // for nan 547 | CHECK_LT(fabs(normalized_sum - 1.0), FLAGS_epsilon_value); 548 | } 549 | 550 | double sample_uniform() { 551 | #ifdef USE_MT_RANDOM 552 | return dsfmt_genrand_close_open(&dsfmt); 553 | #else 554 | return random() / (double)RAND_MAX; 555 | #endif 556 | 557 | } 558 | 559 | // Given a multinomial distribution of the form {label:prob}, return a label 560 | // with that probability. 561 | inline int sample_normalized_multinomial(vector*d) { 562 | double cut = sample_uniform(); 563 | CHECK_LE(cut, 1.0); 564 | CHECK_GE(cut, 0.0); 565 | 566 | for (int i = 0; i < d->size(); i++) { 567 | cut -= d->at(i); 568 | 569 | if (cut < 0) { 570 | return i; 571 | } 572 | } 573 | 574 | CHECK(false) << "improperly normalized distribution " << cut; 575 | return 0; 576 | } 577 | // Given a multinomial distribution of the form {label:prob}, return a label 578 | // with that probability. 579 | inline int sample_normalized_multinomial(vector >*d) { 580 | double cut = sample_uniform(); 581 | CHECK_LE(cut, 1.0); 582 | CHECK_GE(cut, 0.0); 583 | 584 | for (int i = 0; i < d->size(); i++) { 585 | cut -= d->at(i).second; 586 | 587 | if (cut < 0) { 588 | return d->at(i).first; 589 | } 590 | } 591 | 592 | CHECK(false) << "improperly normalized distribution " << cut; 593 | return -1; 594 | } 595 | 596 | 597 | // Assume that the data coming in are log probs and that they need to be 598 | // appropriately normalized. 599 | // XXX: sample_unnormalized_log_multinomial changes d into normal p space 600 | unsigned sample_unnormalized_log_multinomial(vector*d) { 601 | double cut = sample_uniform(); 602 | CHECK_LE(cut, 1.0); 603 | CHECK_GE(cut, 0.0); 604 | 605 | unsigned int i; 606 | long double s = 0; 607 | for (i = 0; i < d->size(); i++) { 608 | s = addLog(s, d->at(i)); 609 | } 610 | for (i = 0; i < d->size(); i++) { 611 | cut -= exp(d->at(i) - s); 612 | 613 | if (cut < 0) { 614 | return i; 615 | } 616 | } 617 | 618 | CHECK(false) << "improperly normalized distribution " << cut; 619 | return 0; 620 | } 621 | unsigned sample_unnormalized_log_multinomial(vector >*d) { 622 | double cut = sample_uniform(); 623 | CHECK_LE(cut, 1.0); 624 | CHECK_GE(cut, 0.0); 625 | 626 | int i; 627 | long double s = 0; 628 | 629 | for (i = 0; i < d->size(); i++) { 630 | s = addLog(s, d->at(i).second); 631 | } 632 | for (i = 0; i < d->size(); i++) { 633 | cut -= exp(d->at(i).second - s); 634 | 635 | if (cut < 0) { 636 | return d->at(i).first; 637 | } 638 | } 639 | 640 | CHECK(false) << "improperly normalized distribution " << cut; 641 | return -1; 642 | } 643 | 644 | sampler_entry NEW_sample_unnormalized_log_multinomial(vector*d) { 645 | double cut = sample_uniform(); 646 | CHECK_LE(cut, 1.0); 647 | CHECK_GE(cut, 0.0); 648 | 649 | int i; 650 | long double s = 0; 651 | 652 | for (i = 0; i < d->size(); i++) { 653 | s = addLog(s, d->at(i).score); 654 | } 655 | for (i = 0; i < d->size(); i++) { 656 | double score = exp(d->at(i).score - s); 657 | cut -= score; 658 | 659 | if (cut < 0) { 660 | // Return a new entry with the normalized score 661 | return sampler_entry(d->at(i).index, score); 662 | } 663 | } 664 | 665 | CHECK(false) << "improperly normalized distribution " << cut; 666 | return d->at(0); 667 | } 668 | 669 | int SAFE_sample_unnormalized_log_multinomial(vector*d) { 670 | normalizeLog(d); 671 | return sample_normalized_multinomial(d); 672 | } 673 | int SAFE_sample_unnormalized_log_multinomial(vector >*d) { 674 | normalizeLog(d); 675 | return sample_normalized_multinomial(d); 676 | } 677 | 678 | unsigned sample_integer(unsigned range) { 679 | return (unsigned)(sample_uniform() * range); 680 | } 681 | 682 | double sample_gaussian() { 683 | double x1, x2, w, y1; 684 | 685 | static bool returned = false; 686 | static double y2 = 0.0; 687 | 688 | if (returned) { 689 | returned = false; 690 | do { 691 | x1 = 2.0 * sample_uniform() - 1.0; 692 | x2 = 2.0 * sample_uniform() - 1.0; 693 | w = x1 * x1 + x2 * x2; 694 | } while ( w >= 1.0 ); 695 | 696 | w = sqrt((-2.0 * log(w)) / w); 697 | y1 = x1 * w; 698 | y2 = x2 * w; 699 | return y1; 700 | } else { 701 | returned = true; 702 | return y2; 703 | } 704 | } 705 | 706 | 707 | // Returns the file part of the path s 708 | string get_base_name(const string& s) { 709 | vector tokens; 710 | SplitStringUsing(s, "/", &tokens); 711 | return tokens.back(); 712 | } 713 | 714 | // Test whether this string ends in bz2 715 | bool is_bz2_file(const string& s) { 716 | vector tokens; 717 | SplitStringUsing(s, ".", &tokens); 718 | 719 | return tokens.back() == "bz2"; 720 | } 721 | 722 | // Ditto for gz 723 | bool is_gz_file(const string& s) { 724 | vector tokens; 725 | SplitStringUsing(s, ".", &tokens); 726 | 727 | return tokens.back() == "gz"; 728 | } 729 | 730 | 731 | void open_or_gz(string filename, istream* result_stream) { 732 | ifstream input_file(filename.c_str(), ios_base::in | ios_base::binary); 733 | CHECK(input_file.is_open()); 734 | boost::iostreams::filtering_streambuf in; 735 | if (is_gz_file(filename)) { 736 | in.push(boost::iostreams::gzip_decompressor()); 737 | } 738 | in.push(input_file); 739 | result_stream = new istream(&in); 740 | } 741 | -------------------------------------------------------------------------------- /gibbs-base.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | /* 17 | Copyright 2010 Joseph Reisinger 18 | 19 | Licensed under the Apache License, Version 2.0 (the "License"); 20 | you may not use this file except in compliance with the License. 21 | You may obtain a copy of the License at 22 | 23 | http://www.apache.org/licenses/LICENSE-2.0 24 | 25 | Unless required by applicable law or agreed to in writing, software 26 | distributed under the License is distributed on an "AS IS" BASIS, 27 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 28 | See the License for the specific language governing permissions and 29 | limitations under the License. 30 | */ 31 | // Contains basic statistics routines and data structures for a discrete Gibbs 32 | // sampler over some kind of document data. 33 | 34 | #ifndef GIBBS_BASE_H_ 35 | #define GIBBS_BASE_H_ 36 | 37 | #include 38 | #include 39 | 40 | #include 41 | 42 | #include 43 | 44 | #include 45 | #include 46 | #include 47 | #include 48 | #include 49 | #include 50 | 51 | #include 52 | 53 | #include 54 | #include 55 | 56 | #include 57 | #include 58 | 59 | #include 60 | #include 61 | 62 | 63 | #include "strutil.h" 64 | 65 | using namespace std; 66 | using namespace google::protobuf; 67 | 68 | // For vanilla LDA, we can save the state of the model directly in the raw 69 | // docify, using an extra : for each term denoting the topic assignment. This 70 | // allows us to do some fancy things like checkpoint and update documents on the 71 | // fly. 72 | DECLARE_int32(preassigned_topics); 73 | 74 | // Streaming makes some major structural changes to the code, moving the main loop into load_data; 75 | // This variable sets how many documents should be kept in memory at any one time 76 | DECLARE_int32(streaming); 77 | 78 | // Eta controls the amount of smoothing within the per-topic word distributions. 79 | // Higher eta = more smoothing. Also used in the GEM sampler. 80 | DECLARE_double(eta); 81 | 82 | // We can incorporate prior ranking information via eta. One way to do this is 83 | // to assume that eta is proportional to some exponential of the average word 84 | // rank, thus more highly ranked words have higher priors. eta_prior_scale 85 | // represents the degree of confidence (how fast to decay) in each rank, higher 86 | // meaning less informative prior. 87 | DECLARE_double(eta_prior_exponential); 88 | 89 | // the number of gibbs iterations to run for. 90 | DECLARE_int32(max_gibbs_iterations); 91 | 92 | // One problem with Gibbs sampling is that nearby samples are highly 93 | // correlated, throwing off the empirical distribution. In practice you need to 94 | // wait some amount of time before reading each (independent) sample. 95 | DECLARE_int32(sample_lag); 96 | 97 | // The random seed 98 | DECLARE_int32(random_seed); 99 | 100 | // Our tolerance for numerical (under)overflow. 101 | DECLARE_double(epsilon_value); 102 | 103 | // How many steps of not writing a best do we need before we declare 104 | // convergence? 105 | DECLARE_int32(convergence_interval); 106 | 107 | // Binarize the feature counts 108 | DECLARE_bool(binarize); 109 | 110 | // Should the best sample get output? 111 | DECLARE_bool(output_best); 112 | 113 | // Should the last sample get output? 114 | DECLARE_bool(output_last); 115 | 116 | class CRP; 117 | 118 | typedef google::sparse_hash_map WordToCountMap; 119 | typedef google::sparse_hash_map DocToWordCountMap; 120 | typedef google::dense_hash_map > DocToTopicChain; 121 | typedef google::dense_hash_map Docsize; 122 | typedef google::dense_hash_map DocWordToCount; 123 | 124 | typedef vector Document; 125 | 126 | typedef google::dense_hash_map DocumentMap; 127 | typedef google::dense_hash_map DocIDToTitle; 128 | typedef google::dense_hash_map TitleToDocID; 129 | 130 | typedef google::dense_hash_map WordCode; 131 | 132 | 133 | typedef google::dense_hash_map LevelWordToCountMap; 134 | typedef google::dense_hash_map LevelToCountMap; 135 | 136 | typedef google::dense_hash_map cluster_map; 137 | typedef map multiple_cluster_map; 138 | typedef google::dense_hash_map clustering; 139 | typedef map multiple_clustering; 140 | 141 | typedef google::dense_hash_map topic_map; 142 | 143 | const string kEmptyStringKey = "$$$EMPTY$$$"; 144 | const unsigned kEmptyUnsignedKey = UINT_MAX; 145 | const unsigned kDeletedUnsignedKey = UINT_MAX-1; 146 | const string kDeletedStringKey = "$$$DELETED$$$"; 147 | 148 | class sampler_entry { 149 | public: 150 | sampler_entry(unsigned index, double score) 151 | : index(index), score(score) { } 152 | double score; 153 | unsigned index; 154 | }; 155 | 156 | // A single node in the nCRP, corresponds to a table and also contains a list of 157 | // children, e.g. the tables in the restaurant that it points to. 158 | class CRP { 159 | public: 160 | CRP() : nwsum(0), label(""), ndsum(0) { 161 | nw.set_deleted_key(kDeletedUnsignedKey); 162 | nd.set_deleted_key(kDeletedUnsignedKey); 163 | } 164 | CRP(unsigned l, unsigned customers) 165 | : level(l), nwsum(0), lp(0), label(""), ndsum(customers) { 166 | nw.set_deleted_key(kDeletedUnsignedKey); 167 | nd.set_deleted_key(kDeletedUnsignedKey); 168 | // nw.set_empty_key(kEmptyUnsignedKey); 169 | // nd.set_empty_key(kEmptyUnsignedKey); 170 | } 171 | CRP(unsigned l, unsigned customers, CRP* p) 172 | : level(l), nwsum(0), lp(0), label(""), ndsum(customers) { 173 | prev.push_back(p); 174 | nw.set_deleted_key(kDeletedUnsignedKey); 175 | nd.set_deleted_key(kDeletedUnsignedKey); 176 | // nw.set_empty_key(kEmptyUnsignedKey); 177 | // nd.set_empty_key(kEmptyUnsignedKey); 178 | } 179 | ~CRP(); 180 | 181 | // Update ndsum to reflect the actual document assignments 182 | void add(unsigned w, unsigned d) { 183 | add_no_ndsum(w,d); 184 | 185 | if (nd[d] == 1) { // Added from a new doc 186 | ndsum += 1; 187 | } 188 | } 189 | void add_no_ndsum(unsigned w, unsigned d) { 190 | CHECK_GE(nw[w], 0); 191 | CHECK_GE(nw[w], 0); 192 | CHECK_GE(nd[d], 0); 193 | nw[w] += 1; 194 | nwsum += 1; 195 | nd[d] += 1; 196 | } 197 | 198 | void remove(unsigned w, unsigned d) { 199 | remove_no_ndsum(w,d); 200 | 201 | if (nd[d] == 0) { // Added from a new doc 202 | nd.erase(d); 203 | ndsum -= 1; 204 | } 205 | 206 | CHECK_GE(ndsum, 0); 207 | } 208 | 209 | void remove_no_ndsum(unsigned w, unsigned d) { 210 | nw[w] -= 1; 211 | nwsum -= 1; 212 | nd[d] -= 1; 213 | CHECK_GE(nw[w], 0); 214 | CHECK_GE(nwsum, 0); 215 | CHECK_GE(nd[d], 0); 216 | if (nw[w] == 0) { 217 | nw.erase(w); 218 | } 219 | } 220 | 221 | void remove_doc(const Document& D, unsigned d) { 222 | for (int n = 0; n < D.size(); n++) { 223 | unsigned w = D.at(n); 224 | // Remove this document and word from the counts 225 | remove(w,d); 226 | } 227 | } 228 | 229 | void add_doc(const Document& D, unsigned d) { 230 | for (int n = 0; n < D.size(); n++) { 231 | unsigned w = D.at(n); 232 | // Remove this document and word from the counts 233 | add(w,d); 234 | } 235 | } 236 | 237 | 238 | void remove_from_parents(); 239 | 240 | public: 241 | WordToCountMap nw; // number of words equal to w in this node 242 | DocToWordCountMap nd; // number of words from doc d in this node 243 | 244 | unsigned level; 245 | 246 | unsigned nwsum; // number of words in this node 247 | unsigned ndsum; // number of docs in this node (same as m) 248 | 249 | vector prev; // the parents of this node in the DAG 250 | 251 | vector tables; // the tables in the next restaurant 252 | 253 | double lp; // probability of reaching this node, used in posterior_c 254 | 255 | // the node label from WN or whatever hierarchy (defaults to none) 256 | string label; 257 | }; 258 | 259 | // The hLDA base class, contains code common to the Multinomial (fixed-depth) 260 | // and GEM (infinite-depth) samplers 261 | class GibbsSampler { 262 | public: 263 | // Initialization routines; must be called before run. Sets up the tree 264 | // and list views of the nCRP and does the initial level assignment. The 265 | // tree is grown incrementally from a single branch using the same 266 | // agglomerative method as resample_posterior_c. This procedure is 267 | // recommended by Blei. 268 | GibbsSampler() { 269 | _unique_word_count = 0; 270 | _total_word_count = 0; 271 | _lD = 0; 272 | _lV = 0; 273 | 274 | _iter = 0; 275 | _best_ll = 0; 276 | _best_iter = 0; 277 | 278 | _eta_sum = 0; 279 | 280 | _converged_iterations = 0; 281 | 282 | _eta_sum = 0; // fix a particularly nasty bug 283 | 284 | _D.set_empty_key(kEmptyUnsignedKey); 285 | _D.set_deleted_key(kDeletedUnsignedKey); 286 | _word_name_to_id.set_empty_key(kEmptyStringKey); 287 | _word_id_to_name.set_empty_key(kEmptyUnsignedKey); 288 | _document_name.set_empty_key(kEmptyUnsignedKey); 289 | _document_id.set_empty_key(kEmptyStringKey); 290 | _nd.set_empty_key(kEmptyUnsignedKey); 291 | _initial_topic_assignment.set_empty_key(kEmptyUnsignedKey); 292 | _initial_topic_assignment.set_deleted_key(kDeletedUnsignedKey); 293 | } 294 | virtual ~GibbsSampler() { /* TODO: free memory! */ } 295 | 296 | // Allocate all documents at once 297 | virtual void batch_allocation() { 298 | LOG(FATAL) << "NYI batch_allocation"; 299 | } 300 | 301 | // Allocate a document (called right after the document is read from the file) 302 | virtual void allocate_document(unsigned d) { 303 | LOG(FATAL) << "NYI allocate_document"; 304 | } 305 | // Deallocate a document to conserve memory 306 | virtual void deallocate_document() { 307 | LOG(FATAL) << "NYI deallocate_document"; 308 | } 309 | 310 | // Allocate new_d into the model and (maybe) kick out an old document 311 | void streaming_step(unsigned new_d); 312 | 313 | // Check log-likelihood and write out some samples if necessary 314 | bool sample_and_check_for_convergence(); 315 | 316 | // Run the Gibbs sampler for max_gibbs_iterations iterations. This is the main 317 | // workhorse loop. 318 | void run(); 319 | 320 | // Load a data file 321 | virtual void load_data(const string& filename); 322 | 323 | // Process a single document line from the file 324 | bool process_document_line(const string& curr_line, unsigned line_no); 325 | 326 | // Write some summary of the output 327 | virtual void write_data(string prefix) = 0; 328 | 329 | // return a string describing the current state of the hyperparameters 330 | // and log likelihood 331 | virtual string current_state() = 0; 332 | 333 | string show_chopped_sorted_nw(const WordToCountMap& nw); 334 | 335 | protected: 336 | virtual void resample_posterior() = 0; 337 | 338 | virtual double compute_log_likelihood() = 0; 339 | 340 | protected: 341 | WordToCountMap _V; // vocabulary keys mapping to corpus counts 342 | unsigned _lV; // size of vocab 343 | unsigned _lD; // number of documents 344 | 345 | unsigned _unique_word_count; 346 | unsigned _total_word_count; 347 | 348 | WordCode _word_id_to_name; // uniqe_id to string 349 | google::dense_hash_map _word_name_to_id; 350 | 351 | DocumentMap _D; // documents indexed by unique # 352 | DocumentMap _initial_topic_assignment; // initial term-topic assignment when FLAGS_preassigned_topics=1 353 | 354 | DocIDToTitle _document_name; // doc_number to title 355 | TitleToDocID _document_id; 356 | 357 | Docsize _nd; // number of words in document d 358 | 359 | double _ll; // current log-likelihood 360 | double _best_ll; 361 | int _best_iter; 362 | int _iter; 363 | 364 | vector _eta; // Smoother for document likelihood 365 | double _eta_sum; 366 | 367 | // Test for convergence 368 | unsigned _converged_iterations; 369 | }; 370 | 371 | typedef std::pair word_score; 372 | bool word_score_comp(const word_score& left, const word_score& right); 373 | 374 | void init_random(); 375 | 376 | // Safely remove an element from a list 377 | void safe_remove_crp(vector* v, const CRP*); 378 | 379 | // These are adapted from Hal Daume's HBC: 380 | 381 | // Logarithm of the gamma function. 382 | double gammaln(double x); 383 | 384 | // Log factorial 385 | inline double factln(double x) { return gammaln(x+1); } 386 | 387 | long double addLog(long double x, long double y); 388 | void normalizeLog(vector*x); 389 | void normalizeLog(vector >*x); 390 | 391 | // Given a multinomial distribution of the form {label:prob}, return a label 392 | // with that probability. 393 | inline int sample_normalized_multinomial(vector*d); 394 | inline int sample_normalized_multinomial(vector >*d); 395 | unsigned sample_unnormalized_log_multinomial(vector*d); 396 | unsigned sample_unnormalized_log_multinomial(vector >*d); 397 | sampler_entry NEW_sample_unnormalized_log_multinomial(vector*d); 398 | int SAFE_sample_unnormalized_log_multinomial(vector*d); 399 | int SAFE_sample_unnormalized_log_multinomial(vector >*d); 400 | 401 | unsigned sample_integer(unsigned range); 402 | double sample_gaussian(); 403 | double sample_uniform(); 404 | 405 | string get_base_name(const string& s); 406 | // filtering_ostream get_bz2_ostream(const string& filename); 407 | bool is_bz2_file(const string& s); 408 | bool is_gz_file(const string& s); 409 | void open_or_gz(string filename, istream* result_stream); 410 | #endif // GIBBS_BASE_H_ 411 | -------------------------------------------------------------------------------- /ncrp-base.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // This file contains all of the common code shared between the Mutinomial hLDA 17 | // sampler (fixed-depth tree) and the GEM hLDA sampler (infinite-depth tree), as 18 | // well as the fixed topic structure sampler used to learn WN senses. 19 | // 20 | // CRPNode is the basic data structure for a single node in the topic tree (dag) 21 | 22 | 23 | #ifndef CRP_BASE_H_ 24 | #define CRP_BASE_H_ 25 | 26 | // TODO:: currently we have these depencies because I couldn't figure out 27 | // how to (easily) write plain ascii files to the disk. One solution is to 28 | // figure out how to do that, a better, more long-term friendly solution is to 29 | // design a protocol buffer to store the output samples. This would be more 30 | // space efficient, among other things. 31 | #include 32 | #include 33 | 34 | #include 35 | 36 | #include 37 | //#include 38 | #include 39 | #include 40 | #include 41 | #include 42 | 43 | #include "gibbs-base.h" 44 | 45 | #include "strutil.h" 46 | 47 | using namespace std; 48 | using namespace google::protobuf; 49 | 50 | // number of times to resample z (the level assignments) per iteration of c (the 51 | // tree sampling) 52 | DECLARE_int32(ncrp_z_per_iteration); 53 | 54 | // this is the actual data file containing a list of attributes on each line, 55 | // tab separated, with the first entry being the class label. 56 | DECLARE_string(ncrp_datafile); 57 | 58 | // Alpha controls the topic smoothing, with higher alpha causing more "uniform" 59 | // distributions over topics. This is replaced by m and pi in the GEM sampler. 60 | DECLARE_double(ncrp_alpha); 61 | 62 | 63 | // Gamma controls the probability of creating new brances in both the 64 | // Multinomial and GEM sampler; has no effect in the fixed-structure sampler. 65 | DECLARE_double(ncrp_gamma); 66 | 67 | // Setting this to true interleaves Metropolis-Hasting steps in between the 68 | // Gibbs steps to update the hyperparameters. Currently it is only implemented 69 | // in the basic version. 70 | DECLARE_bool(ncrp_update_hyperparameters); 71 | 72 | // Setting this to true causes the hyperparameter gamma to be scaled by m, the 73 | // number of documents attached to the node. This makes branching into a 74 | // constant proportion (roughly \gamma / (\gamma + 1)) indepedent of node size 75 | // (slighly more intuitive behavior). If this isn't set, you're likely to get 76 | // long chains instead of branches 77 | DECLARE_bool(ncrp_m_dependent_gamma); 78 | 79 | // This places an (artificial) cap on the number of branches possible from each 80 | // node, reducing the width of the tree, but sacrificing the generative 81 | // semantics of the model. -1 is the default for no capping. 82 | DECLARE_int32(ncrp_max_branches); 83 | 84 | // Parameter controlling the depth of the tree. Any interior node can have an 85 | // arbitrary number of branches, but paths down to the leaves are constrained 86 | // to be exactly this length. 87 | DECLARE_int32(ncrp_depth); 88 | 89 | // If set to true, don't assign any words to the root node; this still maintains 90 | // the generative semantics of the model, but gives us a free implementation of 91 | // the dirichlet process (L=2, skip root) and as well as mixture of ncrps. 92 | DECLARE_bool(ncrp_skip_root); 93 | 94 | // Setting this forces the topic topology to consist of a length L-1 chain 95 | // followed by a set of leaves at the end. Basically the idea is to get a set of 96 | // L-1 "noise" topics and a single "signal" topic; so this is really an 97 | // implementation of prix-fixe with more than one noise. 98 | DECLARE_bool(ncrp_prix_fixe); 99 | 100 | // Eta depth scale: multiply eta by eta_depth_scale**depth for nodes at that 101 | // depth; essentially eta_depth_scale=0.5 will lead to more mass at higher 102 | // nodes, as opposed to leaves 103 | DECLARE_double(ncrp_eta_depth_scale); 104 | 105 | // The hLDA base class, contains code common to the Multinomial (fixed-depth) 106 | // and GEM (infinite-depth) samplers 107 | class NCRPBase : public GibbsSampler { 108 | public: 109 | // Initialization routines; must be called before run. Sets up the tree 110 | // and list views of the nCRP and does the initial level assignment. The 111 | // tree is grown incrementally from a single branch using the same 112 | // agglomerative method as resample_posterior_c. This procedure is 113 | // recommended by Blei. 114 | NCRPBase(); 115 | virtual ~NCRPBase() { /* TODO: free memory! */ } 116 | 117 | // Allocate all the documents at once (called for non-streaming) 118 | void batch_allocation(); 119 | 120 | // Allocate a single document; can be called during load for streaming 121 | void allocate_document(unsigned d); 122 | 123 | // Deallocate a random document from the model 124 | void deallocate_document(); 125 | 126 | // Write out the learned tree in dot format 127 | virtual void write_data(string prefix); 128 | 129 | // return a string describing the current state of the hyperparameters 130 | // and log likelihood 131 | virtual string current_state() = 0; 132 | 133 | protected: 134 | virtual void resample_posterior_z_for(unsigned d, bool remove) = 0; 135 | void resample_posterior_c_for(unsigned d); 136 | 137 | void calculate_path_probabilities_for_subtree(CRP* root, 138 | unsigned d, 139 | unsigned max_depth, 140 | LevelWordToCountMap& nw_removed, 141 | LevelToCountMap& nwsum_removed, 142 | vector* lp_c_d, 143 | vector* c_d); 144 | 145 | virtual double compute_log_likelihood() = 0; 146 | 147 | bool tree_is_consistent(); // check the consistency of the tree 148 | 149 | // Returns a list of nodes in this path. If node is internal, then it grows 150 | // down to depth L. If node is a leaf, then it just returns the path to the 151 | // root 152 | void graft_path_at(CRP* node, vector* chain, unsigned depth); 153 | 154 | void print_summary(); 155 | 156 | protected: 157 | 158 | // Parameters 159 | unsigned _L; 160 | 161 | // eta = "scale" of the topic model, higher eta = more general / fewer topics. 162 | vector _alpha; 163 | double _gamma; // hyperparams 164 | 165 | double _alpha_sum; // normalization constants 166 | 167 | DocWordToCount _z; // level assignments per document, word 168 | DocToTopicChain _c; // CRP nodes for a document m 169 | 170 | CRP* _ncrp_root; // tree representation of the nCRP. 171 | CRP* _reject_node; // special node containing rejected attributes. 172 | 173 | unsigned _unique_nodes; // number of leaves in the tree 174 | 175 | string _filename; // output file name 176 | 177 | unsigned _total_words; // total number of words added 178 | }; 179 | 180 | #endif // CRP_BASE_H_ 181 | -------------------------------------------------------------------------------- /sample-clustered-lda-main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "ncrp-base.h" 17 | #include "sample-clustered-lda.h" 18 | 19 | int main(int argc, char **argv) { 20 | google::InitGoogleLogging(argv[0]); 21 | google::ParseCommandLineFlags(&argc, &argv, true); 22 | 23 | init_random(); 24 | 25 | ClusteredLDA h = ClusteredLDA(); 26 | h.initialize(); 27 | 28 | h.run(); 29 | } 30 | -------------------------------------------------------------------------------- /sample-clustered-lda.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Basic implementation of clustered LDA with multinomial likelihood 17 | 18 | #ifndef SAMPLE_CLUSTERED_LDA_H_ 19 | #define SAMPLE_CLUSTERED_LDA_H_ 20 | 21 | #include 22 | #include 23 | 24 | #include "ncrp-base.h" 25 | 26 | // Number of topics 27 | DECLARE_int32(T); 28 | 29 | // Number of clusters 30 | DECLARE_int32(K); 31 | 32 | // Smoother on clusters 33 | DECLARE_double(xi); 34 | 35 | // File containing the words and their features 36 | DECLARE_string(word_features_file); 37 | 38 | // Should clustering be constrained to only guys that have the same type? I.e. 39 | // whereas before we would draw our word-clusters across the entire word-feature 40 | // space, now the clustering is type-dependent. 41 | DECLARE_bool(type_dependent_clustering); 42 | 43 | class Feature { 44 | public: 45 | Feature(unsigned feature_id, unsigned count) : _feature_id(feature_id), _count(count) { } 46 | 47 | public: 48 | unsigned _feature_id; 49 | unsigned _count; 50 | }; 51 | 52 | class NestedDocument { 53 | public: 54 | class WordFeatures { 55 | public: 56 | WordFeatures(unsigned word_id, string word_name, unsigned word_type_id) 57 | : _word_id(word_id), _word_name(word_name), _word_type_id(word_type_id), _topic_indicator(0) { } 58 | 59 | void uniform_initialization() { 60 | _topic_indicator = sample_integer(FLAGS_T); 61 | _cluster_indicator = sample_integer(FLAGS_K); 62 | } 63 | 64 | public: 65 | unsigned _word_id; 66 | unsigned _word_type_id; 67 | string _word_name; 68 | 69 | unsigned _topic_indicator; 70 | unsigned _cluster_indicator; 71 | }; 72 | 73 | NestedDocument(string doc_name, unsigned doc_id) : _doc_name(doc_name), _doc_id(doc_id) { } 74 | 75 | public: 76 | string _doc_name; 77 | unsigned _doc_id; 78 | vector _words; 79 | }; 80 | 81 | typedef google::dense_hash_map > WordIDToFeatures; 82 | 83 | // Clustered Topic model with dirichlet-multinomial likelihood 84 | class ClusteredLDA : public NCRPBase { 85 | public: 86 | ClusteredLDA() { } 87 | 88 | // Set up initial assignments and load the doc->word and word->feature maps 89 | void initialize(); 90 | 91 | // Write out a static dictionary required for decoding Gibbs samples 92 | void write_dictionary(); 93 | 94 | // Write out the Gibbs sample 95 | void write_data(string prefix); 96 | 97 | double compute_log_likelihood(); 98 | protected: 99 | // Load a document -> word file 100 | void load_documents(const string& filename); 101 | 102 | // Load a word -> features file 103 | void load_words(const string& filename); 104 | 105 | void resample_posterior(); 106 | void resample_posterior_z_for(unsigned d, bool remove); 107 | void resample_posterior_w_for(unsigned d); 108 | 109 | string current_state(); 110 | 111 | protected: 112 | vector _D; 113 | vector _master_cluster; 114 | google::dense_hash_map > _cluster; 115 | vector _topic; 116 | 117 | unsigned _unique_type_count; 118 | 119 | unsigned _unique_feature_count; 120 | unsigned _total_feature_count; 121 | 122 | WordCode _feature_id_to_name; // uniqe_id to string 123 | google::dense_hash_map _feature_name_to_id; 124 | 125 | WordCode _type_id_to_name; // uniqe_id to string 126 | google::dense_hash_map _type_name_to_id; 127 | google::dense_hash_map _word_id_to_type_id; 128 | 129 | WordIDToFeatures _features; 130 | 131 | 132 | vector _xi; 133 | double _xi_sum; 134 | 135 | // Base names of the flags 136 | string _word_features_moniker; 137 | string _datafile_moniker; 138 | }; 139 | 140 | #endif // SAMPLE_CLUSTERED_LDA_H_ 141 | -------------------------------------------------------------------------------- /sample-crosscat-mm-main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "sample-crosscat-mm.h" 17 | 18 | int main(int argc, char **argv) { 19 | google::InitGoogleLogging(argv[0]); 20 | google::ParseCommandLineFlags(&argc, &argv, true); 21 | 22 | init_random(); 23 | 24 | CrossCatMM h = CrossCatMM(); 25 | h.load_data(FLAGS_mm_datafile); 26 | h.run(); 27 | } 28 | -------------------------------------------------------------------------------- /sample-crosscat-mm.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Basic implementation of clustered LDA with multinomial likelihood 17 | 18 | #ifndef SAMPLE_CROSSCAT_MM_H_ 19 | #define SAMPLE_CROSSCAT_MM_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | #include "gibbs-base.h" 26 | 27 | // the number of feature clusters 28 | DECLARE_int32(M); 29 | 30 | // the maximum number of clusters 31 | DECLARE_int32(KMAX); 32 | 33 | // Smoother on clustering 34 | DECLARE_double(mm_alpha); 35 | 36 | // Smoother on cross cat clustering 37 | DECLARE_double(cc_xi); 38 | 39 | // File holding the data 40 | DECLARE_string(mm_datafile); 41 | 42 | // Number of feature moves to make 43 | DECLARE_double(cc_feature_move_rate); 44 | 45 | // If toggled, the first view will be constrained to a single cluster 46 | DECLARE_bool(cc_include_noise_view); 47 | 48 | // Basically controls whether and how we should do cross-cat on the features. 49 | // Implemented using MH steps. 50 | DECLARE_string(cross_cat_prior); 51 | 52 | const string kCrossCatOff = "off"; 53 | 54 | // typedef google::dense_hash_map collapsed_document; 55 | typedef map collapsed_document; 56 | // typedef google::dense_hash_map collapsed_document_collection; 57 | typedef map collapsed_document_collection; 58 | 59 | // Implements several kinds of mixture models (uniform prior, Dirichlet prior, 60 | // DPCrossCatMM all with DP-Multinomial likelihood. 61 | class CrossCatMM : public GibbsSampler { 62 | public: 63 | CrossCatMM() { } 64 | 65 | // Allocate all the documents at once (called for non-streaming) 66 | void batch_allocation(); 67 | 68 | double compute_log_likelihood(); 69 | 70 | void write_data(string prefix); 71 | protected: 72 | void resample_posterior(); 73 | void resample_posterior_z_for(unsigned d, unsigned m, bool remove); 74 | void resample_posterior_m(double percent); 75 | void resample_posterior_m_for(unsigned tw); 76 | 77 | 78 | string current_state(); 79 | 80 | double compute_log_likelihood_for(unsigned m, clustering& cm); 81 | double cross_cat_clustering_log_likelihood(unsigned w, unsigned m); 82 | double cross_cat_reassign_features(unsigned old_m, unsigned new_m, unsigned w); 83 | 84 | protected: 85 | // Maps documents to clusters 86 | multiple_cluster_map _z; // Map (data point, clustering) -> cluster_id 87 | multiple_clustering _c; // Map [w][z] -> CRP 88 | 89 | cluster_map _m; // Map vocab -> cluster 90 | clustering _b; 91 | 92 | // Base names of the flags 93 | string _word_features_moniker; 94 | string _datafile_moniker; 95 | 96 | unsigned _ndsum; 97 | 98 | vector _current_component; // # of clusters currently 99 | 100 | // Document clustering smoother 101 | vector _alpha; 102 | double _alpha_sum; 103 | 104 | // Cross-cat smoother 105 | vector _xi; 106 | double _xi_sum; 107 | 108 | string _output_filename; 109 | 110 | // Count the number of feature movement proposals and the number that 111 | // fail 112 | unsigned _m_proposed; 113 | unsigned _m_failed; 114 | 115 | // Count cluster moves 116 | unsigned _c_proposed; 117 | unsigned _c_failed; 118 | 119 | // Data structure to hold the documents optimzed for mixture model 120 | // computation (i.e. we don't need to break up each feature into 121 | // occurrences, and can instead treat the count directly) 122 | collapsed_document_collection _DD; 123 | }; 124 | 125 | #endif // SAMPLE_CROSSCAT_MM_H_ 126 | -------------------------------------------------------------------------------- /sample-fixed-ncrp-main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "ncrp-base.h" 17 | #include "sample-fixed-ncrp.h" 18 | 19 | int main(int argc, char **argv) { 20 | google::InitGoogleLogging(argv[0]); 21 | google::ParseCommandLineFlags(&argc, &argv, true); 22 | 23 | init_random(); 24 | 25 | GEMNCRPFixed h = GEMNCRPFixed(FLAGS_gem_m, FLAGS_gem_pi); 26 | h.load_data(FLAGS_ncrp_datafile); 27 | h.load_tree_structure(FLAGS_tree_structure_file); 28 | 29 | h.run(); 30 | } 31 | -------------------------------------------------------------------------------- /sample-fixed-ncrp.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Samples from the Nested Chinese Restaurant Process using a fixed topic 17 | // structure. This model is more expressive in that the topic structure is not 18 | // constrained to be a tree, but rather a digraph. 19 | 20 | #ifndef SAMPLE_GEM_FIXED_NCRP_H_ 21 | #define SAMPLE_GEM_FIXED_NCRP_H_ 22 | 23 | #include 24 | #include 25 | 26 | #include "ncrp-base.h" 27 | 28 | // The GEM(m, \pi) distribution hyperparameter m, controls the "proportion of 29 | // general words relative to specific words" 30 | DECLARE_double(gem_m); 31 | 32 | // The GEM(m, \pi) hyperparameter \pi: reflects how strictly we expect the 33 | // documents to adhere to the m proportions. 34 | DECLARE_double(gem_pi); 35 | 36 | // The file path from which to load the topic structure. The file must be 37 | // encoded as one connection per line, child parent. 38 | DECLARE_string(tree_structure_file); 39 | 40 | // Whether or not to use the GEM sampler. The Multinomial sampler currently is 41 | // more flexible as it allows the tree structure to be a DAG; the GEM sampler 42 | // might not work yet with DAGs. 43 | DECLARE_bool(gem_sampler); 44 | 45 | // If unset, then just throw away extra edges that cause nodes to have multiple 46 | // parents. Enforcing a tree topology. 47 | DECLARE_bool(use_dag); 48 | 49 | // Should non-WN class nodes have words assigned to them? If not, then all 50 | // topics will start with wn_ 51 | DECLARE_bool(fold_non_wn); 52 | 53 | // Should we perform variable selection (i.e. attribute rejection) based on 54 | // adding a single "REJECT" node with a uniform distribution over the 55 | // vocabulary to each topic list? 56 | DECLARE_bool(use_reject_option); 57 | 58 | // Should the hyperparameters on the vocabulary Dirichlet (eta) be learned. For 59 | // now this uses moment matchin to perform the updates. 60 | DECLARE_bool(learn_eta); 61 | 62 | // Should all the path combinations to the root be separated out into different 63 | // documents? DAG only. 64 | DECLARE_bool(separate_path_assignments); 65 | 66 | // Should we try to learn a single best sense from a list of senses? 67 | DECLARE_bool(sense_selection); 68 | 69 | typedef google::dense_hash_map DocSenseToTopicChain; 70 | typedef google::dense_hash_map > DocSenseWordToCount; 71 | 72 | typedef google::dense_hash_map > NodeLogFrequencyMap; 73 | 74 | // This version differs from the normal GEM sampler in that the tree structure 75 | // is fixed a priori. Hence there is no resampling of c, the path allocations. 76 | class GEMNCRPFixed : public NCRPBase { 77 | public: 78 | GEMNCRPFixed(double m, double pi); 79 | ~GEMNCRPFixed() { /* TODO: free memory! */ } 80 | 81 | string current_state(); 82 | 83 | void load_tree_structure(const string& filename); 84 | void load_precomputed_tree_structure(const string& filename); 85 | 86 | protected: 87 | void resample_posterior(); 88 | void resample_posterior_z_for(unsigned d, bool remove) { resample_posterior_z_for(d, _c[d], _z[d]); } 89 | void resample_posterior_z_for(unsigned d, vector& cd, WordToCountMap& zd); 90 | void resample_posterior_c_for(unsigned d); // used in sense selection 91 | void resample_posterior_eta(); 92 | 93 | double compute_log_likelihood(); 94 | 95 | void contract_tree(); 96 | 97 | void build_path_assignments(CRP* node, vector* c, int sense_index); 98 | void build_separate_path_assignments(CRP* node, vector< vector >* paths); 99 | 100 | // Assume that all the words for document d have been assigned using the level 101 | // assignment zd, now remove them all. 102 | void remove_all_words_from(unsigned d, vector& cd, WordToCountMap& zd); 103 | 104 | // Assume that all the words for document d have been removed, now add them 105 | // back using the level assignment zd, now remove them all. 106 | void add_all_words_from(unsigned d, vector& cd, WordToCountMap& zd); 107 | 108 | // Returns the (unnormalized) path probability for document d given the current 109 | // set of _z assignments 110 | double compute_path_probability_for(unsigned d, vector& cd); 111 | 112 | 113 | protected: 114 | double _gem_m; 115 | double _pi; 116 | 117 | // These hold other possible sense attachments that are not currently 118 | // in use. 119 | DocSenseWordToCount _z_shadow; // level assignments per document, word 120 | DocSenseToTopicChain _c_shadow; // CRP nodes for a document m 121 | 122 | NodeLogFrequencyMap _log_node_freq; // Gives the frequency of a sense attachment 123 | 124 | unsigned _maxL; 125 | }; 126 | 127 | #endif // SAMPLE_GEM_FIXED_NCRP_H_ 128 | -------------------------------------------------------------------------------- /sample-gem-ncrp.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // This implements (I think) the GEM distributed hLDA, which has an infinitely 17 | // deep tree, where each node can have (possibly) infinitely many branches. 18 | // There are some issues with implementing this sampler that I (and others) 19 | // have raised on the Princeton topic models list, but so far no one has come 20 | // up with any answers. In any case the sampler /seems/ to work just fine. 21 | 22 | #include "ncrp-base.h" 23 | #include "sample-gem-ncrp.h" 24 | 25 | // The GEM(m, \pi) distribution hyperparameter m, controls the "proportion of 26 | // general words relative to specific words" 27 | DEFINE_double(gem_m, 28 | 0.1, 29 | "m reflects the proportion of general words to specific words"); 30 | 31 | // The GEM(m, \pi) hyperparameter \pi: reflects how strictly we expect the 32 | // documents to adhere to the m proportions. 33 | DEFINE_double(gem_pi, 34 | 0.1, 35 | "reflects our confidence in the setting m"); 36 | 37 | GEMNCRP::GEMNCRP(double m, double pi) 38 | : _gem_m(m), _pi(pi), _maxL(0) { 39 | _L = 3; // for now we need an initial depth (just for the first data pt) 40 | _maxL = _L; 41 | 42 | CHECK_GE(_gem_m, 0.0); 43 | CHECK_LE(_gem_m, 1.0); 44 | } 45 | 46 | 47 | string GEMNCRP::current_state() { 48 | // HACK: put this in here for now since it needs to get updated whenever 49 | // maxL changes 50 | _filename = FLAGS_ncrp_datafile; 51 | _filename += StringPrintf("-L%d-m%f-pi%f-gamma%f-eta%f-zpi%d-best.data", 52 | _maxL, _gem_m, _pi, _gamma, 53 | _eta_sum / (double)_eta.size(), 54 | FLAGS_ncrp_z_per_iteration); 55 | 56 | return StringPrintf( 57 | "ll = %f (%f at %d) %d m = %f pi = %f eta = %f gamma = %f L = %d", 58 | _ll, _best_ll, _best_iter, _unique_nodes, _gem_m, _pi, 59 | _eta_sum / (double)_eta.size(), _gamma, 60 | _maxL); 61 | } 62 | 63 | // Performs a single document's level assignment resample step There is some 64 | // significant subtly to this compared to the fixed depth version. First, we no 65 | // longer have the guarantee that all documents are attached to leaves. Some 66 | // documents may now stop at interior nodes of our current tree. This is ok, 67 | // since the it just means that we haven't assigned any words to lower levels, 68 | // and hence we haven't needed to actually assert what the path is. Anyway, 69 | // since the level assignments can effectively change length here, we need to 70 | // get more child nodes from the nCRP on the fly. 71 | void GEMNCRP::resample_posterior_z_for(unsigned d, bool remove) { 72 | // CHECK_EQ(_L, -1); // HACK to make sure we're not using _L 73 | CHECK(false) << "implement remove"; 74 | 75 | CHECK(!FLAGS_ncrp_skip_root); 76 | for (int n = 0; n < _D[d].size(); n++) { // loop over every word 77 | unsigned w = _D[d][n]; 78 | // Compute the new level assignment # 79 | _c[d][_z[d][n]]->nw[w] -= 1; // number of words in topic z equal to w 80 | _c[d][_z[d][n]]->nd[d] -= 1; // number of words in doc d with topic z 81 | _c[d][_z[d][n]]->nwsum -= 1; // number of words in topic z 82 | _nd[d] -= 1; // number of words in doc d 83 | 84 | 85 | CHECK_GE(_c[d][_z[d][n]]->nwsum, 0); 86 | CHECK_GE(_c[d][_z[d][n]]->nw[w], 0); 87 | CHECK_GE(_nd[d], 0); 88 | CHECK_GT(_c[d][_z[d][n]]->ndsum, 0); 89 | 90 | // ndsum_above[k] is #[z_{d,-n} >= k] 91 | vector ndsum_above; 92 | ndsum_above.resize(_c[d].size()); 93 | ndsum_above[_c[d].size()-1] = _c[d].back()->nd[d]; 94 | // LOG(INFO) << "ndsum_above[" << _c[d].size()-1 << "] = " 95 | // << ndsum_above.back(); 96 | for (int l = _c[d].size()-2; l >= 0; l--) { 97 | // TODO:: optimize this 98 | ndsum_above[l] = _c[d][l]->nd[d] + ndsum_above[l+1]; 99 | // LOG(INFO) << "ndsum_above[" << l << "] = " << ndsum_above[l]; 100 | } 101 | 102 | // Here we assign probabilities to all the "finite" options, e.g. all 103 | // the levels up to the current maximum level for this document. TODO:: 104 | // this can be optimized quite extensively 105 | vector lposterior_z_dn; 106 | double lp_z_dn_sum = 0; 107 | double V_j_sum = 0; 108 | unsigned total_nd = 0; 109 | for (int l = 0; l < _c[d].size(); l++) { 110 | // check that ["doesnt exist"]->0 111 | // DCHECK(_c[d][l]->nw.find(w) != _c[d][l]->nw.end() || _c[d][l]->nw[w] == 0); 112 | // DCHECK(_c[d][l]->nd.find(d) != _c[d][l]->nd.end() || _c[d][l]->nd[d] == 0); 113 | total_nd += _c[d][l]->nd[d]; 114 | 115 | double lp_w_dn = log(_eta[w] + _c[d][l]->nw[w]) - 116 | log(_eta_sum + _c[d][l]->nwsum); 117 | double lp_z_dn = log(_pi*(1-_gem_m) + _c[d][l]->nd[d]) - 118 | log(_pi + ndsum_above[l]) + V_j_sum; 119 | 120 | lposterior_z_dn.push_back(lp_w_dn + lp_z_dn); 121 | 122 | if (l < _c[d].size()-1) { 123 | V_j_sum += log(_gem_m*_pi + ndsum_above[l+1]) - 124 | log(_pi + ndsum_above[l]); 125 | } 126 | lp_z_dn_sum = addLog(lp_z_dn_sum, lp_z_dn); 127 | } 128 | // DCHECK_EQ(total_nd, _nd[d]); 129 | 130 | 131 | // If the "new" entry is sampled, we have to determine actually at what 132 | // level to attach the word. This is done by repeatedly sampling from a 133 | // binomial, that results in sampling from the GEM. 134 | // XXX: this is from the earlier version of the paper... 135 | unsigned new_max_level = _c[d].size()+1; 136 | while (sample_uniform() < _gem_m) { 137 | new_max_level += 1; 138 | } 139 | 140 | // The next big block computes lp_w_dn for the new level assignment 141 | // (hopefully without taking too much computation). 142 | double lp_w_dn = 0; 143 | CRP* new_leaf = _c[d].back(); 144 | 145 | // if there are things attached below here 146 | if (_c[d].back()->tables.size() > 0) { 147 | // then sample from them 148 | 149 | // We now need to have selected a branch to depth new_max_level 150 | // Theoretically, we've already done this, since the GEM distribution and 151 | // nCRP draws are infinite. However, obviously, this is not the case. 152 | // Instead, we need to proceed as if we alread know down what branch to 153 | // extend the level allocations. We do this by post-hoc selecting a 154 | // subtree path starting at the old end of c[d] and augmenting it to get 155 | // a path to new_max_level 156 | 157 | // Keep track of the removed counts for computing the likelihood of 158 | // the data 159 | // HACK: we only remove one at a time, so this can be optimized.... 160 | LevelWordToCountMap nw_removed; 161 | LevelToCountMap nwsum_removed; 162 | nw_removed[_z[d][n]][w] += 1; 163 | nwsum_removed[_z[d][n]] += 1; 164 | 165 | vector lp_c_d; // log-probability of this branch c_d 166 | vector c_d; // the actual branch c_d 167 | calculate_path_probabilities_for_subtree(_c[d].back(), d, new_max_level, 168 | nw_removed, nwsum_removed, 169 | &lp_c_d, &c_d); 170 | 171 | // Choose a new leaf node 172 | int index = sample_unnormalized_log_multinomial(&lp_c_d); 173 | 174 | new_leaf = c_d[index]; // keep a pointer around for later 175 | 176 | // If we choose to create a new branch at a level less than our 177 | // desired level, then it can have no words already added, hence the 178 | // word probability is the default. 179 | if (c_d[index]->level < new_max_level) { 180 | // then there are no words below here 181 | lp_w_dn = log(_eta[w]) - log(_eta_sum); 182 | } else { 183 | // TODO: this is slow 184 | // replay back to level level 185 | CRP* current = c_d[index]; 186 | while (current->level > new_max_level-1) { 187 | CHECK(false); // shouldn't get here 188 | current = current->prev[0]; 189 | } 190 | lp_w_dn = log(_eta[w] + current->nw[w]) - log(_eta_sum + current->nwsum); 191 | } 192 | 193 | } else { // then there are no words below here 194 | lp_w_dn = log(_eta[w]) - log(_eta_sum); 195 | } 196 | 197 | // Add the probability of sampling a new level, defined as the above 198 | lposterior_z_dn.push_back(log(1.0-exp(lp_z_dn_sum))+lp_w_dn); 199 | // LOG(INFO) << lposterior_z_dn[lposterior_z_dn.size()-1]; 200 | 201 | // Update the assignment 202 | _z[d][n] = sample_unnormalized_log_multinomial(&lposterior_z_dn); 203 | 204 | 205 | CHECK_LE(_z[d][n], _c[d].size()); 206 | 207 | // Take care of the m assignment if we shrink 208 | // TODO: this is probably inefficient 209 | if (_z[d][n] < _c[d].size()) { 210 | // detach this document from the lower nodes, if nd[d] is zero 211 | for (unsigned l = _c[d].size()-1; l > _z[d][n]; l--) { 212 | if (_c[d][l]->nd[d] == 0) { 213 | CHECK_GT(_c[d][l]->ndsum, 0); 214 | _c[d][l]->ndsum -= 1; 215 | _c[d].pop_back(); 216 | CHECK(_c[d].size() == l); 217 | } else { 218 | break; // break off early to allow nd=1 -> nd=0 -> nd=1 219 | } 220 | } 221 | } else if (_z[d][n] == _c[d].size()) { // sampled the new entry 222 | _z[d][n] = new_max_level-1; 223 | 224 | // Support the new level by adding to the CRP tree 225 | unsigned old_size = _c[d].size(); 226 | graft_path_at(new_leaf, &_c[d], new_max_level); 227 | 228 | // Add to all the document counts 229 | for (int l = old_size; l < _c[d].size(); l++) { 230 | _c[d][l]->ndsum += 1; 231 | } 232 | } 233 | 234 | 235 | // Update the maximum depth if necessary 236 | if (_z[d][n] > _maxL) { 237 | _maxL = _z[d][n]; 238 | } 239 | 240 | // Update the counts 241 | 242 | // Check to see that the default dictionary insertion works like we 243 | // expect 244 | // DCHECK(_c[d][_z[d][n]]->nw.find(w) != _c[d][_z[d][n]]->nw.end() || _c[d][_z[d][n]]->nw[w] == 0); 245 | // DCHECK(_c[d][_z[d][n]]->nd.find(d) != _c[d][_z[d][n]]->nd.end() || _c[d][_z[d][n]]->nd[d] == 0); 246 | 247 | _c[d][_z[d][n]]->nw[w] += 1; // number of words in topic z equal to w 248 | _c[d][_z[d][n]]->nd[d] += 1; // number of words in doc d with topic z 249 | _c[d][_z[d][n]]->nwsum += 1; // number of words in topic z 250 | _nd[d] += 1; // number of words in doc d 251 | 252 | CHECK_GT(_c[d][_z[d][n]]->ndsum, 0); 253 | 254 | // TODO:: delete excess leaves If we reassigned the levels we might 255 | // have caused some of the lower nodes in the tree to become empty. If so, 256 | // then we should remove them. TODO:: how to do this in a way that 257 | // updates all the things pointng to that node (without having to keep 258 | // track of that?) This is probably just purely an efficiency thing, so we 259 | // only need to worry if memory gets to be an issue 260 | } 261 | } 262 | // Resamples the level allocation variables z_{d,n} given the path assignments 263 | // c and the path assignments given the level allocations 264 | void GEMNCRP::resample_posterior() { 265 | CHECK_GT(_lV, 0); 266 | CHECK_GT(_lD, 0); 267 | CHECK_GT(_L, 0); 268 | 269 | // Interleaved version 270 | _maxL = 0; 271 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 272 | unsigned d = d_itr->first; 273 | 274 | // LOG(INFO) << " resampling document " << d; 275 | resample_posterior_c_for(d); 276 | // DCHECK(tree_is_consistent()); 277 | for (int z = 0; z < FLAGS_ncrp_z_per_iteration; z++) { 278 | resample_posterior_z_for(d, true); 279 | // DCHECK(tree_is_consistent()); 280 | } 281 | } 282 | } 283 | 284 | double GEMNCRP::compute_log_likelihood() { 285 | // Compute the log likelihood for the tree 286 | double log_lik = 0; 287 | _unique_nodes = 0; // recalculate the tree size 288 | // Compute the log likelihood of the tree 289 | deque node_queue; 290 | node_queue.push_back(_ncrp_root); 291 | 292 | while (!node_queue.empty()) { 293 | CRP* current = node_queue.front(); 294 | node_queue.pop_front(); 295 | 296 | _unique_nodes += 1; 297 | 298 | // Should never have words attached but no documents 299 | CHECK(!(current->ndsum == 0 && current->nwsum > 0)); 300 | 301 | if (current->tables.size() > 0) { 302 | for (int i = 0; i < current->tables.size(); i++) { 303 | if (current->tables[i]->ndsum > 0) { // TODO: how to delete nodes? 304 | log_lik += log(current->tables[i]->ndsum) - log(current->ndsum+_gamma-1); 305 | } 306 | } 307 | 308 | node_queue.insert(node_queue.end(), current->tables.begin(), 309 | current->tables.end()); 310 | } 311 | } 312 | 313 | // Compute the log likelihood of the level assignments (correctly?) 314 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 315 | unsigned d = d_itr->first; 316 | 317 | // TODO: inefficient? 318 | // ndsum_above[k] is #[z_{d,-n} >= k] 319 | vector ndsum_above; 320 | for (int l = 0; l < _c[d].size(); l++) { 321 | // TODO: optimize this 322 | ndsum_above.push_back(0); 323 | // count the total words attached for this document below here 324 | for (int ll = l; ll < _c[d].size(); ll++) { 325 | ndsum_above[l] += _c[d][ll]->nd[d]; 326 | } 327 | // LOG(INFO) << "ndsum_above[" << l << "] = " << ndsum_above[l]; 328 | } 329 | 330 | for (int n = 0; n < _D[d].size(); n++) { 331 | // likelihood of drawing this word 332 | unsigned w = _D[d][n]; 333 | log_lik += log(_c[d][_z[d][n]]->nw[w]+_eta[w]) - 334 | log(_c[d][_z[d][n]]->nwsum+_eta_sum); 335 | 336 | // likelihood of the topic? 337 | // TODO: this is heinously inefficient 338 | double V_j_sum = 0; 339 | for (int l = 0; l < _z[d][n]; l++) { 340 | if (l < _c[d].size()-1) { 341 | V_j_sum += log(_gem_m*_pi + ndsum_above[l+1]) - 342 | log(_pi + ndsum_above[l]); 343 | } 344 | } 345 | 346 | log_lik += log((1-_gem_m)*_pi + _c[d][_z[d][n]]->nd[d]) - 347 | log(_pi + ndsum_above[_z[d][n]]) + V_j_sum; 348 | } 349 | } 350 | return log_lik; 351 | } 352 | 353 | int main(int argc, char **argv) { 354 | GEMNCRP h = GEMNCRP(FLAGS_gem_m, FLAGS_gem_pi); 355 | h.load_data(FLAGS_ncrp_datafile); 356 | 357 | h.run(); 358 | } 359 | -------------------------------------------------------------------------------- /sample-gem-ncrp.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // This implements (I think) the GEM distributed hLDA, which has an infinitely 17 | // deep tree, where each node can have (possibly) infinitely many branches. 18 | // There are some issues with implementing this sampler that I (and others) 19 | // have raised on the Princeton topic models list, but so far no one has come 20 | // up with any answers. In any case the sampler /seems/ to work just fine. 21 | 22 | 23 | #ifndef SAMPLE_GEM_NCRP_H_ 24 | #define SAMPLE_GEM_NCRP_H_ 25 | 26 | #include 27 | 28 | #include "ncrp-base.h" 29 | 30 | // The GEM(m, \pi) distribution hyperparameter m, controls the "proportion of 31 | // general words relative to specific words" 32 | DECLARE_double(gem_m); 33 | 34 | // The GEM(m, \pi) hyperparameter \pi: reflects how strictly we expect the 35 | // documents to adhere to the m proportions. 36 | DECLARE_double(gem_pi); 37 | 38 | class GEMNCRP : public NCRPBase { 39 | public: 40 | GEMNCRP(double m, double pi); 41 | ~GEMNCRP() { /* TODO: free memory! */ } 42 | 43 | string current_state(); 44 | private: 45 | void resample_posterior(); 46 | void resample_posterior_z_for(unsigned d, bool remove); 47 | 48 | double compute_log_likelihood(); 49 | 50 | private: 51 | double _gem_m; 52 | double _pi; 53 | 54 | unsigned _maxL; 55 | }; 56 | 57 | #endif // SAMPLE_GEM_NCRP_H_ 58 | -------------------------------------------------------------------------------- /sample-mult-ncrp.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Samples from the Nested Chinese Restaurant Process using multinomials. Input 17 | // data is assumed to be a newline-delimited list of documents, each of which is 18 | // composed of some number of tokens. 19 | // 20 | // This version uses a fixed-depth tree. 21 | 22 | #include 23 | #include 24 | 25 | #include "ncrp-base.h" 26 | #include "sample-mult-ncrp.h" 27 | 28 | // Performs a single document's level assignment resample step 29 | void FixedDepthNCRP::resample_posterior_z_for(unsigned d, bool remove) { 30 | VLOG(1) << "resample posterior z for " << d; 31 | 32 | for (int n = 0; n < _D[d].size(); n++) { 33 | unsigned w = _D[d][n]; 34 | 35 | if (remove) { 36 | // Remove this document and word from the counts 37 | _c[d][_z[d][n]]->remove_no_ndsum(w,d); 38 | } 39 | 40 | vector lp_z_dn; 41 | 42 | unsigned start = FLAGS_ncrp_skip_root ? 1 : 0; 43 | for (int l = start; l < _L; l++) { 44 | // check that ["doesnt exist"]->0 45 | DCHECK(_c[d][l]->nw.find(w) != _c[d][l]->nw.end() || _c[d][l]->nw[w] == 0); 46 | DCHECK(_c[d][l]->nd.find(d) != _c[d][l]->nd.end() || _c[d][l]->nd[d] == 0); 47 | 48 | lp_z_dn.push_back(log(_eta[w] + _c[d][l]->nw[w]) - 49 | log(_eta_sum + _c[d][l]->nwsum) + 50 | log(_alpha[l] + _c[d][l]->nd[d]) - 51 | log(_alpha_sum + _nd[d]-1)); 52 | } 53 | 54 | // Update the assignment 55 | // _z[d][n] = SAFE_sample_unnormalized_log_multinomial(&lp_z_dn) + start; 56 | _z[d][n] = sample_unnormalized_log_multinomial(&lp_z_dn) + start; 57 | 58 | // Update the counts 59 | 60 | // Check to see that the default dictionary insertion works like we 61 | // expect 62 | DCHECK(_c[d][_z[d][n]]->nw.find(w) != _c[d][_z[d][n]]->nw.end() 63 | || _c[d][_z[d][n]]->nw[w] == 0); 64 | DCHECK(_c[d][_z[d][n]]->nd.find(d) != _c[d][_z[d][n]]->nd.end() 65 | || _c[d][_z[d][n]]->nd[d] == 0); 66 | 67 | _c[d][_z[d][n]]->add_no_ndsum(w,d); 68 | } 69 | } 70 | 71 | // Resamples the level allocation variables z_{d,n} given the path assignments 72 | // c and the path assignments given the level allocations 73 | void FixedDepthNCRP::resample_posterior() { 74 | CHECK_GT(_lV, 0); 75 | CHECK_GT(_lD, 0); 76 | CHECK_GT(_L, 0); 77 | 78 | if (FLAGS_ncrp_update_hyperparameters) { 79 | // TODO:: for now this is basically assuming a uniform distribution 80 | // over hyperparameters (bad!) with a truncated gaussian as the proposal 81 | // distribution 82 | // double old_ll = _ll; 83 | // double old_eta = _eta; 84 | // double old_alpha = _alpha; 85 | // double old_gamma = _gamma; 86 | // _eta += sample_gaussian() / 1000.0; 87 | // _alpha += sample_gaussian() / 1.0; 88 | // _gamma += sample_gaussian() / 1000.0; 89 | // _eta = max(0.00000001, _eta); 90 | // _alpha = max(0.00001, _alpha); 91 | // _gamma = max(0.00001, _gamma); 92 | 93 | // double new_ll = compute_log_likelihood(); 94 | // double k = log(sample_uniform()); 95 | 96 | // if (k < new_ll - old_ll) { 97 | // _ll = new_ll; 98 | // } else { 99 | // _eta = old_eta; 100 | // _alpha = old_alpha; 101 | // _gamma = old_gamma; 102 | // _ll = old_ll; 103 | // } 104 | } 105 | // Interleaved version 106 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 107 | unsigned d = d_itr->first; 108 | 109 | if (FLAGS_ncrp_depth > 1 && FLAGS_ncrp_max_branches != 1) { 110 | resample_posterior_c_for(d); 111 | } 112 | // BAD BAD: skipping check 113 | // DCHECK(tree_is_consistent()); 114 | for (int z = 0; z < FLAGS_ncrp_z_per_iteration; z++) { 115 | resample_posterior_z_for(d, true); 116 | // DCHECK(tree_is_consistent()); 117 | } 118 | } 119 | 120 | print_summary(); 121 | } 122 | 123 | 124 | double FixedDepthNCRP::compute_log_likelihood() { 125 | // VLOG(1) << "compute log likelihood " ; 126 | // Compute the log likelihood for the tree 127 | double log_lik = 0; 128 | _unique_nodes = 0; // recalculate the tree size 129 | // Compute the log likelihood of the tree 130 | deque node_queue; 131 | node_queue.push_back(_ncrp_root); 132 | 133 | while (!node_queue.empty()) { 134 | CRP* current = node_queue.front(); 135 | node_queue.pop_front(); 136 | 137 | _unique_nodes += 1; 138 | 139 | if (current->tables.size() > 0) { 140 | for (int i = 0; i < current->tables.size(); i++) { 141 | CHECK_GT(current->tables[i]->ndsum, 0); 142 | if (FLAGS_ncrp_m_dependent_gamma) { 143 | log_lik += log(current->tables[i]->ndsum) - log(current->ndsum * (_gamma + 1) - 1); 144 | } else { 145 | log_lik += log(current->tables[i]->ndsum) - log(current->ndsum+_gamma-1); 146 | } 147 | } 148 | 149 | node_queue.insert(node_queue.end(), current->tables.begin(), 150 | current->tables.end()); 151 | } 152 | } 153 | 154 | // Compute the log likelihood of the level assignments (correctly?) 155 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 156 | unsigned d = d_itr->first; 157 | 158 | double lndsumd = log(_nd[d]+_alpha_sum); 159 | for (int n = 0; n < _D[d].size(); n++) { 160 | // likelihood of drawing this word 161 | unsigned w = _D[d][n]; 162 | log_lik += log(_c[d][_z[d][n]]->nw[w]+_eta[w]) - 163 | log(_c[d][_z[d][n]]->nwsum+_eta_sum); 164 | // likelihood of the topic? 165 | log_lik += log(_c[d][_z[d][n]]->nd[d]+_alpha[_z[d][n]]) - lndsumd; 166 | } 167 | } 168 | return log_lik; 169 | } 170 | 171 | string FixedDepthNCRP::current_state() { 172 | _filename = FLAGS_ncrp_datafile; 173 | _filename += StringPrintf("-L%d-gamma%.2f-alpha%.2f-eta%.2f-eds%.2f-zpi%d-best.data", 174 | _L, _gamma, _alpha_sum / (double)_alpha.size(), 175 | _eta_sum / (double)_eta.size(), 176 | FLAGS_ncrp_eta_depth_scale, 177 | FLAGS_ncrp_z_per_iteration); 178 | 179 | return StringPrintf( 180 | "ll = %f (%f at %d) %d alpha = %f eta = %f gamma = %f L = %d", 181 | _ll, _best_ll, _best_iter, _unique_nodes, 182 | _alpha_sum / (double)_alpha.size(), 183 | _eta_sum / (double)_eta.size(), _gamma, _L); 184 | } 185 | 186 | int main(int argc, char **argv) { 187 | google::InitGoogleLogging(argv[0]); 188 | google::ParseCommandLineFlags(&argc, &argv, true); 189 | 190 | init_random(); 191 | 192 | FixedDepthNCRP h = FixedDepthNCRP(); 193 | h.load_data(FLAGS_ncrp_datafile); 194 | 195 | h.run(); 196 | } 197 | -------------------------------------------------------------------------------- /sample-mult-ncrp.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // This implements the multinomial sampler for FixedDepthNCRP, which constrains the 17 | // topics to be a tree with fixed depth. 18 | 19 | 20 | #ifndef SAMPLE_MULT_NCRP_H_ 21 | #define SAMPLE_MULT_NCRP_H_ 22 | 23 | class FixedDepthNCRP : public NCRPBase { 24 | public: 25 | FixedDepthNCRP() { } 26 | ~FixedDepthNCRP() { /* TODO: free memory! */ } 27 | 28 | string current_state(); 29 | private: 30 | void resample_posterior(); 31 | void resample_posterior_z_for(unsigned d, bool remove); 32 | 33 | double compute_log_likelihood(); 34 | }; 35 | 36 | #endif // SAMPLE_MULT_NCRP_H_ 37 | -------------------------------------------------------------------------------- /sample-precomputed-fixed-ncrp-main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "ncrp-base.h" 17 | #include "sample-fixed-ncrp.h" 18 | #include "sample-precomputed-fixed-ncrp.h" 19 | 20 | int main(int argc, char **argv) { 21 | google::InitGoogleLogging(argv[0]); 22 | google::ParseCommandLineFlags(&argc, &argv, true); 23 | 24 | init_random(); 25 | 26 | NCRPPrecomputedFixed h = NCRPPrecomputedFixed(FLAGS_gem_m, FLAGS_gem_pi); 27 | h.load_data(FLAGS_ncrp_datafile); 28 | 29 | h.run(); 30 | } 31 | -------------------------------------------------------------------------------- /sample-precomputed-fixed-ncrp.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Like sample-fixed-ncrp but reads in the topic structure in a much more 17 | // flexible way (assumes a file containing doctopictopic). 18 | // then load topic assignments directly from the topic_assignments_file 19 | // instead of inferring them from the hierarchy. 20 | 21 | #include 22 | #include 23 | 24 | #include "ncrp-base.h" 25 | #include "sample-fixed-ncrp.h" 26 | #include "sample-precomputed-fixed-ncrp.h" 27 | 28 | // Topic assignments file, if we've precomputed the topic assignments 29 | DEFINE_string(topic_assignments_file, 30 | "", 31 | "file holding documenttopictopic... for every document"); 32 | 33 | // Number of additional topics to use over the labeled ones 34 | DEFINE_int32(additional_noise_topics, 35 | 0, 36 | "number of additional topics to include in addition to the labeled ones"); 37 | 38 | // Should we cull topics that only have one document? 39 | DEFINE_bool(cull_unique_topics, 40 | true, 41 | "should we cull topics that only have one document?"); 42 | 43 | void NCRPPrecomputedFixed::add_crp_node(const string& name) { 44 | if (_node_to_crp.find(name) == _node_to_crp.end()) { 45 | // havent made a CRP node for this yet 46 | _node_to_crp[name] = new CRP(0, 0); 47 | _node_to_crp[name]->label = name; 48 | CHECK_EQ(_node_to_crp[name]->prev.size(), 0); 49 | VLOG(1) << "creating node [" << name << "]"; 50 | _unique_nodes += 1; 51 | } 52 | } 53 | 54 | void NCRPPrecomputedFixed::load_precomputed_tree_structure(const string& filename) { 55 | LOG(INFO) << "loading tree"; 56 | _unique_nodes = 0; 57 | _node_to_crp.set_empty_key(kEmptyStringKey); 58 | 59 | CHECK(!FLAGS_ncrp_skip_root); 60 | 61 | CHECK_STRNE(filename.c_str(), ""); 62 | 63 | // These all cause problems. 64 | CHECK(FLAGS_use_dag); 65 | CHECK(!FLAGS_separate_path_assignments); 66 | CHECK(!FLAGS_sense_selection); 67 | CHECK(!FLAGS_learn_eta); 68 | CHECK(!FLAGS_gem_sampler); 69 | 70 | ifstream input_file(filename.c_str(), ios_base::in | ios_base::binary); 71 | 72 | // First add in the noise topics 73 | for (int i = 0; i < FLAGS_additional_noise_topics; i++) { 74 | add_crp_node(StringPrintf("NOISE_%d", i)); 75 | } 76 | 77 | // First pass over the entire tree structure and create CRP nodes for each 78 | // of the unique node names 79 | string curr_line; 80 | while (true) { 81 | getline(input_file, curr_line); 82 | 83 | if (input_file.eof()) { 84 | break; 85 | } 86 | vector tokens; 87 | curr_line = StringReplace(curr_line, "\n", "", true); 88 | 89 | // LOG(INFO) << curr_line; 90 | SplitStringUsing(curr_line, "\t", &tokens); 91 | 92 | string doc_name = tokens[0]; 93 | 94 | if (_document_id.find(doc_name) != _document_id.end()) { 95 | // Add new nodes for each unseen topic and then build the _c vector for 96 | // the document, if necessary (adds a node specific to the document 97 | // as well) 98 | for (int i = 0; i < tokens.size(); i++) { 99 | add_crp_node(tokens.at(i)); 100 | } 101 | // Build the _c vector 102 | unsigned d = _document_id[doc_name]; 103 | for (int i = 1; i < tokens.size(); i++) { 104 | _c[d].push_back(_node_to_crp[tokens[i]]); 105 | } 106 | // Add in the noise topics 107 | for (int i = 0; i < FLAGS_additional_noise_topics; i++) { 108 | _c[d].push_back(_node_to_crp[StringPrintf("NOISE_%d", i)]); 109 | } 110 | // Resize _L and the alpha vector as needed 111 | if (_c[d].size() > _L) { 112 | _L = _c[d].size(); 113 | _maxL = _c[d].size(); 114 | // Find the deepest branch 115 | if (_L > _alpha.size()) { 116 | // Take care of resizing the vector of alpha hyperparameters, even 117 | // though we currently don't learn them 118 | for (int l = _alpha.size(); l < _L; l++) { 119 | _alpha.push_back(FLAGS_ncrp_alpha); 120 | } 121 | _alpha_sum = _L*FLAGS_ncrp_alpha; 122 | } 123 | } 124 | for (int l = 0; l < _c[d].size(); l++) { 125 | // TODO: what is level used for? 126 | //_c[d][l]->level = l; 127 | _c[d][l]->ndsum += 1; 128 | } 129 | } else { 130 | VLOG(1) << "document [" << doc_name << "] from topics file is not in raw data source"; 131 | } 132 | } 133 | 134 | VLOG(1) << "init tree done..."; 135 | } 136 | 137 | void NCRPPrecomputedFixed::allocate_document(unsigned d) { 138 | // Do some sanity checking 139 | // Attach the words to this path 140 | 141 | string doc_name = _document_name[d]; 142 | // CHECK(_node_to_crp.find(doc_name) != _node_to_crp.end()) 143 | // << "missing document [" << doc_name << "] in topics file"; 144 | _total += 1; 145 | if(_node_to_crp.find(doc_name) == _node_to_crp.end()) { 146 | VLOG(1) << "missing document [" << doc_name << "] in topics file"; 147 | _missing += 1; 148 | return; 149 | } 150 | // CHECK(_c[d].size() > 0); 151 | // 152 | _nd[d] = 0; 153 | 154 | // Cull out all the topics where m==1 (we're the only document 155 | // referencing this topic) 156 | vector new_topics; 157 | for (int l = 0; l < _c[d].size(); l++) { 158 | CHECK_GT(_c[d][l]->ndsum, 0); 159 | if (_c[d][l]->ndsum == 1 && _c[d][l]->label != doc_name && FLAGS_cull_unique_topics) { 160 | VLOG(1) << "culling topic [" << _c[d][l]->label << "] from document [" << doc_name << "]"; 161 | new_topics[0]->label += " | " + _c[d][l]->label; 162 | VLOG(1) << " [" << new_topics[0]->label << "]"; 163 | } else { 164 | LOG(INFO) << "keeping topic [" << _c[d][l]->label << "] for document [" << doc_name << "] size " << _c[d][l]->ndsum; 165 | new_topics.push_back(_c[d][l]); 166 | } 167 | } 168 | 169 | _c[d] = new_topics; 170 | 171 | if (_c[d].empty()) { 172 | LOG(INFO) << "removing document [" << doc_name << "] since no non-unique topics"; 173 | _missing += 1; 174 | // CHECK_GT(_c[d].size(), 0) << "failed topic check."; 175 | return; 176 | } 177 | 178 | for (int n = 0; n < _D[d].size(); n++) { 179 | // CHECK_GT(_c[d].size(), 0) << "[" << _document_name[d] << "] has a zero length path"; 180 | unsigned w = _D[d][n]; 181 | 182 | // set a random topic assignment for this guy 183 | _z[d][n] = sample_integer(_c[d].size()); 184 | 185 | // test the initialization of maps 186 | CHECK(_c[d][_z[d][n]]->nw.find(w) != _c[d][_z[d][n]]->nw.end() 187 | || _c[d][_z[d][n]]->nw[w] == 0); 188 | CHECK(_c[d][_z[d][n]]->nd.find(d) != _c[d][_z[d][n]]->nd.end() 189 | || _c[d][_z[d][n]]->nd[d] == 0); 190 | 191 | _c[d][_z[d][n]]->nw[w] += 1; // number of words in topic z equal to w 192 | _c[d][_z[d][n]]->nd[d] += 1; // number of words in doc d with topic z 193 | _c[d][_z[d][n]]->nwsum += 1; // number of words in topic z 194 | _nd[d] += 1; // number of words in doc d 195 | 196 | _total_words += 1; 197 | } 198 | // Incrementally reconstruct the tree (each time we add a document, 199 | // update its tree assignment based only on the previously added 200 | // documents; this results in a "fuller" initial tree, instead of one 201 | // fat trunk (fat trunks cause problems for mixing) 202 | if (d > 0) { 203 | LOG(INFO) << "resample posterior for " << doc_name; 204 | resample_posterior_z_for(d, true); 205 | } 206 | } 207 | 208 | void NCRPPrecomputedFixed::batch_allocation() { 209 | LOG(INFO) << "Doing precomputed tree ncrp batch allocation..."; 210 | 211 | _missing = 0; 212 | _total = 0; 213 | 214 | // Load the precomputed tree structure (after we've loaded the document; 215 | // before we allocate the document) 216 | load_precomputed_tree_structure(FLAGS_topic_assignments_file); 217 | 218 | // Allocate the document 219 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 220 | unsigned d = d_itr->first; 221 | allocate_document(d); 222 | } 223 | LOG(INFO) << "missing " << _missing << " of " << _total; 224 | 225 | // NOTE: we need to do this in order to get the filename right... 226 | LOG(INFO) << "Initial state: " << current_state(); 227 | 228 | VLOG(1) << "writing dictionary"; 229 | write_dictionary(); 230 | } 231 | 232 | 233 | // Write out a static dictionary required for decoding Gibbs samples 234 | void NCRPPrecomputedFixed::write_dictionary() { 235 | string filename = StringPrintf("%s-%s-noise%d-%d.dictionary", get_base_name(_filename).c_str(), get_base_name(FLAGS_topic_assignments_file).c_str(), FLAGS_additional_noise_topics, FLAGS_random_seed); 236 | LOG(INFO) << "writing dictionary to [" << filename << "]"; 237 | 238 | ofstream f(filename.c_str(), ios_base::out | ios_base::binary); 239 | 240 | // First write out all the topics for each document that was kept in the 241 | // sampler 242 | set visited_nodes; // keep track of visited things for the DAG case 243 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 244 | unsigned d = d_itr->first; 245 | 246 | if (!_c[d].empty()) { 247 | f << _document_name[d] << "\t" << d; 248 | for (int l = 0; l < _c[d].size(); l++) { 249 | f << "\t" << _c[d][l]; 250 | } 251 | f << endl; 252 | // Write out any unvisited topics 253 | for (int l = 0; l < _c[d].size(); l++) { 254 | if (visited_nodes.find(_c[d][l]) == visited_nodes.end()) { 255 | visited_nodes.insert(_c[d][l]); 256 | 257 | CRP* current = _c[d][l]; 258 | 259 | f << current << "\t" << current->label << endl; 260 | } 261 | } 262 | } 263 | } 264 | 265 | f << endl; 266 | // Now write out the term dictionary 267 | for (WordCode::const_iterator itr = _word_id_to_name.begin(); itr != _word_id_to_name.end(); itr++) { 268 | f << itr->first << "\t" << itr->second << endl; 269 | } 270 | } 271 | 272 | 273 | // Write out all the data in an intermediate format 274 | void NCRPPrecomputedFixed::write_data(string prefix) { 275 | string filename = StringPrintf("%s-%s-noise%d-%d-%s.hlda", get_base_name(_filename).c_str(), 276 | get_base_name(FLAGS_topic_assignments_file).c_str(), 277 | FLAGS_additional_noise_topics, 278 | FLAGS_random_seed, 279 | prefix.c_str()); 280 | VLOG(1) << "writing data to [" << filename << "]"; 281 | 282 | ofstream f(filename.c_str(), ios_base::out | ios_base::binary); 283 | 284 | f << current_state() << endl; 285 | 286 | // Write out all the topic-term distributions 287 | set visited_nodes; // keep track of visited things for the DAG case 288 | for (DocumentMap::const_iterator d_itr = _D.begin(); d_itr != _D.end(); d_itr++) { 289 | unsigned d = d_itr->first; 290 | if (!_c[d].empty()) { 291 | // Write out any unvisited topics 292 | for (int l = 0; l < _c[d].size(); l++) { 293 | if (visited_nodes.find(_c[d][l]) == visited_nodes.end()) { 294 | visited_nodes.insert(_c[d][l]); 295 | 296 | CRP* current = _c[d][l]; 297 | 298 | // Write out the node contents 299 | f << current << "\t||\t" << current->ndsum << "\t||"; 300 | 301 | for (WordToCountMap::iterator nw_itr = current->nw.begin(); 302 | nw_itr != current->nw.end(); nw_itr++) { 303 | if (nw_itr->second > 0) { // sparsify 304 | f << "\t" << nw_itr->first << ":" << nw_itr->second; 305 | } 306 | } 307 | f << "\t||"; 308 | for (DocToWordCountMap::iterator nd_itr = current->nd.begin(); 309 | nd_itr != current->nd.end(); nd_itr++) { 310 | if (nd_itr->second > 0) { // sparsify 311 | f << "\t" << nd_itr->first << ":" << nd_itr->second; 312 | } 313 | } 314 | f << endl; 315 | // end writing out node contents 316 | } 317 | } 318 | } 319 | } 320 | } 321 | 322 | -------------------------------------------------------------------------------- /sample-precomputed-fixed-ncrp.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | // Like sample-fixed-ncrp but reads in the topic structure in a much more 17 | // flexible way (assumes a file containing doctopictopic). 18 | 19 | #ifndef SAMPLE_PRECOMPUTED_FIXED_NCRP_H_ 20 | #define SAMPLE_PRECOMPUTED_FIXED_NCRP_H_ 21 | 22 | #include 23 | #include 24 | 25 | #include "ncrp-base.h" 26 | #include "sample-fixed-ncrp.h" 27 | #include "sample-precomputed-fixed-ncrp.h" 28 | 29 | // Topic assignments file, if we've precomputed the topic assignments 30 | DECLARE_string(topic_assignments_file); 31 | 32 | // Number of additional topics to use over the labeled ones 33 | DECLARE_int32(additional_noise_topics); 34 | 35 | // Should we cull topics that only have one document? 36 | DECLARE_bool(cull_unique_topics); 37 | 38 | // This version differs from the normal GEM sampler in that the tree structure 39 | // is fixed a priori. Hence there is no resampling of c, the path allocations. 40 | class NCRPPrecomputedFixed : public GEMNCRPFixed { 41 | public: 42 | NCRPPrecomputedFixed(double m, double pi) : GEMNCRPFixed(m, pi) { } 43 | 44 | // Load the tree structure 45 | void load_precomputed_tree_structure(const string& filename); 46 | 47 | // Allocate all the documents at once (called for non-streaming) 48 | void batch_allocation(); 49 | 50 | // Allocate a single document; can be called during load for streaming 51 | void allocate_document(unsigned d); 52 | 53 | // Write out a static dictionary required for decoding Gibbs samples 54 | void write_dictionary(); 55 | 56 | // Write out the Gibbs sample 57 | void write_data(string prefix); 58 | 59 | protected: 60 | void add_crp_node(const string& name); 61 | 62 | protected: 63 | google::dense_hash_map _node_to_crp; 64 | 65 | unsigned _total; // number of total documents on load 66 | unsigned _missing; // number of missing documents on load 67 | }; 68 | 69 | #endif // SAMPLE_PRECOMPUTED_FIXED_NCRP_H_ 70 | -------------------------------------------------------------------------------- /sample-soft-crosscat-main.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "sample-soft-crosscat.h" 17 | 18 | int main(int argc, char **argv) { 19 | google::InitGoogleLogging(argv[0]); 20 | google::ParseCommandLineFlags(&argc, &argv, true); 21 | 22 | init_random(); 23 | 24 | SoftCrossCatMM h = SoftCrossCatMM(); 25 | h.load_data(FLAGS_mm_datafile); 26 | h.run(); 27 | } 28 | -------------------------------------------------------------------------------- /sample-soft-crosscat.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2010 Joseph Reisinger 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef SAMPLE_SOFT_CROSSCAT_MM_H_ 17 | #define SAMPLE_SOFT_CROSSCAT_MM_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "gibbs-base.h" 24 | 25 | // Two main implementations: 26 | // (1) normal: treat the topic model part for a document as the set of 27 | // clusters picked out by the document (one per view) 28 | // (2) marginal: treat each "topic" as the marginal over all clusters inside 29 | // it; this model seems to make more sense to me 30 | DECLARE_string(implementation); 31 | 32 | // the number of clusterings (topics) 33 | DECLARE_int32(M); 34 | 35 | // the maximum number of clusters 36 | DECLARE_int32(KMAX); 37 | 38 | // Smoother on clustering 39 | DECLARE_double(mm_alpha); 40 | 41 | // Smoother on cross cat / topic model 42 | DECLARE_double(cc_xi); 43 | 44 | // File holding the data 45 | DECLARE_string(mm_datafile); 46 | 47 | // If toggled, the first view will be constrained to a single cluster 48 | DECLARE_bool(cc_include_noise_view); 49 | 50 | // If toggled, will resume from the best model written so far 51 | DECLARE_bool(cc_resume_from_best); 52 | 53 | DECLARE_string(cc_fixed_topic_seed); 54 | 55 | // Implements several kinds of mixture models (uniform prior, Dirichlet prior, 56 | // DPCrossCatMM all with DP-Multinomial likelihood. 57 | class SoftCrossCatMM : public GibbsSampler { 58 | public: 59 | SoftCrossCatMM() { } 60 | 61 | // Allocate all the documents at once (called for non-streaming) 62 | void batch_allocation(); 63 | 64 | // Initialize the model cleanly 65 | void clean_initialization(); 66 | 67 | double compute_log_likelihood(); 68 | 69 | void write_data(string prefix); 70 | 71 | // Restore from the intermediate model 72 | bool restore_data_from_prefix(string prefix); 73 | bool restore_data_from_file(string filename, bool seed); 74 | protected: 75 | void resample_posterior(); 76 | void resample_posterior_c_for(unsigned d); 77 | void resample_posterior_z_for(unsigned d); 78 | 79 | 80 | string current_state(); 81 | 82 | protected: 83 | // Maps documents to clusters 84 | multiple_cluster_map _c; // Map [d][m] -> cluster_id 85 | multiple_cluster_map _z; // Map [d][n] -> view_id 86 | multiple_clustering _cluster; // Map [m][z] -> chinese restaurant 87 | clustering _cluster_marginal; // Map [m] -> chinese restaurant (marginal realization of _cluster) 88 | 89 | // Base names of the flags 90 | string _word_features_moniker; 91 | string _datafile_moniker; 92 | 93 | vector _current_component; // # of clusters currently 94 | 95 | string _output_filename; 96 | 97 | // Count the number of feature movement proposals and the number that 98 | // fail 99 | unsigned _m_proposed; 100 | unsigned _m_failed; 101 | 102 | // Count cluster moves 103 | unsigned _c_proposed; 104 | unsigned _c_failed; 105 | 106 | // Cluster marginal model? 107 | bool is_cluster_marginal; 108 | // Fixed topic assignments? 109 | bool is_fixed_topics; 110 | 111 | double _temp_log_lik; 112 | }; 113 | 114 | #endif // SAMPLE_SOFT_CROSSCAT_MM_H_ 115 | -------------------------------------------------------------------------------- /sampleCrossCatMixtureModel-Local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | M=$2 5 | ALPHA=$3 6 | ETA=$4 7 | XI=$5 8 | CONVERGENCE=$6 9 | SEED=$7 10 | NOISE=${8} 11 | 12 | MEM=4G 13 | 14 | echo $K CLUSTERS 15 | echo $M VIEWS 16 | echo $SEED SEED 17 | echo $NOISE NOISE 18 | 19 | if [ $MEM = "4G" ]; then 20 | BINARY=/scratch/cluster/joeraii/ncrp/sampleCrossCatMixtureModel 21 | echo Using 4G machines 22 | elif [ $MEM = "8G" ]; then 23 | BINARY=/scratch/cluster/joeraii/ncrp/sampleCrossCatMixtureModel64 24 | echo Using 8G machines 25 | else 26 | echo Error parsing memory requirement 27 | exit 28 | fi 29 | 30 | SHORN_DATAFILE=${DATAFILE##*/} 31 | 32 | BASE_RUN_PATH=XCAT-$SHORN_DATAFILE-${M}M-$ALPHA-$ETA-$XI-noise_is_${NOISE} 33 | FULL_RUN_PATH=$BASE_RUN_PATH/$SEED 34 | 35 | ORIGINAL_PATH=`pwd` 36 | 37 | mkdir $BASE_RUN_PATH 38 | cd $BASE_RUN_PATH 39 | mkdir $SEED 40 | cd $SEED 41 | 42 | 43 | GLOG_logtostderr=1 $BINARY \ 44 | --mm_datafile=$ORIGINAL_PATH/$DATAFILE --M=$M --mm_alpha=$ALPHA --eta=$ETA \ 45 | --max_gibbs_iterations=2000 --cc_xi=$XI --random_seed=$SEED \ 46 | --sample_lag=0 \ 47 | --cc_include_noise_view=$NOISE \ 48 | --convergence_interval=$CONVERGENCE 49 | -------------------------------------------------------------------------------- /sampleCrossCatMixtureModel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | M=$2 5 | ALPHA=$3 6 | ETA=$4 7 | XI=$5 8 | CONVERGENCE=$6 9 | SEED=$7 10 | NOISE=${8} 11 | KMAX=${9} 12 | 13 | MEM=4G 14 | 15 | echo $KMAX MAX_CLUSTERS 16 | echo $M VIEWS 17 | echo $SEED SEED 18 | echo $NOISE NOISE 19 | 20 | if [ $MEM = "4G" ]; then 21 | CONDORIZER=/projects/nn/joeraii/condorizer.py 22 | BINARY=/scratch/cluster/joeraii/ncrp/sampleCrossCatMixtureModel 23 | echo Using 4G machines 24 | elif [ $MEM = "8G" ]; then 25 | CONDORIZER=/projects/nn/joeraii/condorizer-8G.py 26 | BINARY=/scratch/cluster/joeraii/ncrp/sampleCrossCatMixtureModel64 27 | echo Using 8G machines 28 | else 29 | echo Error parsing memory requirement 30 | exit 31 | fi 32 | 33 | SHORN_DATAFILE=${DATAFILE##*/} 34 | 35 | BASE_RUN_PATH=XCAT-$SHORN_DATAFILE-${KMAX}KMAX-${M}M-$ALPHA-$ETA-$XI-noise_is_${NOISE} 36 | FULL_RUN_PATH=$BASE_RUN_PATH/$SEED 37 | 38 | ORIGINAL_PATH=`pwd` 39 | 40 | mkdir $BASE_RUN_PATH 41 | cd $BASE_RUN_PATH 42 | mkdir $SEED 43 | cd $SEED 44 | 45 | 46 | python $CONDORIZER $BINARY \ 47 | --mm_datafile=$ORIGINAL_PATH/$DATAFILE --KMAX=$KMAX --M=$M --mm_alpha=$ALPHA --eta=$ETA \ 48 | --max_gibbs_iterations=2000 --cc_xi=$XI --random_seed=$SEED \ 49 | --sample_lag=0 \ 50 | --cc_include_noise_view=$NOISE \ 51 | --cc_feature_move_rate=1.0 \ 52 | --convergence_interval=$CONVERGENCE \ 53 | out 54 | -------------------------------------------------------------------------------- /sampleMultLDA-Local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BINARY=~/Desktop/metamx/ncrp/sampleMultNCRP 4 | 5 | DATAFILE=$1 6 | DEPTH=$2 7 | CRP_ALPHA=$3 8 | CRP_ETA=$4 9 | SEED=$5 10 | 11 | SHORN_DATAFILE=${DATAFILE##*/} 12 | 13 | RESULTS_PATH=RESULTS_LDA 14 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-LDA$DEPTH-$CRP_ALPHA-$CRP_ETA 15 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 16 | 17 | ORIGINAL_PATH=`pwd` 18 | 19 | mkdir $RESULTS_PATH 20 | cd $RESULTS_PATH 21 | mkdir $BASE_RUN_PATH 22 | cd $BASE_RUN_PATH 23 | mkdir $SEED 24 | cd $SEED 25 | 26 | #GLOG_logtostderr=1 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_max_branches=1 --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED out 27 | GLOG_logtostderr=1 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_max_branches=1 --ncrp_depth=$DEPTH --sample_lag=5 --random_seed=$SEED out 28 | -------------------------------------------------------------------------------- /sampleMultLDA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | TOPICS=$2 5 | CRP_ALPHA=$3 6 | CRP_ETA=$4 7 | SEED=$5 8 | 9 | CONDORIZER=/projects/nn/joeraii/condorizer.py 10 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP 11 | 12 | SHORN_DATAFILE=${DATAFILE##*/} 13 | 14 | RESULTS_PATH=RESULTS_LDA 15 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-LDA$TOPICS-$CRP_ALPHA-$CRP_ETA 16 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 17 | 18 | ORIGINAL_PATH=`pwd` 19 | 20 | mkdir $RESULTS_PATH 21 | cd $RESULTS_PATH 22 | mkdir $BASE_RUN_PATH 23 | cd $BASE_RUN_PATH 24 | mkdir $SEED 25 | cd $SEED 26 | 27 | #GLOG_logtostderr=1 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_max_branches=1 --ncrp_depth=$TOPICS --sample_lag=100 --random_seed=$SEED out 28 | python $CONDORIZER $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_max_branches=1 --ncrp_depth=$TOPICS --sample_lag=100 --convergence_interval=2000 --random_seed=$SEED out 29 | -------------------------------------------------------------------------------- /sampleMultNCRP-Local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP 4 | 5 | DATAFILE=$1 6 | DEPTH=$2 7 | CRP_ALPHA=$3 8 | CRP_ETA=$4 9 | SEED=$5 10 | 11 | SHORN_DATAFILE=${DATAFILE##*/} 12 | 13 | RESULTS_PATH=RESULTS_NCRP 14 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-DEPTH$DEPTH-$CRP_ALPHA-$CRP_ETA 15 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 16 | 17 | ORIGINAL_PATH=`pwd` 18 | 19 | mkdir $RESULTS_PATH 20 | cd $RESULTS_PATH 21 | mkdir $BASE_RUN_PATH 22 | cd $BASE_RUN_PATH 23 | mkdir $SEED 24 | cd $SEED 25 | 26 | #GLOG_logtostderr=1 GLOG_v=2 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED out 27 | GLOG_logtostderr=1 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_gamma=1.0 --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=5 --random_seed=$SEED out 28 | -------------------------------------------------------------------------------- /sampleMultNCRP.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | DEPTH=$2 5 | CRP_ALPHA=$3 6 | CRP_ETA=$4 7 | SEED=$5 8 | 9 | CONDORIZER=/projects/nn/joeraii/condorizer.py 10 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP 11 | 12 | SHORN_DATAFILE=${DATAFILE##*/} 13 | 14 | RESULTS_PATH=RESULTS_NCRP 15 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-DEPTH$DEPTH-$CRP_ALPHA-$CRP_ETA 16 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 17 | 18 | ORIGINAL_PATH=`pwd` 19 | 20 | mkdir $RESULTS_PATH 21 | cd $RESULTS_PATH 22 | mkdir $BASE_RUN_PATH 23 | cd $BASE_RUN_PATH 24 | mkdir $SEED 25 | cd $SEED 26 | 27 | #GLOG_logtostderr=1 GLOG_v=2 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED out 28 | python $CONDORIZER $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_gamma=1.0 --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED --convergence_interval=2000 out 29 | -------------------------------------------------------------------------------- /sampleMultNCRPPRixFixe-Local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | DEPTH=$2 5 | CRP_ALPHA=$3 6 | CRP_ETA=$4 7 | SEED=$5 8 | MEM=$6 9 | 10 | if [ $MEM = "4G" ]; then 11 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP 12 | echo Using 4G machines 13 | elif [ $MEM = "8G" ]; then 14 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP64 15 | echo Using 8G machines 16 | else 17 | echo Error parsing memory requirement 18 | exit 19 | fi 20 | 21 | SHORN_DATAFILE=${DATAFILE##*/} 22 | 23 | RESULTS_PATH=RESULTS_PRIX_FIXE 24 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-DEPTH$DEPTH-$CRP_ALPHA-$CRP_ETA 25 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 26 | 27 | ORIGINAL_PATH=`pwd` 28 | 29 | mkdir $RESULTS_PATH 30 | cd $RESULTS_PATH 31 | mkdir $BASE_RUN_PATH 32 | cd $BASE_RUN_PATH 33 | mkdir $SEED 34 | cd $SEED 35 | 36 | #GLOG_logtostderr=1 GLOG_v=2 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED out 37 | GLOG_logtostderr=1 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_prix_fixe=true --ncrp_gamma=1.0 --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED --convergence_interval=2000 out 38 | -------------------------------------------------------------------------------- /sampleMultNCRPPRixFixe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | DEPTH=$2 5 | CRP_ALPHA=$3 6 | CRP_ETA=$4 7 | SEED=$5 8 | MEM=$6 9 | 10 | if [ $MEM = "4G" ]; then 11 | CONDORIZER=/projects/nn/joeraii/condorizer.py 12 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP 13 | echo Using 4G machines 14 | elif [ $MEM = "8G" ]; then 15 | CONDORIZER=/projects/nn/joeraii/condorizer-8G.py 16 | BINARY=/scratch/cluster/joeraii/ncrp/sampleMultNCRP64 17 | echo Using 8G machines 18 | else 19 | echo Error parsing memory requirement 20 | exit 21 | fi 22 | 23 | SHORN_DATAFILE=${DATAFILE##*/} 24 | 25 | RESULTS_PATH=RESULTS_PRIX_FIXE 26 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-DEPTH$DEPTH-$CRP_ALPHA-$CRP_ETA 27 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 28 | 29 | ORIGINAL_PATH=`pwd` 30 | 31 | mkdir $RESULTS_PATH 32 | cd $RESULTS_PATH 33 | mkdir $BASE_RUN_PATH 34 | cd $BASE_RUN_PATH 35 | mkdir $SEED 36 | cd $SEED 37 | 38 | #GLOG_logtostderr=1 GLOG_v=2 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED out 39 | python $CONDORIZER $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_prix_fixe=true --ncrp_gamma=1.0 --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --ncrp_depth=$DEPTH --sample_lag=100 --random_seed=$SEED --convergence_interval=2000 out 40 | -------------------------------------------------------------------------------- /samplePrecomputedFixedNCRP-Local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BINARY=/scratch/cluster/joeraii/ncrp/samplePrecomputedFixedNCRP 4 | 5 | DATAFILE=$1 6 | TOPIC_ASSIGNMENTS_FILE=$2 7 | NOISE_TOPICS=$3 8 | CRP_ALPHA=$4 9 | CRP_ETA=$5 10 | SEED=$6 11 | 12 | SHORN_DATAFILE=${DATAFILE##*/} 13 | SHORN_TOPIC_ASSIGNMENTS_FILE=${TOPIC_ASSIGNMENTS_FILE##*/} 14 | 15 | RESULTS_PATH=RESULTS_FIXED 16 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-$SHORN_TOPIC_ASSIGNMENTS_FILE-$NOISE_TOPICS-$CRP_ALPHA-$CRP_ETA 17 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 18 | 19 | ORIGINAL_PATH=`pwd` 20 | 21 | echo "culling" 22 | 23 | mkdir $RESULTS_PATH 24 | cd $RESULTS_PATH 25 | mkdir $BASE_RUN_PATH 26 | cd $BASE_RUN_PATH 27 | mkdir $SEED 28 | cd $SEED 29 | 30 | GLOG_logtostderr=1 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --topic_assignments_file=$ORIGINAL_PATH/$TOPIC_ASSIGNMENTS_FILE --additional_noise_topics=$NOISE_TOPICS --use_dag=true --sample_lag=100 --random_seed=$SEED out 31 | -------------------------------------------------------------------------------- /samplePrecomputedFixedNCRP.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | TOPIC_ASSIGNMENTS_FILE=$2 5 | NOISE_TOPICS=$3 6 | CRP_ALPHA=$4 7 | CRP_ETA=$5 8 | SEED=$6 9 | 10 | CONDORIZER=/projects/nn/joeraii/condorizer.py 11 | BINARY=/scratch/cluster/joeraii/ncrp/samplePrecomputedFixedNCRP 12 | 13 | SHORN_DATAFILE=${DATAFILE##*/} 14 | SHORN_TOPIC_ASSIGNMENTS_FILE=${TOPIC_ASSIGNMENTS_FILE##*/} 15 | 16 | RESULTS_PATH=RESULTS_FIXED 17 | BASE_RUN_PATH=RUN-$SHORN_DATAFILE-$SHORN_TOPIC_ASSIGNMENTS_FILE-$NOISE_TOPICS-$CRP_ALPHA-$CRP_ETA 18 | FULL_RUN_PATH=$RESULTS_PATH/$BASE_RUN_PATH/$SEED 19 | 20 | ORIGINAL_PATH=`pwd` 21 | 22 | echo "culling" 23 | 24 | mkdir $RESULTS_PATH 25 | cd $RESULTS_PATH 26 | mkdir $BASE_RUN_PATH 27 | cd $BASE_RUN_PATH 28 | mkdir $SEED 29 | cd $SEED 30 | 31 | # echo $BINARY --crp_datafile=$ORIGINAL_PATH/$DATAFILE --crp_alpha=$CRP_ALPHA --crp_eta=$CRP_ETA --topic_assignments_file=$ORIGINAL_PATH/$TOPIC_ASSIGNMENTS_FILE --additional_noise_topics=$NOISE_TOPICS --use_dag=true --sample_lag=50 --random_seed=$SEED out 32 | python $CONDORIZER $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --topic_assignments_file=$ORIGINAL_PATH/$TOPIC_ASSIGNMENTS_FILE --additional_noise_topics=$NOISE_TOPICS --use_dag=true --sample_lag=5 --random_seed=$SEED out 33 | # GLOG_logtostderr=1 GLOG_v=2 $BINARY --ncrp_datafile=$ORIGINAL_PATH/$DATAFILE --ncrp_alpha=$CRP_ALPHA --eta=$CRP_ETA --topic_assignments_file=$ORIGINAL_PATH/$TOPIC_ASSIGNMENTS_FILE --additional_noise_topics=$NOISE_TOPICS --use_dag=true --sample_lag=100 --random_seed=$SEED out 34 | -------------------------------------------------------------------------------- /sampleSoftCrossCatMixtureModel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATAFILE=$1 4 | M=$2 5 | ALPHA=$3 6 | ETA=$4 7 | XI=$5 8 | CONVERGENCE=$6 9 | SEED=$7 10 | NOISE=${8} 11 | KMAX=${9} 12 | IMPLEMENTATION=${10} 13 | 14 | echo $KMAX MAX_CLUSTERS 15 | echo $M VIEWS 16 | echo $SEED SEED 17 | echo $NOISE NOISE 18 | 19 | CONDORIZER=/projects/nn/joeraii/condorizer.py 20 | BINARY=/scratch/cluster/joeraii/ncrp/sampleSoftCrossCatMixtureModel 21 | 22 | SHORN_DATAFILE=${DATAFILE##*/} 23 | 24 | BASE_RUN_PATH=XCATSOFT-$SHORN_DATAFILE-$IMPLEMENTATION-${KMAX}KMAX-${M}M-$ALPHA-$ETA-$XI-noise_is_${NOISE} 25 | FULL_RUN_PATH=$BASE_RUN_PATH/$SEED 26 | 27 | ORIGINAL_PATH=`pwd` 28 | 29 | mkdir $BASE_RUN_PATH 30 | cd $BASE_RUN_PATH 31 | mkdir $SEED 32 | cd $SEED 33 | 34 | 35 | python $CONDORIZER $BINARY \ 36 | --mm_datafile=$ORIGINAL_PATH/$DATAFILE --KMAX=$KMAX --M=$M --mm_alpha=$ALPHA --eta=$ETA \ 37 | --max_gibbs_iterations=2000 --cc_xi=$XI --random_seed=$SEED \ 38 | --sample_lag=0 \ 39 | --implementation=$IMPLEMENTATION \ 40 | --cc_include_noise_view=$NOISE \ 41 | --convergence_interval=$CONVERGENCE \ 42 | --cc_resume_from_best=true \ 43 | out 44 | -------------------------------------------------------------------------------- /scripts/model_file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for parsing wikipedia category graph 3 | """ 4 | import sys 5 | import re 6 | from string import lower 7 | from bz2 import * 8 | from collections import defaultdict 9 | from operator import itemgetter 10 | from glob import glob 11 | from math import log, exp 12 | 13 | def open_or_bz2(file): 14 | if file.endswith('.bz2'): 15 | import bz2 16 | return bz2.BZ2File(file) 17 | else: 18 | return open(file) 19 | 20 | def log_add(x, y): 21 | if x == 0: 22 | return y 23 | if y == 0: 24 | return x 25 | if x-y > 16: 26 | return x 27 | elif x > y: 28 | return x + log(1 + exp(y-x)) 29 | elif y-x > 16: 30 | return y 31 | else: 32 | return y + log(1 + exp(x-y)) 33 | 34 | TOLERANCE = 0.000001 35 | 36 | def read_docify(docify_file, parse_counts=True): 37 | if docify_file.endswith('.bz2'): 38 | f = open_or_bz2(docify_file) 39 | else: 40 | f = open(docify_file) 41 | for line in f: 42 | tokens = line.strip().split('\t') 43 | if parse_counts: 44 | doc_title, words = tokens[0], map(parse_count, tokens[1:]) 45 | else: 46 | doc_title, words = tokens[0], tokens[1:] 47 | yield doc_title, words 48 | 49 | def read_hlda_dictionary(dictionary_file): 50 | category_of = {} 51 | word_of = {} 52 | doc_of = {} 53 | doc_to_categories = {} 54 | unvisited_categories = [] 55 | visited_categories = set() 56 | word_dictionary_section = False 57 | sys.stderr.write('Loading document-topic map...\n') 58 | doc_count = 0 59 | 60 | lines = open_or_bz2(dictionary_file) 61 | 62 | for line in lines: 63 | if line.strip(): 64 | tokens = line.strip().split('\t') 65 | if word_dictionary_section: 66 | assert(len(tokens) == 2) 67 | key, word = tokens 68 | key = int(key, 0) 69 | assert(key not in word_of) 70 | word_of[key] = word 71 | elif not unvisited_categories and not line.startswith('0x'): # REALLY NASTY HACK 72 | document, doc_no, categories = tokens[0], tokens[1], tokens[2:] 73 | categories = [int(x.strip(),0) for x in categories] 74 | # print 'original', set(categories) 75 | # print 'adding', set(categories).difference(visited_categories) 76 | # TODO: fix this bug where there can be multiple _c entries for 77 | # the same topic 78 | unvisited_categories = [x for x in categories if not x in visited_categories] 79 | doc_to_categories[document] = categories 80 | doc_of[doc_count] = document 81 | doc_count += 1 82 | else: 83 | # print 'here bc not visited', unvisited_categories 84 | assert(len(tokens) == 2) # What 85 | key, category = tokens 86 | key = int(key, 0) 87 | # print '[%s]' % category 88 | # assert(category in unvisited_categories) 89 | assert(key not in category_of) 90 | if not key in unvisited_categories: 91 | sys.stderr.write('BUG: dup [%s] for %s\n' % (category, 92 | document)) 93 | else: 94 | unvisited_categories.remove(key) 95 | visited_categories.add(key) 96 | category_of[key] = category 97 | else: 98 | sys.stderr.write('Loading word dictionary...\n') 99 | word_dictionary_section = True 100 | 101 | return (category_of, word_of, doc_of, doc_to_categories) 102 | 103 | 104 | 105 | def glob_samples(Moniker, MapOrBayes, ToKeep=20): 106 | if MapOrBayes == 'map': 107 | Header = Moniker + '*-best.hlda*' 108 | # Samples = [os.path.dirname(Header)+'/'+x for x in os.listdir(os.path.dirname(Header)) if x.startswith(os.path.basename(Header))] 109 | Samples = glob(Header) 110 | else: 111 | Header = Moniker + '*-sample*' 112 | # Samples = [os.path.dirname(Header)+'/'+x for x in os.listdir(os.path.dirname(Header)) if x.startswith(os.path.basename(Header))] 113 | Samples = glob(Header) 114 | sys.stderr.write('%s\n' % Header) 115 | sys.stderr.write('%d\n' % len(Samples)) 116 | 117 | if not Samples: 118 | sys.stderr.write('FAIL No samples matched %s\n' % Header) 119 | sys.exit() 120 | 121 | if len(Samples) > ToKeep: 122 | Samples = Samples[::-1][:ToKeep] 123 | sys.stderr.write('keeping these %r\n' % Samples) 124 | 125 | return Samples 126 | 127 | def load_all_ncrp_samples(Moniker, MapOrBayes, restrict_docs=False, ToKeep=20): 128 | """ Loads a set of samples from an ncrp and returns some sufficient 129 | statistics """ 130 | node_term_dist = defaultdict(lambda: defaultdict(int)) 131 | term_node_dist = defaultdict(lambda: defaultdict(int)) 132 | node_doc_dist = defaultdict(lambda: defaultdict(int)) 133 | doc_node_dist = defaultdict(lambda: defaultdict(int)) 134 | prev = defaultdict(set) 135 | 136 | sys.stderr.write('Quantizing to TOLERANCE=%f\n' % TOLERANCE) 137 | 138 | alpha, eta, V, T = load_append_ncrp_samples(glob_samples(Moniker, MapOrBayes, ToKeep), 139 | node_term_dist, 140 | term_node_dist, 141 | node_doc_dist, 142 | doc_node_dist, 143 | prev, 144 | restrict_docs=restrict_docs) 145 | 146 | return (node_term_dist, term_node_dist, node_doc_dist, doc_node_dist, prev, 147 | alpha, eta, V, T) 148 | 149 | def ncrp_sample_iterator(sample): 150 | """ 151 | Yields node data from an ncrp sample 152 | """ 153 | visited = set() 154 | for (line_no, line) in enumerate(open_or_bz2(sample)): 155 | if line.startswith('ll ='): 156 | m = re.search('alpha = (.*) eta = (.*) gamma .* L = (.*)', line) 157 | alpha = float(m.group(1)) 158 | eta = float(m.group(2)) 159 | L = int(m.group(3)) 160 | sys.stderr.write('Got alpha = %f eta = %f L = %d\n' % (alpha, eta, L)) 161 | continue 162 | 163 | line = line.replace('\n','') 164 | try: 165 | (node, _, m, raw_nw, raw_nd, tables) = line.split('||') 166 | except: 167 | sys.stderr.write('Excepted on %s\n' % sample) 168 | break 169 | 170 | # nodename is the actual memory address (uid of the node) nodelabel 171 | # is the tree label 172 | node = int(node.replace('\t',''), 0) 173 | 174 | assert node not in visited 175 | visited.add(node) 176 | 177 | parsed_nw = [x.rsplit('@@@') for x in raw_nw.split('\t') if x] 178 | parsed_nd = [x.rsplit('@@@') for x in raw_nd.split('\t') if x] 179 | 180 | nwsum = float(sum([int(c) for _,c in parsed_nw])) 181 | 182 | yield (node, parsed_nw, parsed_nd, nwsum, tables, alpha, eta, L) 183 | 184 | 185 | def collect_term_term_count(Samples): 186 | """ 187 | Collects the number of times two terms co-occur across all topics; this is 188 | normalized for frequency, so beware 189 | """ 190 | 191 | # This version doesnt use the intermediate 192 | joint = defaultdict(float) 193 | marginal = defaultdict(float) 194 | for file_no, file in enumerate(Samples): 195 | for (node, nw, nd, nwsum, tables, alpha, eta, L) in ncrp_sample_iterator(file): 196 | for i, (word, count) in enumerate(nw): 197 | for k, (word2, count2) in enumerate(nw): 198 | marginal[intern(word)] += 1.0 199 | marginal[intern(word2)] += 1.0 200 | if i < k: 201 | (w1, w2) = sorted([word, word2]) 202 | joint[(intern(w1),intern(w2))] += 1.0 203 | 204 | return (joint, marginal) 205 | 206 | def collect_term_pmi(Samples): 207 | """ 208 | Computes the pmi from a joint and marginal distribution 209 | """ 210 | joint, marginal = collect_term_term_count(Samples) 211 | 212 | pmi_and_freq = defaultdict(float) 213 | for (w1,w2), freq in joint.iteritems(): 214 | pmi_and_freq[(intern(w1),intern(w2))] = (log(freq) - log(marginal[w1]*marginal[w2]), freq) 215 | return pmi_and_freq 216 | 217 | 218 | 219 | def load_append_ncrp_samples(Samples, node_term_dist, term_node_dist, 220 | node_doc_dist, doc_node_dist, prev, restrict_docs=set()): 221 | V = set() 222 | for file_no, file in enumerate(Samples): 223 | for (node, nw, nd, nwsum, tables, alpha, eta, L) in ncrp_sample_iterator(file): 224 | for (word, count) in nw: 225 | if float(count) / nwsum > TOLERANCE: 226 | node_term_dist[node][intern(word)] += int(count) 227 | term_node_dist[intern(word)][node] += int(count) 228 | V.add(intern(word)) 229 | for (_, doc_name, count) in nd: 230 | node_doc_dist[node][intern(doc_name)] += int(count) 231 | doc_node_dist[intern(doc_name)][node] += int(count) 232 | 233 | # prev stores the DAG structure 234 | tables = [int(x,0) for x in tables.split('\t') if x != ''] 235 | for t in tables: 236 | prev[t].add(node) 237 | 238 | return (alpha, eta, len(V), L) 239 | 240 | def get_smoothed_terms_for(doc, doc_to_categories, node_doc_dist, 241 | node_term_dist, word_of, category_of, alpha=0, 242 | eta=0): 243 | pw = defaultdict(float) 244 | if not doc_to_categories.has_key(doc): 245 | sys.stderr.write('missing [%s]\n' % doc) 246 | return pw 247 | 248 | sys.stderr.write('Quantizing to TOLERANCE=%f\n' % TOLERANCE) 249 | 250 | T = len(doc_to_categories[doc]) 251 | V = len(word_of) 252 | 253 | if alpha != None: 254 | sys.stderr.write('Smoothing with alpha=%f eta=%f\n' % (alpha,eta)) 255 | 256 | dsum = float(sum([node_doc_dist[category_of[c]][doc] for c in 257 | doc_to_categories[doc]])) 258 | 259 | for raw_node in doc_to_categories[doc]: 260 | node = category_of[raw_node] 261 | 262 | d = node_doc_dist[node][doc] 263 | # print node, 'in ndd?', node_doc_dist.has_key(node) 264 | # print node, 'in ntd?', node_term_dist.has_key(node) 265 | sys.stderr.write('found %d/%d of [%s] in [%s]\n' % (d, dsum, doc, node)) 266 | lpd = log(float(d)+alpha) - log(dsum+alpha*T) 267 | wsum = float(sum(node_term_dist[node].itervalues())) 268 | for word, w in node_term_dist[node].iteritems(): 269 | lp = lpd + log(float(w)+eta) - log(wsum+eta*V) 270 | pw[intern(word)] = log_add(lp, pw[intern(word)]) 271 | 272 | return pw 273 | 274 | def HACK_ncrp_get_smoothed_terms_for(doc, doc_node_dist, node_doc_dist, node_term_dist, V, T, alpha=0, eta=0): 275 | """ 276 | The bug here is that we can't get a list of all the possible nodes for doc; 277 | hence instead we have to just rely on where it is actually present 278 | (problematic) 279 | """ 280 | pw = defaultdict(float) 281 | 282 | sys.stderr.write('USING HACK SMOOTHED TERMS\n') 283 | sys.stderr.write('Quantizing to TOLERANCE=%f\n' % TOLERANCE) 284 | 285 | if alpha != None: 286 | sys.stderr.write('Smoothing with alpha=%f eta=%f\n' % (alpha,eta)) 287 | 288 | dsum = float(sum(doc_node_dist[doc].itervalues())) 289 | 290 | assert doc_node_dist.has_key(doc) 291 | 292 | for node, d in doc_node_dist[doc].iteritems(): 293 | sys.stderr.write('found %d/%d of [%s] in [%s]\n' % (d, dsum, doc, node)) 294 | lpd = log(float(d)+alpha) - log(dsum+alpha*T) 295 | wsum = float(sum(node_term_dist[node].itervalues())) 296 | for word, w in node_term_dist[node].iteritems(): 297 | lp = lpd + log(float(w)+eta) - log(wsum+eta*V) 298 | pw[intern(word)] = log_add(lp, pw[intern(word)]) 299 | 300 | return pw 301 | 302 | def build_sort(dist, to_show=100): 303 | """ 304 | Builds a sorted list summarizing the distribution 305 | """ 306 | sorted_dist = [] 307 | for concept in dist.keys(): 308 | attribs = '\t%s' % '\n\t'.join(['%s %d' % (v, k) for (v, k) in 309 | sorted(dist[concept].items(), key=itemgetter(1), 310 | reverse=True)[:to_show]]) 311 | count = sum(dist[concept].values()) 312 | 313 | sorted_dist.append((count, concept, attribs)) 314 | 315 | return sorted_dist 316 | 317 | def load_append_llda_samples(samples, word_of, doc_of, category_of, node_term_dist, 318 | term_node_dist, node_doc_dist, doc_node_dist, restrict_categories=set(), 319 | skip_noise=False): 320 | alpha, eta = 0, 0 321 | sys.stderr.write('Loading samples...\n') 322 | for sample in samples: 323 | sys.stderr.write('%s\n' % sample) 324 | f = open_or_bz2(sample) 325 | for line in f: 326 | if line.startswith('ll ='): 327 | # ll = -434060662.375486 (-434060662.375486 at 123) -627048446 alpha = 0.001000 eta = 0.100000 L = 454 328 | # ll = -432994653.484200 (-432994653.484200 at 100) -1071382526 alpha = 0.001000 eta = 0.100000 L = 211 329 | m = re.search('ll = .* alpha = (.*) eta = (.*) L = .*', line) 330 | alpha = float(m.group(1)) 331 | eta = float(m.group(2)) 332 | sys.stderr.write('Got alpha = %f eta = %f\n' % (alpha, eta)) 333 | continue 334 | 335 | concept, _, rest = line.partition('\t||\t') 336 | concept = int(concept, 0) 337 | 338 | if (not skip_noise or not category_of[concept].startswith('NOISE')) and \ 339 | (not restrict_categories or concept in restrict_categories): 340 | tokens = rest.strip().split('\t||\t') 341 | if len(tokens) == 3: 342 | m, w, d = tokens 343 | else: 344 | assert len(tokens) == 2 345 | m, w = tokens 346 | d = [] 347 | if w != '||': 348 | d = map(parse_int_float_count, d.split('\t')) 349 | w = map(parse_int_float_count, w.split('\t')) 350 | # print 'Found', len(d), 'docs and', len(w), 'terms at', category_of[concept] 351 | assert concept in category_of 352 | 353 | nw_sum = float(sum([c for _,c in w])) 354 | for ww, c in w: 355 | assert ww in word_of 356 | if nw_sum < 10000 or float(c) / nw_sum > TOLERANCE: # Allow small nodes and high prob things 357 | node_term_dist[intern(category_of[concept])][intern(word_of[ww])] += c 358 | term_node_dist[intern(word_of[ww])][intern(category_of[concept])] += c 359 | for dd, c in d: 360 | assert dd in doc_of 361 | node_doc_dist[intern(category_of[concept])][intern(doc_of[dd])] += c 362 | doc_node_dist[intern(doc_of[dd])][intern(category_of[concept])] += c 363 | 364 | return alpha, eta 365 | 366 | 367 | def load_all_llda_samples(DictionaryFile, MapOrBayes, restrict_docs=set(), 368 | skip_noise=False): 369 | node_term_dist = defaultdict(lambda: defaultdict(int)) 370 | term_node_dist = defaultdict(lambda: defaultdict(int)) 371 | node_doc_dist = defaultdict(lambda: defaultdict(int)) 372 | doc_node_dist = defaultdict(lambda: defaultdict(int)) 373 | 374 | sys.stderr.write('Quantizing to TOLERANCE=%f\n' % TOLERANCE) 375 | 376 | if MapOrBayes == 'map': 377 | Header = DictionaryFile.split('.dictionary')[0] + '-best' 378 | # Samples = [os.path.dirname(Header)+'/'+x for x in os.listdir(os.path.dirname(Header)) if x.startswith(os.path.basename(Header))] 379 | Samples = glob(Header+'*') 380 | else: 381 | Header = DictionaryFile.split('.dictionary')[0] + '-sample' 382 | # Samples = [os.path.dirname(Header)+'/'+x for x in os.listdir(os.path.dirname(Header)) if x.startswith(os.path.basename(Header))] 383 | Samples = glob(Header+'*') 384 | sys.stderr.write('%s\n' % Header) 385 | sys.stderr.write('%d\n' % len(Samples)) 386 | 387 | # Load the dictioanry 388 | (category_of, word_of, doc_of, doc_to_categories) = read_hlda_dictionary(DictionaryFile) 389 | 390 | # Compute all the categories covered by the interesting docs 391 | restrict_categories = set() 392 | for d in restrict_docs: 393 | if doc_to_categories.has_key(d): 394 | sys.stderr.write('FOUND [%s] at %r\n' % (d, [category_of[dd] for dd 395 | in doc_to_categories[d]])) 396 | restrict_categories.update(doc_to_categories[d]) 397 | else: 398 | sys.stderr.write('MISSING [%s]\n' % (d)) 399 | 400 | # Actually load the samples 401 | alpha, eta = load_append_llda_samples(Samples, word_of, doc_of, category_of, 402 | node_term_dist, term_node_dist, node_doc_dist, doc_node_dist, 403 | restrict_categories=restrict_categories, skip_noise=skip_noise) 404 | 405 | return (word_of, doc_of, category_of, doc_to_categories, node_term_dist, 406 | term_node_dist, node_doc_dist, doc_node_dist, alpha, eta) 407 | -------------------------------------------------------------------------------- /scripts/summarize_ncrp_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a set of hlda samples, print out some statistics about them. 3 | """ 4 | 5 | import sys 6 | import os 7 | from model_file_utils import * 8 | from operator import itemgetter 9 | from collections import defaultdict 10 | 11 | SampleBase = sys.argv[1] 12 | MapOrBayes = sys.argv[2] 13 | assert MapOrBayes in ['map', 'bayes'] 14 | 15 | ToShow = 5 16 | 17 | # Load the model 18 | # If we specify bayes, then average over the last ToKeep samples. 19 | sys.stderr.write('loading model...\n') 20 | (node_term_dist, term_node_dist, node_doc_dist, doc_node_dist, parents, alpha, eta, V, T) = load_all_ncrp_samples(SampleBase, MapOrBayes, ToKeep=10) 21 | 22 | # Find the root node (hint: it doesn't have parents) 23 | # root = [n for n in node_term_dist.iterkeys() if not parents[n]] 24 | # assert len(root) == 1 25 | # root = root[0] 26 | 27 | # Print out a summary of the node-term distributions 28 | for node, terms in node_term_dist.iteritems(): 29 | print 'Node:', node, 'Parents:', parents[node] 30 | print '\n'.join(['\t%.1f: %s' % (f,w) for w,f in sorted(terms.iteritems(), key=itemgetter(1), reverse=True)[:ToShow]]) 31 | -------------------------------------------------------------------------------- /test_data/testW: -------------------------------------------------------------------------------- 1 | A a:1 b:1 2 | AA a:2 b:1 3 | AAA b:2 4 | AAAA a:4 5 | AAAAA a:1 b:1 6 | B c:1 d:1 7 | BB c:2 d:1 8 | BBB d:2 9 | BBBB c:4 10 | BBBBB c:1 d:1 11 | C a:1 b:1 12 | CC a:2 b:1 13 | CCC b:2 14 | CCCC a:4 15 | CCCCC a:1 b:1 16 | D c:1 d:1 17 | DD c:2 d:1 18 | DDD d:2 19 | DDDD c:4 20 | DDDDD c:1 d:1 21 | -------------------------------------------------------------------------------- /test_data/testWclustered: -------------------------------------------------------------------------------- 1 | doc_1 w1 a:1 b:2 w2 b:1 c:2 w3 b:1 b:1 2 | doc_2 w2 c:2 w3 b:1 c:2 w4 c:1 c:1 b:1 e:3 3 | doc_3 w4 d:1 b:1 d:2 w5 c:1 d:1 b:1 e:3 4 | doc_4 w6 dd:1 d:2 w7 dd:1 dd:1 d:2 w7 b:1 e:3 5 | doc_5 w1 c:1 d:2 w1 c:1 c:1 d:2 w8 b:1 e:3 6 | doc_6 w1 a:1 b:2 w1 b:1 dd:2 w8 b:1 b:1 a:3 7 | doc_7 w1 dd:2 w1 b:1 dd:2 w8 dd:1 dd:1 b:1 a:3 8 | doc_8 w1 d:1 b:1 d:2 w10 dd:1 d:1 b:1 a:3 9 | -------------------------------------------------------------------------------- /test_data/testWnoisy: -------------------------------------------------------------------------------- 1 | A a:1 b:1 f:2 2 | AA a:2 b:1 f:2 3 | AAA b:2 e:1 f:2 4 | AAAA a:4 e:1 f:2 5 | AAAAA a:1 b:1 e:1 f:2 6 | B c:1 d:1 e:1 f:2 7 | BB c:2 d:1 e:1 f:2 8 | BBB d:2 e:1 f:2 9 | BBBB c:4 e:1 f:2 10 | BBBBB c:1 d:1 e:1 f:2 11 | C a:1 b:1 e:1 12 | CC a:2 b:1 e:1 13 | CCC b:2 e:1 f:2 14 | CCCC a:4 e:1 f:2 15 | CCCCC a:1 b:1 f:2 16 | D c:1 d:1 e:1 f:2 17 | DD c:2 d:1 e:1 f:2 18 | DDD d:2 e:1 f:2 19 | DDDD c:4 e:1 20 | DDDDD c:1 d:1 e:1 f:2 21 | -------------------------------------------------------------------------------- /test_data/testWtree: -------------------------------------------------------------------------------- 1 | A A AA AA 2 | AA AA AAA AAA 3 | AAA AAA AAAA AAAA 4 | AAAA AAAA AAAAA AAAAA 5 | AAAAA AAAAA B B 6 | B B BB BB 7 | BB BB BBB BBBB 8 | BBB BBB BBBB BBBBB 9 | BBBB BBBB BBBBB C 10 | BBBBB BBBBB C CC 11 | C C CC CCC 12 | CC CC CCC CCCC 13 | CCC CCC CCCC CCCCC 14 | CCCC CCCC CCCCC D 15 | CCCCC CCCCC D DD 16 | D D DD DDD 17 | DD DD DDD DDDD 18 | DDD DDD DDDD DDDDD 19 | DDDD DDDD DDDDD A 20 | DDDDD DDDDD A BBB 21 | --------------------------------------------------------------------------------