├── .gitignore ├── README.md ├── bin └── mallet ├── input ├── empty.cons ├── example.cons ├── synthetic-topic-input.mallet ├── synthetic │ └── .gitignore └── tree_hyperparams ├── output └── .gitignore ├── pom.xml ├── scripts ├── extract_constraints_wordnet.py ├── generate_data_from_protofiles.py └── vocab_filter.py ├── src └── main │ ├── java │ ├── edu │ │ └── umd │ │ │ └── umiacs │ │ │ └── itm │ │ │ ├── tree │ │ │ ├── GenerateVocab.java │ │ │ ├── HIntIntDoubleHashMap.java │ │ │ ├── HIntIntIntHashMap.java │ │ │ ├── HIntIntObjectHashMap.java │ │ │ ├── Node.java │ │ │ ├── NonZeroPath.java │ │ │ ├── Path.java │ │ │ ├── PriorTree.java │ │ │ ├── TopicSampler.java │ │ │ ├── TopicTreeWalk.java │ │ │ ├── TreeTopicModel.java │ │ │ ├── TreeTopicModelFast.java │ │ │ ├── TreeTopicModelFastEst.java │ │ │ ├── TreeTopicModelFastEstSortW.java │ │ │ ├── TreeTopicModelFastSortW.java │ │ │ ├── TreeTopicModelFastSortW1.java │ │ │ ├── TreeTopicModelFastSortW2.java │ │ │ ├── TreeTopicModelNaive.java │ │ │ ├── TreeTopicSampler.java │ │ │ ├── TreeTopicSamplerFast.java │ │ │ ├── TreeTopicSamplerFastEst.java │ │ │ ├── TreeTopicSamplerFastEstSortD.java │ │ │ ├── TreeTopicSamplerFastSortD.java │ │ │ ├── TreeTopicSamplerHashD.java │ │ │ ├── TreeTopicSamplerNaive.java │ │ │ ├── TreeTopicSamplerSortD.java │ │ │ ├── TwoIntHashMap.java │ │ │ ├── testFast.java │ │ │ └── testNaive.java │ │ │ └── tui │ │ │ └── ITMVectors2Topics.java │ └── topicmod_projects_ldawn │ │ └── WordnetFile.java │ └── resources │ └── cc │ └── mallet │ └── util │ └── resources │ └── logging.properties └── tree ├── lib ├── __init__.py ├── flags.py └── proto │ ├── __init__.py │ ├── corpus_pb2.py │ ├── wordnet_file.proto │ └── wordnet_file_pb2.py ├── ontology_writer.py └── ontology_writer_wordleaf.py /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | tmp/ 3 | *.pyc 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Interactive tree topic modeling 2 | =============================== 3 | 4 | Yuening Hu, Jordan Boyd-Graber, and Brianna Satinoff. 5 | [Interactive Topic Modeling](http://umiacs.umd.edu/~jbg/docs/itm.pdf). 6 | _Association for Computational Linguistics_, 2011. 7 | 8 | Topic models have been used extensively as a tool for corpus exploration, and 9 | a cottage industry has developed to tweak topic models to better encode human 10 | intuitions or to better model data. However, creating such extensions requires 11 | expertise in machine learning unavailable to potential end-users of topic 12 | modeling software. In this work, we develop a framework for allowing users to 13 | iteratively refine the topics discovered by models such as latent Dirichlet 14 | allocation (LDA) by adding constraints that enforce that sets of words must 15 | appear together in the same topic. 16 | 17 | ------------------------------------------------------------------------------ 18 | 19 | The project code has been Mavenified and lightly edited by Travis Brown. All 20 | dependencies are now managed by Maven and are not packaged with the project. 21 | The Java source files have also been moved out of the MALLET namespace, and it 22 | is no longer necessary to merge them manually with the MALLET source. 23 | 24 | Compiling 25 | --------- 26 | 27 | [Apache Maven](http://maven.apache.org/) is required to build this project. 28 | The following command will download all dependencies (as necessary) and 29 | compile the code: 30 | 31 | mvn compile 32 | 33 | The class files will now be available in `target/classes`, and will be used 34 | when you run the `bin/mallet` script in subsequent steps. 35 | 36 | Importing documents 37 | ------------------- 38 | 39 | The following command will convert documents into the MALLET format as 40 | described [in the MALLET documentation](http://mallet.cs.umass.edu/import.php): 41 | 42 | bin/mallet import-dir --input ../../../data/synthetic/synth_word \ 43 | --output input/synthetic-topic-input.mallet --keep-sequence 44 | 45 | Note that for this synthetic data set we do not use `--remove-stopwords`, but 46 | in general you would want to include it here. Note also that the `input` 47 | directory contains the `synthetic-topic-input.mallet` file, so you can skip 48 | this step and continue directly to the steps below. 49 | 50 | Generating vocabulary file 51 | -------------------------- 52 | 53 | bin/mallet train-topics --input input/synthetic-topic-input.mallet \ 54 | --use-tree-lda true --generate-vocab true --vocab input/synthetic/synthetic.voc 55 | 56 | Generating the tree 57 | ------------------- 58 | 59 | The following command requires Python 2, so you may need to change the 60 | `python` command if Python 3 is the default on your system. 61 | 62 | python tree/ontology_writer_wordleaf.py --vocab=input/synthetic/synthetic.voc \ 63 | --constraints=input/empty.cons --write_wordnet=False \ 64 | --write_constraints=True --wnname=input/synthetic/synthetic.wn 65 | 66 | Note that the constraints file can be empty, in which case the output is a 67 | tree with symmetric priors, working as in normal LDA. 68 | 69 | You can check the generated tree structure with the following commands (note 70 | that Protobuf 2.3 is required): 71 | 72 | cat input/synthetic/synthetic.wn.0 | protoc tree/lib/proto/wordnet_file.proto \ 73 | --decode=topicmod_projects_ldawn.WordNetFile \ 74 | --proto_path=tree/lib/proto/ > input/synthetic/tmp0.txt 75 | 76 | cat input/synthetic/synthetic.wn.1 | protoc tree/lib/proto/wordnet_file.proto \ 77 | --decode=topicmod_projects_ldawn.WordNetFile \ 78 | --proto_path=tree/lib/proto/ > input/synthetic/tmp1.txt 79 | 80 | Training the tree topic model 81 | ----------------------------- 82 | 83 | bin/mallet train-topics --input input/synthetic-topic-input.mallet --num-topics 5 \ 84 | --num-iterations 300 --alpha 0.5 --random-seed 0 --output-interval 10 \ 85 | --output-dir output/model --use-tree-lda True --tree-model-type fast \ 86 | --tree input/synthetic/synthetic.wn --tree-hyperparameters input/tree_hyperparams \ 87 | --vocab input/synthetic/synthetic.voc --clear-type term --constraint input/empty.cons 88 | 89 | Resuming the tree topic model 90 | ----------------------------- 91 | 92 | bin/mallet train-topics --input input/synthetic-topic-input.mallet --num-topics 5 \ 93 | --num-iterations 600 --alpha 0.5 --random-seed 0 --output-interval 10 \ 94 | --output-dir output/model --use-tree-lda True \ 95 | --tree input/synthetic/synthetic.wn --tree-hyperparameters input/tree_hyperparams \ 96 | --vocab input/synthetic/synthetic.voc --clear-type term --constraint input/empty.cons \ 97 | --resume true --resume-dir output/model 98 | 99 | -------------------------------------------------------------------------------- /bin/mallet: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | malletdir=`dirname $0` 4 | malletdir=`dirname $malletdir` 5 | 6 | cp=$CLASSPATH:`mvn dependency:build-classpath | grep "^[^\[]"`:$malletdir/target/classes 7 | 8 | MEMORY=1g 9 | 10 | JAVA_COMMAND="java -Xmx$MEMORY -ea -Djava.awt.headless=true -Dfile.encoding=UTF-8 -server -classpath $cp" 11 | 12 | CMD=$1 13 | shift 14 | 15 | help() 16 | { 17 | cat < 2 | 5 | 4.0.0 6 | edu.umd.umiacs 7 | itm 8 | Interactive Topic Modeling 9 | 0.1.0 10 | 11 | 12 | 2.9.1 13 | 14 | 15 | 16 | 17 | scala-tools.releases 18 | Scala-Tools Dependencies Repository for Releases 19 | http://scala-tools.org/repo-releases 20 | 21 | 22 | 23 | 24 | 25 | net.sf.trove4j 26 | trove4j 27 | 2.0.2 28 | 29 | 30 | cc.mallet 31 | mallet 32 | 2.0.7 33 | 34 | 35 | com.google.protobuf 36 | protobuf-java 37 | 2.3.0 38 | 39 | 40 | org.jgrapht 41 | jgrapht-jdk1.5 42 | 0.7.3 43 | 44 | 45 | org.scala-lang 46 | scala-library 47 | ${scala.version} 48 | 49 | 50 | org.specs2 51 | specs2_${scala.version} 52 | 1.6.1 53 | test 54 | 55 | 56 | 57 | 58 | 59 | 60 | org.apache.maven.plugins 61 | maven-compiler-plugin 62 | 2.3.2 63 | 64 | 1.5 65 | 1.5 66 | 67 | 68 | 69 | org.scala-tools 70 | maven-scala-plugin 71 | 2.14.3 72 | 73 | ${project.build.sourceEncoding} 74 | 75 | -Xmx1024m 76 | 77 | 78 | -unchecked 79 | -deprecation 80 | 81 | 82 | 83 | 84 | scala-test-compile 85 | 86 | testCompile 87 | 88 | test-compile 89 | 90 | 91 | scala-compile-first 92 | 93 | add-source 94 | compile 95 | 96 | process-resources 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /scripts/extract_constraints_wordnet.py: -------------------------------------------------------------------------------- 1 | from topicmod.util.wordnet import load_wn 2 | from nltk.corpus.reader.wordnet import WordNetError 3 | from topicmod.util import flags 4 | from collections import defaultdict 5 | import codecs 6 | 7 | pos_tags = ["n", "v", "a", "r"] 8 | 9 | def readVocab(vocabname): 10 | infile = open(vocabname, 'r') 11 | vocab = defaultdict(dict) 12 | for line in infile: 13 | line = line.strip() 14 | ww = line.split('\t') 15 | vocab[ww[0]][ww[1]] = 1 16 | infile.close() 17 | return vocab 18 | 19 | def generateCons(vocab, wn, outfilename, num_cons): 20 | lang = '0' 21 | cons = defaultdict(dict) 22 | for word in vocab[lang]: 23 | for pos in pos_tags: 24 | synsets = wn.synsets(word, pos) 25 | for syn in synsets: 26 | if not syn.offset in cons[pos]: 27 | cons[pos][syn.offset] = set() 28 | cons[pos][syn.offset].add(word) 29 | 30 | outfile = codecs.open(outfilename, 'w', 'utf-8') 31 | multipaths = defaultdict(dict) 32 | count = 0 33 | for pos in cons: 34 | for syn in cons[pos]: 35 | if len(cons[pos][syn]) > 1: 36 | count += 1 37 | if count <= num_cons: 38 | words = list(cons[pos][syn]) 39 | tmp = "\t".join(words) 40 | outfile.write("MERGE_\t" + tmp + "\n") 41 | for word in words: 42 | if not pos in multipaths[word]: 43 | multipaths[word][pos] = set() 44 | multipaths[word][pos].add(syn) 45 | outfile.close() 46 | 47 | outfilename = outfilename.replace(".cons", ".interested") 48 | print outfilename 49 | outfile = codecs.open(outfilename, 'w', 'utf-8') 50 | count_word = 0 51 | count_sense = 0 52 | word_senses_count = defaultdict() 53 | im_words = "" 54 | for word in multipaths: 55 | word_senses_count[word] = 0 56 | count_word += 1 57 | tmp = word 58 | for pos in multipaths[word]: 59 | tmp += '\t' + pos 60 | for index in multipaths[word][pos]: 61 | word_senses_count[word] += 1 62 | count_sense += 1 63 | tmp += '\t' + str(index) 64 | if word_senses_count[word] > 1: 65 | im_words += word + " " 66 | outfile.write(tmp + '\n') 67 | outfile.write("\nThe total number of cons words: " + str(count_word) + "\n") 68 | outfile.write("\nThe total number of cons words senses: " + str(count_sense) + "\n") 69 | outfile.write("\nInteresting words: " + im_words + "\n") 70 | outfile.close() 71 | 72 | 73 | flags.define_string("vocab", None, "The input vocab") 74 | flags.define_string("output", None, "The output constraint file") 75 | flags.define_int("num_cons", 0, "The number of constraints we want") 76 | 77 | if __name__ == "__main__": 78 | 79 | flags.InitFlags() 80 | wordnet_path = "../../../data/wordnet/" 81 | eng_wn = load_wn("3.0", wordnet_path, "wn") 82 | vocab = readVocab(flags.vocab) 83 | generateCons(vocab, eng_wn, flags.output, flags.num_cons) 84 | 85 | -------------------------------------------------------------------------------- /scripts/generate_data_from_protofiles.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from topicmod.util import flags 3 | from topicmod.corpora.proto.corpus_pb2 import * 4 | from topicmod.corpora.proto.wordnet_file_pb2 import * 5 | from collections import defaultdict 6 | from nltk import FreqDist 7 | import codecs 8 | 9 | 10 | def gen_files(proto_corpus_dir, output_dir, lemma_flag): 11 | 12 | doc_num = 0 13 | vocab = defaultdict(dict) 14 | tfidf = defaultdict(FreqDist) 15 | frequency = defaultdict(FreqDist) 16 | 17 | for ii in glob("%s/*.index" % proto_corpus_dir): 18 | inputfile = open(ii, 'rb') 19 | protocorpus = Corpus() 20 | protocorpus.ParseFromString(inputfile.read()) 21 | 22 | if lemma_flag: 23 | source = protocorpus.lemmas 24 | else: 25 | source = protocorpus.tokens 26 | 27 | for ii in source: 28 | lang = ii.language 29 | for jj in ii.terms: 30 | if jj.id in vocab[lang]: 31 | assert vocab[lang][jj.id] == jj.original 32 | else: 33 | vocab[lang][jj.id] = jj.original 34 | print len(vocab[lang]) 35 | 36 | for dd in protocorpus.doc_filenames: 37 | doc_num += 1 38 | if doc_num % 1000 == 0: 39 | print "Finished reading", doc_num, "documents." 40 | 41 | docfile = open("%s/%s" % (proto_corpus_dir, dd), 'rb') 42 | doc = Document() 43 | doc.ParseFromString(docfile.read()) 44 | lang = doc.language 45 | outputstring = "" 46 | 47 | for jj in doc.sentences: 48 | for kk in jj.words: 49 | if lemma_flag: 50 | word = vocab[lang][kk.lemma] 51 | tfidf[lang].inc(kk.lemma, kk.tfidf) 52 | frequency[lang].inc(kk.lemma) 53 | else: 54 | word = vocab[lang][kk.token] 55 | tfidf[lang].inc(kk.token, kk.tfidf) 56 | frequency[lang].inc(kk.token) 57 | 58 | outputstring += word + " " 59 | outputstring = outputstring.strip() 60 | outputstring += "\n" 61 | 62 | outputfilename = dd.split('/'); 63 | outputfilename = output_dir + "/" + outputfilename[-1] 64 | outputfile = open(outputfilename, 'w') 65 | outputfile.write(outputstring) 66 | outputfile.close() 67 | 68 | inputfile.close() 69 | 70 | return vocab, tfidf, frequency 71 | 72 | 73 | def gen_vocab(vocab, tfidf, frequency, select_tfidf, outputname, vocab_limit, freq_limit): 74 | 75 | for ii in tfidf: 76 | for jj in tfidf[ii]: 77 | tfidf[ii][jj] /= frequency[ii][jj] 78 | 79 | if select_tfidf: 80 | rank = tfidf 81 | else: 82 | rank = frequency 83 | 84 | o = codecs.open(outputname, 'w', 'utf-8') 85 | for ii in rank: 86 | count = 0 87 | for jj in rank[ii]: 88 | count += 1 89 | if count <= vocab_limit and frequency[ii][jj] >= freq_limit: 90 | word = vocab[ii][jj] 91 | o.write(u"%i\t%s\t%f\t%i\n" % (ii, word, tfidf[ii][jj], frequency[ii][jj])) 92 | 93 | o.close() 94 | 95 | 96 | flags.define_string("proto_corpus", None, "The proto files") 97 | flags.define_bool("lemma", False, "Use lemma or tokens") 98 | flags.define_bool("select_tfidf", False, "select the vocab by tfidf or frequency") 99 | flags.define_string("output", "", "Where we output the preprocessed data") 100 | flags.define_string("vocab", None, "Where we output the vocab") 101 | flags.define_int("vocab_limit", 10000, "The vocab size") 102 | flags.define_int("freq_limit", 20, "The minimum frequency of each word") 103 | 104 | if __name__ == "__main__": 105 | 106 | flags.InitFlags() 107 | [vocab, tfidf, frequency] = gen_files(flags.proto_corpus, flags.output, flags.lemma) 108 | gen_vocab(vocab, tfidf, frequency, flags.select_tfidf, flags.vocab, flags.vocab_limit, flags.freq_limit) 109 | 110 | -------------------------------------------------------------------------------- /scripts/vocab_filter.py: -------------------------------------------------------------------------------- 1 | from topicmod.util import flags 2 | from collections import defaultdict 3 | from nltk import FreqDist 4 | import codecs 5 | 6 | def readStats(filename): 7 | tfidf = defaultdict(FreqDist) 8 | frequency = defaultdict(FreqDist) 9 | infile = open(filename, 'r') 10 | for line in infile: 11 | line = line.strip() 12 | ww = line.split('\t') 13 | tfidf[ww[0]].inc(ww[1], float(ww[2])) 14 | frequency[ww[0]].inc(ww[1], int(ww[3])) 15 | 16 | infile.close() 17 | return tfidf, frequency 18 | 19 | def sortVocab(infilename, tfidf, frequency, option, outfilename): 20 | if option == 1: 21 | source = tfidf 22 | else: 23 | source = frequency 24 | 25 | infile = open(infilename, 'r') 26 | vocab = defaultdict(FreqDist) 27 | for line in infile: 28 | line = line.strip() 29 | ww = line.split('\t') 30 | lang = ww[0] 31 | if source[lang][ww[1]] == 0: 32 | print source[lang][ww[1]], ww[1] 33 | vocab[lang].inc(ww[1], source[lang][ww[1]]) 34 | infile.close() 35 | 36 | outfile = codecs.open(outfilename, 'w', 'utf-8') 37 | for ii in vocab: 38 | for jj in vocab[ii]: 39 | outfile.write(u"%s\t%s\n" % (ii, jj)) 40 | #outfile.write(u"%s\t%s\t%f\t%i\n" % (ii, jj, tfidf[ii][jj], frequency[ii][jj])) 41 | outfile.close() 42 | 43 | 44 | flags.define_string("stats_vocab", None, "The proto files") 45 | flags.define_string("input_vocab", None, "Where we get the original vocab") 46 | flags.define_int("option", 0, "1: tfidf; others: frequency") 47 | flags.define_string("sorted_vocab", None, "Where we output the vocab") 48 | 49 | if __name__ == "__main__": 50 | 51 | flags.InitFlags() 52 | [tfidf, frequency] = readStats(flags.stats_vocab) 53 | 54 | sortVocab(flags.input_vocab, tfidf, frequency, flags.option, flags.sorted_vocab) 55 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/GenerateVocab.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.io.PrintStream; 6 | import java.util.Arrays; 7 | 8 | import gnu.trove.TIntArrayList; 9 | import gnu.trove.TIntIntHashMap; 10 | import gnu.trove.TObjectIntHashMap; 11 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 12 | import cc.mallet.types.Alphabet; 13 | import cc.mallet.types.FeatureSequence; 14 | import cc.mallet.types.Instance; 15 | import cc.mallet.types.InstanceList; 16 | 17 | 18 | /** 19 | * This class generates the vocab file from mallet input. 20 | * Main entrance: genVocab() 21 | * Author: Yuening Hu 22 | */ 23 | public class GenerateVocab { 24 | 25 | /** 26 | * After the preprocessing of mallet, a vocab is needed to generate 27 | * the prior tree. So this function simply read in the alphabet 28 | * of the training data and output as the vocab. 29 | * Currently, the language_id is fixed. 30 | */ 31 | public static void genVocab_old(InstanceList data, String vocab) { 32 | try{ 33 | File file = new File(vocab); 34 | PrintStream out = new PrintStream (file); 35 | 36 | // language_id is fixed now, but can be extended to 37 | // multiple languages 38 | int language_id = 0; 39 | Alphabet alphabet = data.getAlphabet(); 40 | for(int ii = 0; ii < alphabet.size(); ii++) { 41 | String word = alphabet.lookupObject(ii).toString(); 42 | System.out.println(word); 43 | out.println(language_id + "\t" + word); 44 | } 45 | out.close(); 46 | } catch (IOException e) { 47 | e.getMessage(); 48 | } 49 | } 50 | 51 | public static void genVocab(InstanceList data, String vocab) { 52 | 53 | class WordCount implements Comparable { 54 | String word; 55 | int count; 56 | public WordCount (String word, int count) { this.word = word; this.count = count; } 57 | public final int compareTo (Object o2) { 58 | if (count > ((WordCount)o2).count) 59 | return -1; 60 | else if (count == ((WordCount)o2).count) 61 | return 0; 62 | else return 1; 63 | } 64 | } 65 | 66 | try{ 67 | TObjectIntHashMap freq = new TObjectIntHashMap (); 68 | Alphabet alphabet = data.getAlphabet(); 69 | for(int ii = 0; ii < alphabet.size(); ii++) { 70 | String word = alphabet.lookupObject(ii).toString(); 71 | freq.put(word, 0); 72 | } 73 | 74 | for (Instance instance : data) { 75 | FeatureSequence original_tokens = (FeatureSequence) instance.getData(); 76 | for (int jj = 0; jj < original_tokens.getLength(); jj++) { 77 | String word = (String) original_tokens.getObjectAtPosition(jj); 78 | freq.adjustValue(word, 1); 79 | } 80 | } 81 | 82 | WordCount[] array = new WordCount[freq.keys().length]; 83 | int index = -1; 84 | for(Object o : freq.keys()) { 85 | String word = (String)o; 86 | int count = freq.get(word); 87 | index++; 88 | array[index] = new WordCount(word, count); 89 | } 90 | 91 | Arrays.sort(array); 92 | 93 | 94 | File file = new File(vocab); 95 | PrintStream out = new PrintStream (file); 96 | 97 | // language_id is fixed now, but can be extended to 98 | // multiple languages 99 | int language_id = 0; 100 | for(int ii = 0; ii < array.length; ii++) { 101 | out.println(language_id + "\t" + array[ii].word + "\t" + array[ii].count); 102 | } 103 | out.close(); 104 | 105 | } catch (IOException e) { 106 | e.getMessage(); 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/HIntIntDoubleHashMap.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntDoubleHashMap; 4 | 5 | /** 6 | * This class defines a two level hashmap, so a value will be indexed by two keys. 7 | * The value is double, and two keys are both int. 8 | * Author: Yuening Hu 9 | */ 10 | public class HIntIntDoubleHashMap extends TwoIntHashMap { 11 | /** 12 | * If keys do not exist, insert value. 13 | * Else update with the new value. 14 | */ 15 | public void put(int key1, int key2, double value) { 16 | if(! this.data.contains(key1)) { 17 | this.data.put(key1, new TIntDoubleHashMap()); 18 | } 19 | TIntDoubleHashMap tmp = this.data.get(key1); 20 | tmp.put(key2, value); 21 | } 22 | 23 | /** 24 | * Return the value indexed by key1 and key2. 25 | */ 26 | public double get(int key1, int key2) { 27 | if (this.data.contains(key1)) { 28 | TIntDoubleHashMap tmp1 = this.data.get(key1); 29 | if (tmp1.contains(key2)) { 30 | return tmp1.get(key2); 31 | } 32 | } 33 | System.out.println("HIntIntDoubleHashMap: key does not exist!"); 34 | return -1; 35 | } 36 | 37 | /** 38 | * Remove the second key 39 | */ 40 | public void removeKey2(int key1, int key2) { 41 | if (this.data.contains(key1)) { 42 | this.data.get(key1).remove(key2); 43 | } 44 | } 45 | } 46 | 47 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/HIntIntIntHashMap.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntIntHashMap; 4 | 5 | /** 6 | * This class defines a two level hashmap, so a value will be indexed by two keys. 7 | * The value is int, and two keys are both int. 8 | * Author: Yuening Hu 9 | */ 10 | public class HIntIntIntHashMap extends TwoIntHashMap { 11 | /** 12 | * If keys do not exist, insert value. 13 | * Else update with the new value. 14 | */ 15 | public void put(int key1, int key2, int value) { 16 | if(! this.data.contains(key1)) { 17 | this.data.put(key1, new TIntIntHashMap()); 18 | } 19 | TIntIntHashMap tmp = this.data.get(key1); 20 | tmp.put(key2, value); 21 | } 22 | 23 | /** 24 | * Return the value indexed by key1 and key2. 25 | */ 26 | public int get(int key1, int key2) { 27 | if (this.contains(key1, key2)) { 28 | return this.data.get(key1).get(key2); 29 | } else { 30 | System.out.println("HIntIntIntHashMap: key does not exist!"); 31 | return 0; 32 | } 33 | } 34 | 35 | /** 36 | * Adjust the value indexed by the key pair (key1, key2) by the specified amount. 37 | */ 38 | public void adjustValue(int key1, int key2, int increment) { 39 | int old = this.get(key1, key2); 40 | this.put(key1, key2, old+increment); 41 | } 42 | 43 | /** 44 | * If the key pair (key1, key2) exists, adjust the value by the specified amount, 45 | * Or insert the new value. 46 | */ 47 | public void adjustOrPutValue(int key1, int key2, int increment, int newvalue) { 48 | if (this.contains(key1, key2)) { 49 | int old = this.get(key1, key2); 50 | this.put(key1, key2, old+increment); 51 | } else { 52 | this.put(key1, key2, newvalue); 53 | } 54 | } 55 | 56 | /** 57 | * Remove the first key 58 | */ 59 | public void removeKey1(int key1) { 60 | this.data.remove(key1); 61 | } 62 | 63 | /** 64 | * Remove the second key 65 | */ 66 | public void removeKey2(int key1, int key2) { 67 | if (this.data.contains(key1)) { 68 | this.data.get(key1).remove(key2); 69 | } 70 | } 71 | } 72 | 73 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/HIntIntObjectHashMap.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntObjectHashMap; 4 | 5 | /** 6 | * This class defines a two level hashmap, so a value will be indexed by two keys. 7 | * The value is int, and two keys are both int. 8 | * Author: Yuening Hu 9 | */ 10 | public class HIntIntObjectHashMap extends TwoIntHashMap> { 11 | /** 12 | * If keys do not exist, insert value. 13 | * Else update with the new value. 14 | */ 15 | public void put(int key1, int key2, V value) { 16 | if(! this.data.contains(key1)) { 17 | this.data.put(key1, new TIntObjectHashMap()); 18 | } 19 | TIntObjectHashMap tmp = this.data.get(key1); 20 | tmp.put(key2, value); 21 | } 22 | 23 | /** 24 | * Return the HashMap indexed by the first key. 25 | */ 26 | public TIntObjectHashMap get(int key1) { 27 | return this.data.get(key1); 28 | } 29 | 30 | /** 31 | * Return the value indexed by key1 and key2. 32 | */ 33 | public V get(int key1, int key2) { 34 | if (this.contains(key1, key2)) { 35 | return this.data.get(key1).get(key2); 36 | } else { 37 | System.out.println("HIntIntObjectHashMap: key does not exist! " + key1 + " " + key2); 38 | return null; 39 | } 40 | } 41 | } 42 | 43 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/Node.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | 6 | /** 7 | * This class defines a node, which might have children, 8 | * and a distribution scaled by the node prior over the children. 9 | * A node is a synset, which might have children nodes and words 10 | * at the same time. 11 | * Author: Yuening Hu 12 | */ 13 | public class Node { 14 | int offset; 15 | double rawCount; 16 | double hypoCount; 17 | String hyperparamName; 18 | 19 | TIntArrayList words; 20 | TDoubleArrayList wordsCount; 21 | TIntArrayList childOffsets; 22 | 23 | int numChildren; 24 | int numPaths; 25 | int numWords; 26 | 27 | double transitionScalor; 28 | TDoubleArrayList transitionPrior; 29 | 30 | public Node() { 31 | this.words = new TIntArrayList (); 32 | this.wordsCount = new TDoubleArrayList (); 33 | this.childOffsets = new TIntArrayList (); 34 | this.transitionPrior = new TDoubleArrayList (); 35 | this.numChildren = 0; 36 | this.numWords = 0; 37 | this.numPaths = 0; 38 | } 39 | 40 | /** 41 | * Initialize the prior distribution. 42 | */ 43 | public void initializePrior(int size) { 44 | for (int ii = 0; ii < size; ii++ ) { 45 | this.transitionPrior.add(0.0); 46 | } 47 | } 48 | 49 | /** 50 | * Initialize the prior distribution. 51 | */ 52 | public void setOffset(int val) { 53 | this.offset = val; 54 | } 55 | 56 | /** 57 | * set the raw count. 58 | */ 59 | public void setRawCount(double count) { 60 | this.rawCount = count; 61 | } 62 | 63 | /** 64 | * set the hypo count. 65 | */ 66 | public void setHypoCount(double count) { 67 | this.hypoCount = count; 68 | } 69 | 70 | /** 71 | * set the hyperparameter name of this node. 72 | */ 73 | public void setHyperparamName(String name) { 74 | this.hyperparamName = name; 75 | } 76 | 77 | /** 78 | * set the prior scaler. 79 | */ 80 | public void setTransitionScalor(double val) { 81 | this.transitionScalor = val; 82 | } 83 | 84 | /** 85 | * set the prior for the given child index. 86 | */ 87 | public void setPrior(int index, double value) { 88 | this.transitionPrior.set(index, value); 89 | } 90 | 91 | /** 92 | * Add a child, which is defined by the offset. 93 | */ 94 | public void addChildrenOffset(int childOffset) { 95 | this.childOffsets.add(childOffset); 96 | this.numChildren += 1; 97 | } 98 | 99 | /** 100 | * Add a word. 101 | */ 102 | public void addWord(int wordIndex, double wordCount) { 103 | this.words.add(wordIndex); 104 | this.wordsCount.add(wordCount); 105 | this.numWords += 1; 106 | } 107 | 108 | /** 109 | * Increase the number of paths. 110 | */ 111 | public void addPaths(int inc) { 112 | this.numPaths += inc; 113 | } 114 | 115 | /** 116 | * return the offset of current node. 117 | */ 118 | public int getOffset() { 119 | return this.offset; 120 | } 121 | 122 | /** 123 | * return the number of children. 124 | */ 125 | public int getNumChildren() { 126 | return this.numChildren; 127 | } 128 | 129 | /** 130 | * return the number of words. 131 | */ 132 | public int getNumWords() { 133 | return this.numWords; 134 | } 135 | 136 | /** 137 | * return the child offset given the child index. 138 | */ 139 | public int getChild(int child_index) { 140 | return this.childOffsets.get(child_index); 141 | } 142 | 143 | /** 144 | * return the word given the word index. 145 | */ 146 | public int getWord(int word_index) { 147 | return this.words.get(word_index); 148 | } 149 | 150 | /** 151 | * return the word count given the word index. 152 | */ 153 | public double getWordCount(int word_index) { 154 | return this.wordsCount.get(word_index); 155 | } 156 | 157 | /** 158 | * return the hypocount of the node. 159 | */ 160 | public double getHypoCount() { 161 | return this.hypoCount; 162 | } 163 | 164 | /** 165 | * return the transition scalor. 166 | */ 167 | public double getTransitionScalor() { 168 | return this.transitionScalor; 169 | } 170 | 171 | /** 172 | * return the scaled transition prior distribution. 173 | */ 174 | public TDoubleArrayList getTransitionPrior() { 175 | return this.transitionPrior; 176 | } 177 | 178 | /** 179 | * normalize the prior to be a distribution and then scale it. 180 | */ 181 | public void normalizePrior() { 182 | double norm = 0; 183 | for (int ii = 0; ii < this.transitionPrior.size(); ii++) { 184 | norm += this.transitionPrior.get(ii); 185 | } 186 | for (int ii = 0; ii < this.transitionPrior.size(); ii++) { 187 | double tmp = this.transitionPrior.get(ii) / norm; 188 | tmp *= this.transitionScalor; 189 | this.transitionPrior.set(ii, tmp); 190 | } 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/NonZeroPath.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntDoubleHashMap; 4 | import gnu.trove.TIntIntHashMap; 5 | import gnu.trove.TIntObjectHashMap; 6 | 7 | public class NonZeroPath { 8 | 9 | HIntIntIntHashMap data; 10 | 11 | public NonZeroPath () { 12 | this.data = new HIntIntIntHashMap(); 13 | } 14 | 15 | public void put(int key1, int key2, int value) { 16 | this.data.put(key1, key2, value); 17 | } 18 | 19 | public void get(int key1, int key2) { 20 | this.data.get(key1, key2); 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/Path.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | 5 | /** 6 | * This class defines a path. 7 | * A path is a list of nodes, and the last node emits a word. 8 | * Author: Yuening Hu 9 | */ 10 | public class Path { 11 | 12 | TIntArrayList nodes; 13 | //TIntArrayList children; 14 | int finalWord; 15 | 16 | public Path () { 17 | this.nodes = new TIntArrayList(); 18 | this.finalWord = -1; 19 | } 20 | 21 | /** 22 | * Add nodes to this path. 23 | */ 24 | public void addNodes (TIntArrayList innodes) { 25 | for (int ii = 0; ii < innodes.size(); ii++) { 26 | int node_index = innodes.get(ii); 27 | this.nodes.add(node_index); 28 | } 29 | } 30 | 31 | /** 32 | * Add the final word of this path. 33 | */ 34 | public void addFinalWord(int word) { 35 | this.finalWord = word; 36 | } 37 | 38 | /** 39 | * return the node list. 40 | */ 41 | public TIntArrayList getNodes() { 42 | return this.nodes; 43 | } 44 | 45 | /** 46 | * return the final word. 47 | */ 48 | public int getFinalWord() { 49 | return this.finalWord; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/PriorTree.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntObjectHashMap; 5 | import gnu.trove.TIntObjectIterator; 6 | import gnu.trove.TObjectDoubleHashMap; 7 | 8 | import java.io.BufferedReader; 9 | import java.io.DataInputStream; 10 | import java.io.File; 11 | import java.io.FileInputStream; 12 | import java.io.FileNotFoundException; 13 | import java.io.FilenameFilter; 14 | import java.io.IOException; 15 | import java.io.InputStreamReader; 16 | import java.util.ArrayList; 17 | 18 | import cc.mallet.types.Alphabet; 19 | import cc.mallet.types.InstanceList; 20 | 21 | import topicmod_projects_ldawn.WordnetFile.WordNetFile; 22 | import topicmod_projects_ldawn.WordnetFile.WordNetFile.Synset; 23 | import topicmod_projects_ldawn.WordnetFile.WordNetFile.Synset.Word; 24 | 25 | /** 26 | * This class loads the prior tree structure from the proto buffer files of tree structure. 27 | * Main entrance: initialize() 28 | * Author: Yuening Hu 29 | */ 30 | public class PriorTree { 31 | 32 | int root; 33 | int maxDepth; 34 | 35 | TObjectDoubleHashMap hyperparams; 36 | TIntObjectHashMap nodes; 37 | TIntObjectHashMap> wordPaths; 38 | 39 | public PriorTree () { 40 | this.hyperparams = new TObjectDoubleHashMap (); 41 | this.nodes = new TIntObjectHashMap (); 42 | this.wordPaths = new TIntObjectHashMap> (); 43 | } 44 | 45 | /** 46 | * Get the input tree file lists from the given tree file names 47 | */ 48 | private ArrayList getFileList(String tree_files) { 49 | 50 | int split_index = tree_files.lastIndexOf('/'); 51 | String dirname = tree_files.substring(0, split_index); 52 | String fileprefix = tree_files.substring(split_index+1); 53 | fileprefix = fileprefix.replace("*", ""); 54 | 55 | //System.out.println(dirname); 56 | //System.out.println(fileprefix); 57 | 58 | File dir = new File(dirname); 59 | String[] children = dir.list(); 60 | ArrayList filelist = new ArrayList(); 61 | 62 | for (int i = 0; i < children.length; i++) { 63 | if (children[i].startsWith(fileprefix)) { 64 | System.out.println("Found one: " + dirname + "/" + children[i]); 65 | String filename = dirname + "/" + children[i]; 66 | filelist.add(filename); 67 | } 68 | } 69 | return filelist; 70 | } 71 | 72 | /** 73 | * Load hyper parameters from the given file 74 | */ 75 | private void loadHyperparams(String hyperFile) { 76 | try { 77 | FileInputStream infstream = new FileInputStream(hyperFile); 78 | DataInputStream in = new DataInputStream(infstream); 79 | BufferedReader br = new BufferedReader(new InputStreamReader(in)); 80 | 81 | String strLine; 82 | //Read File Line By Line 83 | while ((strLine = br.readLine()) != null) { 84 | strLine = strLine.trim(); 85 | String[] str = strLine.split(" "); 86 | double tmp = Double.parseDouble(str[1]); 87 | hyperparams.put(str[0], tmp); 88 | } 89 | in.close(); 90 | 91 | // Iterator> it = hyperparams.entrySet().iterator(); 92 | // while (it.hasNext()) { 93 | // Map.Entry entry = it.next(); 94 | // System.out.println(entry.getKey()); 95 | // System.out.println(entry.getValue()); 96 | // } 97 | 98 | } catch (IOException e) { 99 | System.out.println("No hyper file Found!"); 100 | } 101 | } 102 | 103 | /** 104 | * Load tree nodes one by one: load the children, words of each node 105 | */ 106 | private void loadTree(String tree_files, ArrayList vocab) { 107 | 108 | ArrayList filelist = getFileList(tree_files); 109 | 110 | for (int ii = 0; ii < filelist.size(); ii++) { 111 | String filename = filelist.get(ii); 112 | WordNetFile tree = null; 113 | try { 114 | tree = WordNetFile.parseFrom(new FileInputStream(filename)); 115 | } catch (IOException e) { 116 | System.out.println("No file Found!"); 117 | } 118 | 119 | int new_root = tree.getRoot(); 120 | assert( (new_root == -1) || (this.root == -1) || (new_root == this.root)); 121 | if (new_root >= 0) { 122 | this.root = new_root; 123 | } 124 | 125 | for (int jj = 0; jj < tree.getSynsetsCount(); jj++) { 126 | Synset synset = tree.getSynsets(jj); 127 | Node n = new Node(); 128 | n.setOffset(synset.getOffset()); 129 | n.setRawCount(synset.getRawCount()); 130 | n.setHypoCount(synset.getHyponymCount()); 131 | 132 | double transition = hyperparams.get(synset.getHyperparameter()); 133 | n.setTransitionScalor(transition); 134 | for (int cc = 0; cc < synset.getChildrenOffsetsCount(); cc++) { 135 | n.addChildrenOffset(synset.getChildrenOffsets(cc)); 136 | } 137 | 138 | for (int ww = 0; ww < synset.getWordsCount(); ww++) { 139 | Word word = synset.getWords(ww); 140 | int term_id = vocab.indexOf(word.getTermStr()); 141 | //int term_id = vocab.lookupIndex(word.getTermStr()); 142 | double word_count = word.getCount(); 143 | n.addWord(term_id, word_count); 144 | } 145 | 146 | nodes.put(n.getOffset(), n); 147 | } 148 | } 149 | } 150 | 151 | /** 152 | * Get all the paths in the tree, keep the (word, path) pairs 153 | * Note the word in the pair is actually the word of the leaf node 154 | */ 155 | private int searchDepthFirst(int depth, 156 | int node_index, 157 | TIntArrayList traversed, 158 | TIntArrayList next_pointers) { 159 | int max_depth = depth; 160 | traversed.add(node_index); 161 | Node current_node = this.nodes.get(node_index); 162 | current_node.addPaths(1); 163 | 164 | // go over the words that current node emits 165 | for (int ii = 0; ii < current_node.getNumWords(); ii++) { 166 | int word = current_node.getWord(ii); 167 | Path p = new Path(); 168 | p.addNodes(traversed); 169 | // p.addChildren(next_pointers); 170 | p.addFinalWord(word); 171 | if (! this.wordPaths.contains(word)) { 172 | this.wordPaths.put(word, new ArrayList ()); 173 | } 174 | ArrayList tmp = this.wordPaths.get(word); 175 | tmp.add(p); 176 | } 177 | 178 | // go over the child nodes of the current node 179 | for (int ii = 0; ii < current_node.getNumChildren(); ii++) { 180 | int child = current_node.getChild(ii); 181 | next_pointers.add(child); 182 | int child_depth = this.searchDepthFirst(depth+1, child, traversed, next_pointers); 183 | next_pointers.remove(next_pointers.size()-1); 184 | max_depth = max_depth >= child_depth ? max_depth : child_depth; 185 | } 186 | 187 | traversed.remove(traversed.size()-1); 188 | return max_depth; 189 | } 190 | 191 | /** 192 | * Set the scaled prior distribution of each node 193 | * According to the hypoCount of the nodes' children, generate a Multinomial 194 | * distribution, then scaled by transitionScalor 195 | */ 196 | private void setPrior() { 197 | for (TIntObjectIterator it = this.nodes.iterator(); it.hasNext(); ) { 198 | it.advance(); 199 | Node n = it.value(); 200 | int numChildren = n.getNumChildren(); 201 | int numWords = n.getNumWords(); 202 | 203 | // firstly set the hypoCount for each child 204 | if (numChildren > 0) { 205 | assert numWords == 0; 206 | n.initializePrior(numChildren); 207 | for (int ii = 0; ii < numChildren; ii++) { 208 | int child = n.getChild(ii); 209 | n.setPrior(ii, this.nodes.get(child).getHypoCount()); 210 | } 211 | } 212 | 213 | // this step is for tree structures whose leaf nodes contain more than one words 214 | // if the leaf node contains multiple words, we will treat each word 215 | // as a "leaf node" and set a multinomial over all the words 216 | // if the leaf node contains only one word, so this step will be jumped over. 217 | if (numWords > 1) { 218 | assert numChildren == 0; 219 | n.initializePrior(numWords); 220 | for (int ii = 0; ii < numWords; ii++) { 221 | n.setPrior(ii, n.getWordCount(ii)); 222 | } 223 | } 224 | 225 | // then normalize and scale 226 | n.normalizePrior(); 227 | } 228 | } 229 | 230 | /** 231 | * the entrance of this class 232 | */ 233 | public void initialize(String treeFiles, String hyperFile, ArrayList vocab) { 234 | this.loadHyperparams(hyperFile); 235 | this.loadTree(treeFiles, vocab); 236 | 237 | TIntArrayList traversed = new TIntArrayList (); 238 | TIntArrayList next_pointers = new TIntArrayList (); 239 | this.maxDepth = this.searchDepthFirst(0, 0, traversed, next_pointers); 240 | 241 | this.setPrior(); 242 | 243 | //System.out.println("**************************"); 244 | // check the word paths 245 | System.out.println("Numer of words: " + this.wordPaths.size()); 246 | System.out.println("Initialized paths"); 247 | 248 | /* 249 | for (TIntObjectIterator> it = this.wordPaths.iterator(); it.hasNext(); ) { 250 | it.advance(); 251 | ArrayList paths = it.value(); 252 | System.out.print(it.key() + ", " + vocab.get(it.key())); 253 | //System.out.print(it.key() + ", " + vocab.lookupObject(it.key())); 254 | for (int ii = 0; ii < paths.size(); ii++) { 255 | Path p = paths.get(ii); 256 | System.out.print(", Path " + ii); 257 | System.out.print(", Path nodes list: " + p.getNodes()); 258 | System.out.println(", Path final word: " + p.getFinalWord()); 259 | } 260 | } 261 | System.out.println("**************************"); 262 | 263 | // check the prior 264 | System.out.println("Check the prior"); 265 | for (TIntObjectIterator it = this.nodes.iterator(); it.hasNext(); ) { 266 | it.advance(); 267 | if (it.value().getTransitionPrior().size() > 0) { 268 | System.out.print("Node " + it.key()); 269 | System.out.println(", Transition prior " + it.value().getTransitionPrior()); 270 | } 271 | } 272 | System.out.println("**************************"); 273 | */ 274 | 275 | } 276 | 277 | public int getMaxDepth() { 278 | return this.maxDepth; 279 | } 280 | 281 | public int getRoot() { 282 | return this.root; 283 | } 284 | 285 | public TIntObjectHashMap getNodes() { 286 | return this.nodes; 287 | } 288 | 289 | public TIntObjectHashMap> getWordPaths() { 290 | return this.wordPaths; 291 | } 292 | 293 | /** 294 | * Load vocab 295 | */ 296 | public ArrayList readVocab(String vocabFile) { 297 | 298 | ArrayList vocab = new ArrayList (); 299 | 300 | try { 301 | FileInputStream infstream = new FileInputStream(vocabFile); 302 | DataInputStream in = new DataInputStream(infstream); 303 | BufferedReader br = new BufferedReader(new InputStreamReader(in)); 304 | 305 | String strLine; 306 | //Read File Line By Line 307 | while ((strLine = br.readLine()) != null) { 308 | strLine = strLine.trim(); 309 | String[] str = strLine.split("\t"); 310 | if (str.length > 1) { 311 | vocab.add(str[1]); 312 | } else { 313 | System.out.println("Error! " + strLine); 314 | return null; 315 | } 316 | } 317 | in.close(); 318 | 319 | } catch (IOException e) { 320 | System.out.println("No vocab file Found!"); 321 | } 322 | 323 | return vocab; 324 | } 325 | 326 | /** 327 | * test main 328 | */ 329 | public static void main(String[] args) throws Exception{ 330 | 331 | String treeFiles = "../toy/toy_set1.wn.*"; 332 | String hyperFile = "../toy/tree_hyperparams"; 333 | String inputFile = "../input/toy-topic-input.mallet"; 334 | String vocabFile = "../toy/toy.voc"; 335 | 336 | //String treeFiles = "../synthetic/synthetic_set1.wn.*"; 337 | //String hyperFile = "../synthetic/tree_hyperparams"; 338 | //String inputFile = "../input/synthetic-topic-input.mallet"; 339 | //String vocabFile = "../synthetic/synthetic.voc"; 340 | 341 | PriorTree tree = new PriorTree(); 342 | ArrayList vocab = tree.readVocab(vocabFile); 343 | 344 | InstanceList ilist = InstanceList.load (new File(inputFile)); 345 | tree.initialize(treeFiles, hyperFile, vocab); 346 | } 347 | 348 | } 349 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TopicSampler.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntHashSet; 6 | import gnu.trove.TIntIntHashMap; 7 | import gnu.trove.TIntIntIterator; 8 | 9 | import java.io.BufferedReader; 10 | import java.io.DataInputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.IOException; 14 | import java.io.InputStreamReader; 15 | import java.io.PrintStream; 16 | import java.io.Serializable; 17 | import java.util.ArrayList; 18 | import java.util.Arrays; 19 | 20 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 21 | import cc.mallet.types.Dirichlet; 22 | import cc.mallet.types.FeatureSequence; 23 | import cc.mallet.types.Instance; 24 | import cc.mallet.types.InstanceList; 25 | import cc.mallet.util.Randoms; 26 | 27 | public abstract class TopicSampler{ 28 | 29 | int numTopics; // Number of topics to be fit 30 | int numIterations; 31 | int startIter; 32 | Randoms random; 33 | double[] alpha; 34 | double alphaSum; 35 | TDoubleArrayList lhood; 36 | TDoubleArrayList iterTime; 37 | ArrayList vocab; 38 | 39 | TreeTopicModel topics; 40 | TIntHashSet cons; 41 | 42 | public TopicSampler (int numberOfTopics, double alphaSum, int seed) { 43 | this.numTopics = numberOfTopics; 44 | this.random = new Randoms(seed); 45 | 46 | this.alphaSum = alphaSum; 47 | this.alpha = new double[numTopics]; 48 | Arrays.fill(alpha, alphaSum / numTopics); 49 | 50 | this.vocab = new ArrayList (); 51 | this.cons = new TIntHashSet(); 52 | 53 | this.lhood = new TDoubleArrayList(); 54 | this.iterTime = new TDoubleArrayList(); 55 | this.startIter = 0; 56 | 57 | // notice: this.topics and this.data are not initialized in this abstract class, 58 | // in each sub class, the topics variable is initialized differently. 59 | } 60 | 61 | 62 | 63 | public void setNumIterations(int iters) { 64 | this.numIterations = iters; 65 | } 66 | 67 | public int getNumIterations() { 68 | return this.numIterations; 69 | } 70 | 71 | 72 | 73 | /** 74 | * This function returns the likelihood. 75 | */ 76 | public double lhood() { 77 | return this.docLHood() + this.topics.topicLHood(); 78 | } 79 | 80 | /** 81 | * Resume lhood and iterTime from the saved lhood file. 82 | */ 83 | public void resumeLHood(String lhoodFile) throws IOException{ 84 | FileInputStream lhoodfstream = new FileInputStream(lhoodFile); 85 | DataInputStream lhooddstream = new DataInputStream(lhoodfstream); 86 | BufferedReader brLHood = new BufferedReader(new InputStreamReader(lhooddstream)); 87 | // the first line is the title 88 | String strLine = brLHood.readLine(); 89 | while ((strLine = brLHood.readLine()) != null) { 90 | strLine = strLine.trim(); 91 | String[] str = strLine.split("\t"); 92 | // iteration, likelihood, iter_time 93 | myAssert(str.length == 3, "lhood file problem!"); 94 | this.lhood.add(Double.parseDouble(str[1])); 95 | this.iterTime.add(Double.parseDouble(str[2])); 96 | } 97 | this.startIter = this.lhood.size(); 98 | if (this.startIter > this.numIterations) { 99 | System.out.println("Have already sampled " + this.numIterations + " iterations!"); 100 | System.exit(0); 101 | } 102 | System.out.println("Start sampling for iteration " + this.startIter); 103 | brLHood.close(); 104 | } 105 | 106 | /** 107 | * Resumes from the saved files. 108 | */ 109 | public void resume(InstanceList training, String resumeDir) { 110 | try { 111 | String statesFile = resumeDir + ".states"; 112 | resumeStates(training, statesFile); 113 | 114 | String lhoodFile = resumeDir + ".lhood"; 115 | resumeLHood(lhoodFile); 116 | } catch (IOException e) { 117 | System.out.println(e.getMessage()); 118 | } 119 | } 120 | 121 | /** 122 | * This function prints the topic words of each topic. 123 | */ 124 | public void printTopWords(File file, int numWords) throws IOException { 125 | PrintStream out = new PrintStream (file); 126 | out.print(displayTopWords(numWords)); 127 | out.close(); 128 | } 129 | 130 | /** 131 | * By implementing the comparable interface, this function ranks the words 132 | * in each topic, and returns the top words for each topic. 133 | */ 134 | public String displayTopWords (int numWords) { 135 | 136 | class WordProb implements Comparable { 137 | int wi; 138 | double p; 139 | public WordProb (int wi, double p) { this.wi = wi; this.p = p; } 140 | public final int compareTo (Object o2) { 141 | if (p > ((WordProb)o2).p) 142 | return -1; 143 | else if (p == ((WordProb)o2).p) 144 | return 0; 145 | else return 1; 146 | } 147 | } 148 | 149 | StringBuilder out = new StringBuilder(); 150 | int numPaths = this.topics.getPathNum(); 151 | //System.out.println(numPaths); 152 | 153 | for (int tt = 0; tt < this.numTopics; tt++){ 154 | String tmp = "\n--------------\nTopic " + tt + "\n------------------------\n"; 155 | //System.out.print(tmp); 156 | out.append(tmp); 157 | WordProb[] wp = new WordProb[numPaths]; 158 | for (int pp = 0; pp < numPaths; pp++){ 159 | int ww = this.topics.getWordFromPath(pp); 160 | double val = this.topics.computeTopicPathProb(tt, ww, pp); 161 | wp[pp] = new WordProb(pp, val); 162 | } 163 | Arrays.sort(wp); 164 | for (int ii = 0; ii < wp.length; ii++){ 165 | int pp = wp[ii].wi; 166 | int ww = this.topics.getWordFromPath(pp); 167 | //tmp = wp[ii].p + "\t" + this.vocab.lookupObject(ww) + "\n"; 168 | tmp = wp[ii].p + "\t" + this.vocab.get(ww) + "\n"; 169 | //System.out.print(tmp); 170 | out.append(tmp); 171 | if(ii > numWords) { 172 | break; 173 | } 174 | } 175 | } 176 | return out.toString(); 177 | } 178 | 179 | /** 180 | * Prints likelihood and iter time. 181 | */ 182 | public void printStats (File file) throws IOException { 183 | PrintStream out = new PrintStream (file); 184 | String tmp = "Iteration\t\tlikelihood\titer_time\n"; 185 | out.print(tmp); 186 | for (int iter = 0; iter < this.lhood.size(); iter++) { 187 | tmp = iter + "\t" + this.lhood.get(iter) + "\t" + this.iterTime.get(iter); 188 | out.println(tmp); 189 | } 190 | out.close(); 191 | } 192 | 193 | public void loadVocab(String vocabFile) { 194 | 195 | try { 196 | FileInputStream infstream = new FileInputStream(vocabFile); 197 | DataInputStream in = new DataInputStream(infstream); 198 | BufferedReader br = new BufferedReader(new InputStreamReader(in)); 199 | 200 | String strLine; 201 | //Read File Line By Line 202 | while ((strLine = br.readLine()) != null) { 203 | strLine = strLine.trim(); 204 | String[] str = strLine.split("\t"); 205 | if (str.length > 1) { 206 | this.vocab.add(str[1]); 207 | } else { 208 | System.out.println("Error! " + strLine); 209 | } 210 | } 211 | in.close(); 212 | 213 | } catch (IOException e) { 214 | System.out.println("No vocab file Found!"); 215 | } 216 | 217 | } 218 | 219 | /** 220 | * Load constraints 221 | */ 222 | public void loadConstraints(String consFile) { 223 | try { 224 | FileInputStream infstream = new FileInputStream(consFile); 225 | DataInputStream in = new DataInputStream(infstream); 226 | BufferedReader br = new BufferedReader(new InputStreamReader(in)); 227 | 228 | String strLine; 229 | //Read File Line By Line 230 | while ((strLine = br.readLine()) != null) { 231 | strLine = strLine.trim(); 232 | String[] str = strLine.split("\t"); 233 | if (str.length > 1) { 234 | // str[0] is either "MERGE_" or "SPLIT_", not a word 235 | for(int ii = 1; ii < str.length; ii++) { 236 | int word = this.vocab.indexOf(str[ii]); 237 | myAssert(word >= 0, "Constraint words not found in vocab: " + str[ii]); 238 | cons.add(word); 239 | } 240 | this.vocab.add(str[1]); 241 | } else { 242 | System.out.println("Error! " + strLine); 243 | } 244 | } 245 | in.close(); 246 | 247 | } catch (IOException e) { 248 | System.out.println("No vocab file Found!"); 249 | } 250 | 251 | } 252 | 253 | /** 254 | * For testing~~ 255 | */ 256 | public static void myAssert(boolean flag, String info) { 257 | if(!flag) { 258 | System.out.println(info); 259 | System.exit(0); 260 | } 261 | } 262 | 263 | abstract void addInstances(InstanceList training); 264 | abstract void resumeStates(InstanceList training, String statesFile) throws IOException; 265 | abstract void clearTopicAssignments(String option, String consFile); 266 | abstract void changeTopic(int doc, int index, int word, int new_topic, int new_path); 267 | abstract double docLHood(); 268 | abstract void printDocumentTopics (File file) throws IOException; 269 | abstract void sampleDoc(int doc); 270 | } 271 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TopicTreeWalk.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntHashSet; 5 | import gnu.trove.TIntIntHashMap; 6 | 7 | /** 8 | * This class counts each node and each edge for a topic with tree structure. 9 | * Author: Yuening Hu 10 | */ 11 | public class TopicTreeWalk { 12 | 13 | // *** To be sorted 14 | HIntIntIntHashMap counts; 15 | TIntIntHashMap nodeCounts; 16 | 17 | public TopicTreeWalk() { 18 | this.counts = new HIntIntIntHashMap(); 19 | this.nodeCounts = new TIntIntHashMap(); 20 | } 21 | 22 | /** 23 | * Given a path (a list of nodes), increase the nodes and edges counts by 24 | * the specified amount. When a node count is changed from zero or changed 25 | * to zero, return this node. (When this happens, the non-zero path of this 26 | * node might need to be changed, that's why we need this list.) 27 | */ 28 | public int[] changeCount(TIntArrayList path_nodes, int increment) { 29 | for (int nn = 0; nn < path_nodes.size()-1; nn++) { 30 | int parent = path_nodes.get(nn); 31 | int child = path_nodes.get(nn+1); 32 | this.counts.adjustOrPutValue(parent, child, increment, increment); 33 | } 34 | 35 | // keep the nodes whose counts is changed from zero or changed to zero 36 | TIntHashSet affected_nodes = new TIntHashSet(); 37 | 38 | for (int nn = 0; nn < path_nodes.size(); nn++) { 39 | int node = path_nodes.get(nn); 40 | if (! this.nodeCounts.contains(node)) { 41 | this.nodeCounts.put(node, 0); 42 | } 43 | 44 | int old_count = this.nodeCounts.get(node); 45 | this.nodeCounts.adjustValue(node, increment); 46 | int new_count = this.nodeCounts.get(node); 47 | 48 | // keep the nodes whose counts is changed from zero or changed to zero 49 | if (nn != 0 && (old_count == 0 || new_count == 0)) { 50 | affected_nodes.add(node); 51 | } 52 | } 53 | 54 | if (affected_nodes.size() > 0) { 55 | return affected_nodes.toArray(); 56 | } else { 57 | return null; 58 | } 59 | } 60 | 61 | /** 62 | * Return an edge count. 63 | */ 64 | public int getCount(int key1, int key2) { 65 | if (this.counts.contains(key1, key2)) { 66 | return this.counts.get(key1, key2); 67 | } 68 | return 0; 69 | } 70 | 71 | /** 72 | * Return a node count. 73 | */ 74 | public int getNodeCount(int key) { 75 | if (this.nodeCounts.contains(key)) { 76 | return this.nodeCounts.get(key); 77 | } 78 | return 0; 79 | } 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModel.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntDoubleHashMap; 6 | import gnu.trove.TIntDoubleIterator; 7 | import gnu.trove.TIntHashSet; 8 | import gnu.trove.TIntIntHashMap; 9 | import gnu.trove.TIntObjectHashMap; 10 | import gnu.trove.TIntObjectIterator; 11 | 12 | import java.util.ArrayList; 13 | import java.util.Random; 14 | import java.util.TreeMap; 15 | 16 | import cc.mallet.types.Dirichlet; 17 | 18 | /** 19 | * This class defines the tree topic model. 20 | * It implements most of the functions and leave four abstract methods, 21 | * which might be various for different models. 22 | * Author: Yuening Hu 23 | */ 24 | public abstract class TreeTopicModel { 25 | 26 | int numTopics; 27 | Random random; 28 | int maxDepth; 29 | int root; 30 | HIntIntObjectHashMap wordPaths; 31 | TIntArrayList pathToWord; 32 | TIntObjectHashMap nodeToPath; 33 | 34 | TIntDoubleHashMap betaSum; 35 | HIntIntDoubleHashMap beta; // 2 levels hash map 36 | //HIntIntObjectHashMap beta; 37 | TIntDoubleHashMap priorSum; 38 | HIntIntDoubleHashMap priorPath; 39 | //HIntIntObjectHashMap priorPath; 40 | 41 | TIntObjectHashMap nonZeroPaths; 42 | TIntObjectHashMap> nonZeroPathsSorted; 43 | TIntObjectHashMap> nonZeroPathsBubbleSorted; 44 | TIntObjectHashMap traversals; 45 | 46 | HIntIntDoubleHashMap normalizer; 47 | //HIntIntObjectHashMap normalizer; 48 | TIntDoubleHashMap rootNormalizer; 49 | TIntDoubleHashMap smoothingEst; 50 | 51 | public TreeTopicModel(int numTopics, Random random) { 52 | this.numTopics = numTopics; 53 | this.random = random; 54 | 55 | this.betaSum = new TIntDoubleHashMap (); 56 | this.beta = new HIntIntDoubleHashMap (); 57 | //this.beta = new HIntIntObjectHashMap (); 58 | this.priorSum = new TIntDoubleHashMap (); 59 | this.priorPath = new HIntIntDoubleHashMap (); 60 | //this.priorPath = new HIntIntObjectHashMap (); 61 | 62 | this.wordPaths = new HIntIntObjectHashMap (); 63 | this.pathToWord = new TIntArrayList (); 64 | this.nodeToPath = new TIntObjectHashMap (); 65 | 66 | this.nonZeroPaths = new TIntObjectHashMap (); 67 | this.nonZeroPathsSorted = new TIntObjectHashMap> (); 68 | this.nonZeroPathsBubbleSorted = new TIntObjectHashMap> (); 69 | this.traversals = new TIntObjectHashMap (); 70 | } 71 | 72 | /** 73 | * Initialize the parameters, including: 74 | * (1) loading the tree 75 | * (2) initialize betaSum and beta 76 | * (3) initialize priorSum, priorPath 77 | * (4) initialize wordPaths, pathToWord, NodetoPath 78 | * (5) initialize traversals 79 | * (6) initialize nonZeroPaths 80 | */ 81 | protected void initializeParams(String treeFiles, String hyperFile, ArrayList vocab) { 82 | 83 | PriorTree tree = new PriorTree(); 84 | tree.initialize(treeFiles, hyperFile, vocab); 85 | 86 | // get tree depth 87 | this.maxDepth = tree.getMaxDepth(); 88 | // get root index 89 | this.root = tree.getRoot(); 90 | // get tree nodes 91 | TIntObjectHashMap nodes = tree.getNodes(); 92 | // get tree paths 93 | TIntObjectHashMap> word_paths = tree.getWordPaths(); 94 | 95 | // if one node contains multiple words, we need to change each word to a leaf node 96 | // (assigning a leaf index for each word). 97 | int leaf_index = nodes.size(); 98 | HIntIntIntHashMap tmp_wordleaf = new HIntIntIntHashMap(); 99 | 100 | // initialize betaSum and beta 101 | for (TIntObjectIterator it = nodes.iterator(); it.hasNext(); ) { 102 | it.advance(); 103 | int index = it.key(); 104 | Node node = it.value(); 105 | TDoubleArrayList transition_prior = node.getTransitionPrior(); 106 | 107 | // when node has children 108 | if (node.getNumChildren() > 0) { 109 | //assert node.getNumWords() == 0; 110 | this.betaSum.put(index, node.getTransitionScalor()); 111 | for (int ii = 0; ii < node.getNumChildren(); ii++) { 112 | int child = node.getChild(ii); 113 | this.beta.put(index, child, transition_prior.get(ii)); 114 | } 115 | } 116 | 117 | // when node contains multiple words. 118 | // we change a node containing multiple words to a node containing multiple 119 | // leaf node and each leaf node containing one word 120 | if (node.getNumWords() > 1) { 121 | //assert node.getNumChildren() == 0; 122 | this.betaSum.put(index, node.getTransitionScalor()); 123 | for (int ii = 0; ii < node.getNumWords(); ii++) { 124 | int word = node.getWord(ii); 125 | leaf_index++; 126 | this.beta.put(index, leaf_index, transition_prior.get(ii)); 127 | 128 | // one word might have multiple paths, 129 | // so we keep the (word_index, word_parent) 130 | // as the index for this leaf index, which is needed later 131 | tmp_wordleaf.put(word, index, leaf_index); 132 | } 133 | } 134 | } 135 | 136 | // initialize priorSum, priorPath 137 | // initialize wordPaths, pathToWord, NodetoPath 138 | int path_index = -1; 139 | TIntObjectHashMap tmp_nodeToPath = new TIntObjectHashMap(); 140 | for (TIntObjectIterator> it = word_paths.iterator(); it.hasNext(); ) { 141 | it.advance(); 142 | 143 | int word = it.key(); 144 | ArrayList paths = it.value(); 145 | this.priorSum.put(word, 0.0); 146 | 147 | for (int ii = 0; ii < paths.size(); ii++) { 148 | path_index++; 149 | this.pathToWord.add(word); 150 | 151 | double prob = 1.0; 152 | Path p = paths.get(ii); 153 | TIntArrayList path_nodes = p.getNodes(); 154 | 155 | // for a node that contains multiple words 156 | // if yes, retrieve the leaf index for each word 157 | // and that to nodes of path 158 | int parent = path_nodes.get(path_nodes.size()-1); 159 | if (tmp_wordleaf.contains(word, parent)) { 160 | leaf_index = tmp_wordleaf.get(word, parent); 161 | path_nodes.add(leaf_index); 162 | } 163 | 164 | for (int nn = 0; nn < path_nodes.size() - 1; nn++) { 165 | parent = path_nodes.get(nn); 166 | int child = path_nodes.get(nn+1); 167 | prob *= this.beta.get(parent, child); 168 | } 169 | 170 | for (int nn = 0; nn < path_nodes.size(); nn++) { 171 | int node = path_nodes.get(nn); 172 | if (! tmp_nodeToPath.contains(node)) { 173 | tmp_nodeToPath.put(node, new TIntHashSet()); 174 | } 175 | tmp_nodeToPath.get(node).add(path_index); 176 | } 177 | 178 | this.priorPath.put(word, path_index, prob); 179 | this.priorSum.adjustValue(word, prob); 180 | this.wordPaths.put(word, path_index, path_nodes); 181 | } 182 | } 183 | 184 | // change tmp_nodeToPath to this.nodeToPath 185 | // this is because arraylist is much more efficient than hashset, when we 186 | // need to go over the whole set multiple times 187 | for(TIntObjectIterator it = tmp_nodeToPath.iterator(); it.hasNext(); ) { 188 | it.advance(); 189 | int node = it.key(); 190 | TIntHashSet paths = (TIntHashSet)it.value(); 191 | TIntArrayList tmp = new TIntArrayList(paths.toArray()); 192 | 193 | // System.out.println("Node" + node); 194 | // for(int ii = 0; ii < tmp.size(); ii++) { 195 | // System.out.print(tmp.get(ii) + " "); 196 | // } 197 | // System.out.println(""); 198 | 199 | this.nodeToPath.put(node, tmp); 200 | } 201 | 202 | // initialize traversals 203 | for (int tt = 0; tt < this.numTopics; tt++) { 204 | TopicTreeWalk tw = new TopicTreeWalk(); 205 | this.traversals.put(tt, tw); 206 | } 207 | 208 | // initialize nonZeroPaths 209 | int[] words = this.wordPaths.getKey1Set(); 210 | for (int ww = 0; ww < words.length; ww++) { 211 | int word = words[ww]; 212 | this.nonZeroPaths.put(word, new HIntIntIntHashMap()); 213 | } 214 | } 215 | 216 | /** 217 | * This function samples a path based on the prior 218 | * and change the node and edge count for a topic. 219 | */ 220 | protected int initialize (int word, int topic) { 221 | double sample = this.random.nextDouble(); 222 | int path_index = this.samplePathFromPrior(word, sample); 223 | this.changeCountOnly(topic, word, path_index, 1); 224 | return path_index; 225 | } 226 | 227 | /** 228 | * This function changes the node and edge count for a topic. 229 | */ 230 | protected void changeCountOnly(int topic, int word, int path, int delta) { 231 | TIntArrayList path_nodes = this.wordPaths.get(word, path); 232 | TopicTreeWalk tw = this.traversals.get(topic); 233 | tw.changeCount(path_nodes, delta); 234 | } 235 | 236 | /** 237 | * This function samples a path from the prior. 238 | */ 239 | protected int samplePathFromPrior(int term, double sample) { 240 | int sampled_path = -1; 241 | sample *= this.priorSum.get(term); 242 | TIntDoubleHashMap paths = this.priorPath.get(term); 243 | for(TIntDoubleIterator it = paths.iterator(); it.hasNext(); ) { 244 | it.advance(); 245 | sample -= it.value(); 246 | if (sample <= 0.0) { 247 | sampled_path = it.key(); 248 | break; 249 | } 250 | } 251 | 252 | return sampled_path; 253 | } 254 | 255 | /** 256 | * This function computes a path probability in a topic. 257 | */ 258 | public double computeTopicPathProb(int topic, int word, int path_index) { 259 | TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 260 | TopicTreeWalk tw = this.traversals.get(topic); 261 | double val = 1.0; 262 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 263 | int parent = path_nodes.get(ii); 264 | int child = path_nodes.get(ii+1); 265 | val *= this.beta.get(parent, child) + tw.getCount(parent, child); 266 | val /= this.betaSum.get(parent) + tw.getNodeCount(parent); 267 | } 268 | return val; 269 | } 270 | 271 | /** 272 | * This function computes the topic likelihood (by node). 273 | */ 274 | public double topicLHood() { 275 | double val = 0.0; 276 | for (int tt = 0; tt < this.numTopics; tt++) { 277 | for (int nn : this.betaSum.keys()) { 278 | double beta_sum = this.betaSum.get(nn); 279 | //val += Dirichlet.logGamma(beta_sum) * this.beta.get(nn).size(); 280 | val += Dirichlet.logGamma(beta_sum); 281 | 282 | double tmp = 0.0; 283 | for (int cc : this.beta.get(nn).keys()) { 284 | tmp += Dirichlet.logGamma(this.beta.get(nn, cc)); 285 | } 286 | //val -= tmp * this.beta.get(nn).size(); 287 | val -= tmp; 288 | 289 | for (int cc : this.beta.get(nn).keys()) { 290 | int count = this.traversals.get(tt).getCount(nn, cc); 291 | val += Dirichlet.logGamma(this.beta.get(nn, cc) + count); 292 | } 293 | 294 | int count = this.traversals.get(tt).getNodeCount(nn); 295 | val -= Dirichlet.logGamma(beta_sum + count); 296 | } 297 | //System.out.println("likelihood " + val); 298 | } 299 | return val; 300 | } 301 | 302 | public TIntObjectHashMap getPaths(int word) { 303 | return this.wordPaths.get(word); 304 | } 305 | 306 | public int[] getWordPathIndexSet(int word) { 307 | return this.wordPaths.get(word).keys(); 308 | } 309 | 310 | public int getPathNum() { 311 | return this.pathToWord.size(); 312 | } 313 | 314 | public int getWordFromPath(int pp) { 315 | return this.pathToWord.get(pp); 316 | } 317 | 318 | public double getPathPrior(int word, int path) { 319 | return this.priorPath.get(word, path); 320 | } 321 | 322 | // for TreeTopicSamplerFast 323 | public double computeTermSmoothing(double[] alpha, int word) { 324 | return 0; 325 | } 326 | 327 | public double computeTermTopicBeta(TIntIntHashMap topic_counts, int word) { 328 | return 0; 329 | } 330 | 331 | public double computeTopicTermTest(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict){ 332 | return 0; 333 | } 334 | 335 | public double computeTermTopicBetaSortD(ArrayList topicCounts, int word) { 336 | return 0; 337 | } 338 | 339 | public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict){ 340 | return 0; 341 | } 342 | 343 | // shared methods 344 | abstract double getNormalizer(int topic, int path); 345 | abstract void updateParams(); 346 | abstract void changeCount(int topic, int word, int path_index, int delta); 347 | abstract double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict); 348 | 349 | 350 | 351 | 352 | } 353 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelFast.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntDoubleHashMap; 5 | import gnu.trove.TIntHashSet; 6 | import gnu.trove.TIntIntHashMap; 7 | import gnu.trove.TIntIterator; 8 | import gnu.trove.TIntObjectIterator; 9 | 10 | import java.util.ArrayList; 11 | import java.util.Random; 12 | 13 | /** 14 | * This class extends the tree topic model 15 | * It implemented the four abstract methods in a faster way: 16 | * (1) normalizer is stored and updated accordingly 17 | * (2) normalizer is split to two parts: root normalizer and normalizer to save computation 18 | * (3) non-zero-paths are stored so when we compute the topic term score, we only compute 19 | * the non-zero paths 20 | * Author: Yuening Hu 21 | */ 22 | public class TreeTopicModelFast extends TreeTopicModel { 23 | 24 | int INTBITS = 31; 25 | 26 | /** 27 | * The normalizer is split to two parts: root normalizer and normalizer 28 | * root normalizer is stored per topic, and normalizer is stored per path per topic 29 | * both normalizers are updated when the count is changing. 30 | */ 31 | public TreeTopicModelFast(int numTopics, Random random) { 32 | super(numTopics, random); 33 | this.normalizer = new HIntIntDoubleHashMap (); 34 | //this.normalizer = new HIntIntObjectHashMap (); 35 | this.rootNormalizer = new TIntDoubleHashMap (); 36 | } 37 | 38 | /** 39 | * This function updates the real count with the path masked count. 40 | * The format is: using the first Tree_depth number of bits of an integer 41 | * to denote whether a node in path has count larger than zero, 42 | * and plus the real count. 43 | * If a node path is shorter than Tree_depth, use "1" to fill the remained part. 44 | */ 45 | protected void updatePathMaskedCount(int path, int topic) { 46 | TopicTreeWalk tw = this.traversals.get(topic); 47 | int ww = this.getWordFromPath(path); 48 | TIntArrayList path_nodes = this.wordPaths.get(ww, path); 49 | int leaf_node = path_nodes.get(path_nodes.size() - 1); 50 | int original_count = tw.getNodeCount(leaf_node); 51 | 52 | int shift_count = this.INTBITS; 53 | int count = this.maxDepth - 1; 54 | int val = 0; 55 | boolean flag = false; 56 | 57 | // note root is not included here 58 | // if count of a node in the path is larger than 0, denote as "1" 59 | // else use "0" 60 | for(int nn = 1; nn < path_nodes.size(); nn++) { 61 | int node = path_nodes.get(nn); 62 | shift_count--; 63 | count--; 64 | if (tw.getNodeCount(node) > 0) { 65 | flag = true; 66 | val += 1 << shift_count; 67 | } 68 | } 69 | 70 | // if a path is shorter than tree depth, fill in "1" 71 | // should we fit in "0" ??? 72 | while (flag && count > 0) { 73 | shift_count--; 74 | val += 1 << shift_count; 75 | count--; 76 | } 77 | 78 | int maskedpath = val; 79 | // plus the original count 80 | val += original_count; 81 | if (val > 0) { 82 | this.nonZeroPaths.get(ww).put(topic, path, val); 83 | } else if (val == 0) { 84 | if (this.nonZeroPaths.get(ww).get(topic) != null) { 85 | this.nonZeroPaths.get(ww).removeKey2(topic, path); 86 | if (this.nonZeroPaths.get(ww).get(topic).size() == 0) { 87 | this.nonZeroPaths.get(ww).removeKey1(topic); 88 | } 89 | } 90 | } 91 | 92 | // int shift = this.INTBITS - this.maxDepth - 1; 93 | // int testmaskedpath = val >> shift; 94 | // maskedpath = maskedpath >> shift; 95 | // int testcount = val - (testmaskedpath << shift); 96 | // System.out.println(maskedpath + " " + testmaskedpath + " " + original_count + " " + testcount); 97 | 98 | //System.out.println(original_count + " " + this.nonZeroPaths.get(ww).get(topic, path)); 99 | } 100 | 101 | /** 102 | * Compute the root normalizer and the normalizer per topic per path 103 | */ 104 | protected void computeNormalizer(int topic) { 105 | TopicTreeWalk tw = this.traversals.get(topic); 106 | double val = this.betaSum.get(root) + tw.getNodeCount(root); 107 | this.rootNormalizer.put(topic, val); 108 | //System.out.println("Topic " + topic + " root normalizer " + this.rootNormalizer.get(topic)); 109 | 110 | for(int pp = 0; pp < this.getPathNum(); pp++) { 111 | int ww = this.getWordFromPath(pp); 112 | val = this.computeNormalizerPath(topic, ww, pp); 113 | this.normalizer.put(topic, pp, val); 114 | //System.out.println("Topic " + topic + " Path " + pp + " normalizer " + this.normalizer.get(topic, pp)); 115 | } 116 | } 117 | 118 | /** 119 | * Compute the the normalizer given a path and a topic. 120 | */ 121 | private double computeNormalizerPath(int topic, int word, int path) { 122 | TopicTreeWalk tw = this.traversals.get(topic); 123 | TIntArrayList path_nodes = this.wordPaths.get(word, path); 124 | 125 | double norm = 1.0; 126 | // do not consider the root 127 | for (int nn = 1; nn < path_nodes.size() - 1; nn++) { 128 | int node = path_nodes.get(nn); 129 | norm *= this.betaSum.get(node) + tw.getNodeCount(node); 130 | } 131 | return norm; 132 | } 133 | 134 | /** 135 | * Compute the root normalizer and the normalizer per topic per path. 136 | */ 137 | protected int[] findAffectedPaths(int[] nodes) { 138 | TIntHashSet affected = new TIntHashSet(); 139 | for(int ii = 0; ii < nodes.length; ii++) { 140 | int node = nodes[ii]; 141 | TIntArrayList paths = this.nodeToPath.get(node); 142 | for (int jj = 0; jj < paths.size(); jj++) { 143 | int pp = paths.get(jj); 144 | affected.add(pp); 145 | } 146 | } 147 | return affected.toArray(); 148 | } 149 | 150 | /** 151 | * Updates a list of paths with the given amount. 152 | */ 153 | protected void updateNormalizer(int topic, TIntArrayList paths, double delta) { 154 | for (int ii = 0; ii < paths.size(); ii++) { 155 | int pp = paths.get(ii); 156 | double val = this.normalizer.get(topic, pp); 157 | val *= delta; 158 | this.normalizer.put(topic, pp, val); 159 | } 160 | } 161 | 162 | /** 163 | * Computes the observation part. 164 | */ 165 | protected double getObservation(int topic, int word, int path_index) { 166 | TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 167 | TopicTreeWalk tw = this.traversals.get(topic); 168 | double val = 1.0; 169 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 170 | int parent = path_nodes.get(ii); 171 | int child = path_nodes.get(ii+1); 172 | val *= this.beta.get(parent, child) + tw.getCount(parent, child); 173 | } 174 | val -= this.priorPath.get(word, path_index); 175 | return val; 176 | } 177 | 178 | /** 179 | * After adding instances, update the parameters. 180 | */ 181 | public void updateParams() { 182 | for(int tt = 0; tt < this.numTopics; tt++) { 183 | for(int pp = 0; pp < this.getPathNum(); pp++) { 184 | this.updatePathMaskedCount(pp, tt); 185 | } 186 | this.computeNormalizer(tt); 187 | } 188 | } 189 | 190 | /** 191 | * This function updates the count given the topic and path of a word. 192 | */ 193 | public void changeCount(int topic, int word, int path_index, int delta) { 194 | TopicTreeWalk tw = this.traversals.get(topic); 195 | TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 196 | 197 | // for affected paths, firstly remove the old values 198 | // do not consider the root 199 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 200 | int node = path_nodes.get(nn); 201 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 202 | tmp = 1 / tmp; 203 | TIntArrayList paths = this.nodeToPath.get(node); 204 | updateNormalizer(topic, paths, tmp); 205 | } 206 | 207 | // change the count for each edge per topic 208 | // return the node index whose count is changed from 0 or to 0 209 | int[] affected_nodes = tw.changeCount(path_nodes, delta); 210 | // change path count 211 | if (delta > 0) { 212 | this.nonZeroPaths.get(word).adjustOrPutValue(topic, path_index, delta, delta); 213 | } else { 214 | this.nonZeroPaths.get(word).adjustValue(topic, path_index, delta); 215 | } 216 | 217 | // if necessary, change the path mask of the affected nodes 218 | if (affected_nodes != null && affected_nodes.length > 0) { 219 | int[] affected_paths = this.findAffectedPaths(affected_nodes); 220 | for(int ii = 0; ii < affected_paths.length; ii++) { 221 | this.updatePathMaskedCount(affected_paths[ii], topic); 222 | } 223 | } 224 | 225 | // for affected paths, update the normalizer 226 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 227 | int node = path_nodes.get(nn); 228 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 229 | TIntArrayList paths = this.nodeToPath.get(node); 230 | updateNormalizer(topic, paths, tmp); 231 | } 232 | 233 | // update the root normalizer 234 | double val = this.betaSum.get(root) + tw.getNodeCount(root); 235 | this.rootNormalizer.put(topic, val); 236 | } 237 | 238 | /** 239 | * This function returns the real normalizer. 240 | */ 241 | public double getNormalizer(int topic, int path) { 242 | return this.normalizer.get(topic, path) * this.rootNormalizer.get(topic); 243 | } 244 | 245 | /** 246 | * This function computes the smoothing bucket for a word. 247 | */ 248 | public double computeTermSmoothing(double[] alpha, int word) { 249 | double smoothing = 0.0; 250 | int[] paths = this.getWordPathIndexSet(word); 251 | 252 | for(int tt = 0; tt < this.numTopics; tt++) { 253 | for(int pp : paths) { 254 | double val = alpha[tt] * this.getPathPrior(word, pp); 255 | val /= this.getNormalizer(tt, pp); 256 | smoothing += val; 257 | } 258 | } 259 | //myAssert(smoothing > 0, "something wrong with smoothing!"); 260 | return smoothing; 261 | } 262 | 263 | /** 264 | * This function computes the topic beta bucket. 265 | */ 266 | public double computeTermTopicBeta(TIntIntHashMap topic_counts, int word) { 267 | double topic_beta = 0.0; 268 | int[] paths = this.getWordPathIndexSet(word); 269 | for(int tt : topic_counts.keys()) { 270 | if (topic_counts.get(tt) > 0 ) { 271 | for (int pp : paths) { 272 | double val = topic_counts.get(tt) * this.getPathPrior(word, pp); 273 | val /= this.getNormalizer(tt, pp); 274 | topic_beta += val; 275 | } 276 | } 277 | } 278 | //myAssert(topic_beta > 0, "something wrong with topic_beta!"); 279 | return topic_beta; 280 | } 281 | 282 | /** 283 | * This function computes the topic beta bucket. 284 | */ 285 | public double computeTermTopicBetaSortD(ArrayList topic_counts, int word) { 286 | double topic_beta = 0.0; 287 | int[] paths = this.getWordPathIndexSet(word); 288 | for(int ii = 0; ii < topic_counts.size(); ii++) { 289 | int[] current = topic_counts.get(ii); 290 | int tt = current[0]; 291 | int count = current[1]; 292 | if (count > 0 ) { 293 | for (int pp : paths) { 294 | double val = count * this.getPathPrior(word, pp); 295 | val /= this.getNormalizer(tt, pp); 296 | topic_beta += val; 297 | } 298 | } 299 | } 300 | //myAssert(topic_beta > 0, "something wrong with topic_beta!"); 301 | return topic_beta; 302 | } 303 | 304 | /** 305 | * This function computes the topic term bucket. 306 | */ 307 | public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { 308 | double norm = 0.0; 309 | HIntIntIntHashMap nonzeros = this.nonZeroPaths.get(word); 310 | 311 | // Notice only the nonzero paths are considered 312 | //for(int tt = 0; tt < this.numTopics; tt++) { 313 | for(int tt : nonzeros.getKey1Set()) { 314 | double topic_alpha = alpha[tt]; 315 | int topic_count = local_topic_counts.get(tt); 316 | int[] paths = nonzeros.get(tt).keys(); 317 | for (int pp = 0; pp < paths.length; pp++) { 318 | int path = paths[pp]; 319 | double val = this.getObservation(tt, word, path); 320 | val *= (topic_alpha + topic_count); 321 | val /= this.getNormalizer(tt, path); 322 | double[] tmp = {tt, path, val}; 323 | dict.add(tmp); 324 | norm += val; 325 | } 326 | } 327 | return norm; 328 | } 329 | 330 | public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict) { 331 | double norm = 0.0; 332 | HIntIntIntHashMap nonzeros = this.nonZeroPaths.get(word); 333 | 334 | 335 | int[] tmpTopics = new int[this.numTopics]; 336 | for(int jj = 0; jj < this.numTopics; jj++) { 337 | tmpTopics[jj] = 0; 338 | } 339 | for(int jj = 0; jj < local_topic_counts.size(); jj++) { 340 | int[] current = local_topic_counts.get(jj); 341 | int tt = current[0]; 342 | tmpTopics[tt] = current[1]; 343 | } 344 | 345 | // Notice only the nonzero paths are considered 346 | //for(int tt = 0; tt < this.numTopics; tt++) { 347 | for(int tt : nonzeros.getKey1Set()) { 348 | double topic_alpha = alpha[tt]; 349 | int topic_count = tmpTopics[tt]; 350 | //local_topic_counts.get(ii); 351 | int[] paths = nonzeros.get(tt).keys(); 352 | for (int pp = 0; pp < paths.length; pp++) { 353 | int path = paths[pp]; 354 | double val = this.getObservation(tt, word, path); 355 | val *= (topic_alpha + topic_count); 356 | val /= this.getNormalizer(tt, path); 357 | double[] tmp = {tt, path, val}; 358 | dict.add(tmp); 359 | norm += val; 360 | } 361 | } 362 | return norm; 363 | } 364 | ////////////////////////////////////////////////////////// 365 | 366 | 367 | } 368 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelFastEst.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntDoubleHashMap; 5 | 6 | import java.util.Random; 7 | 8 | /** 9 | * This class extends the tree topic model fast class 10 | * Only add one more function, it computes the smoothing for each word 11 | * only based on the prior (treat the real count as zero), so it 12 | * serves as the upper bound of smoothing. 13 | * Author: Yuening Hu 14 | */ 15 | public class TreeTopicModelFastEst extends TreeTopicModelFast { 16 | public TreeTopicModelFastEst(int numTopics, Random random) { 17 | super(numTopics, random); 18 | this.smoothingEst = new TIntDoubleHashMap(); 19 | } 20 | 21 | /** 22 | * This function computes the upper bound of smoothing bucket. 23 | */ 24 | public void computeSmoothingEst(double[] alpha) { 25 | for(int ww : this.wordPaths.getKey1Set()) { 26 | this.smoothingEst.put(ww, 0.0); 27 | for(int tt = 0; tt < this.numTopics; tt++) { 28 | for(int pp : this.wordPaths.get(ww).keys()) { 29 | TIntArrayList path_nodes = this.wordPaths.get(ww, pp); 30 | double prob = 1.0; 31 | for(int nn = 0; nn < path_nodes.size() - 1; nn++) { 32 | int parent = path_nodes.get(nn); 33 | int child = path_nodes.get(nn+1); 34 | prob *= this.beta.get(parent, child) / this.betaSum.get(parent); 35 | } 36 | prob *= alpha[tt]; 37 | this.smoothingEst.adjustValue(ww, prob); 38 | } 39 | } 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelFastEstSortW.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntDoubleHashMap; 5 | 6 | import java.util.Random; 7 | 8 | /** 9 | * This class extends the tree topic model fast class 10 | * Only add one more function, it computes the smoothing for each word 11 | * only based on the prior (treat the real count as zero), so it 12 | * serves as the upper bound of smoothing. 13 | * Author: Yuening Hu 14 | */ 15 | public class TreeTopicModelFastEstSortW extends TreeTopicModelFastSortW { 16 | public TreeTopicModelFastEstSortW(int numTopics, Random random) { 17 | super(numTopics, random); 18 | this.smoothingEst = new TIntDoubleHashMap(); 19 | } 20 | 21 | /** 22 | * This function computes the upper bound of smoothing bucket. 23 | */ 24 | public void computeSmoothingEst(double[] alpha) { 25 | for(int ww : this.wordPaths.getKey1Set()) { 26 | this.smoothingEst.put(ww, 0.0); 27 | for(int tt = 0; tt < this.numTopics; tt++) { 28 | for(int pp : this.wordPaths.get(ww).keys()) { 29 | TIntArrayList path_nodes = this.wordPaths.get(ww, pp); 30 | double prob = 1.0; 31 | for(int nn = 0; nn < path_nodes.size() - 1; nn++) { 32 | int parent = path_nodes.get(nn); 33 | int child = path_nodes.get(nn+1); 34 | prob *= this.beta.get(parent, child) / this.betaSum.get(parent); 35 | } 36 | prob *= alpha[tt]; 37 | this.smoothingEst.adjustValue(ww, prob); 38 | } 39 | } 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelFastSortW.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntIntHashMap; 5 | 6 | import java.util.ArrayList; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | import java.util.Random; 10 | 11 | /** 12 | * nonZeroPathsBubbleSorted: Arraylist sorted 13 | * sorted[0] = (topic << TOPIC_BITS) + path 14 | * sorted[1] = (masked_path) + real_count 15 | * Author: Yuening Hu 16 | */ 17 | public class TreeTopicModelFastSortW extends TreeTopicModelFast { 18 | 19 | int TOPIC_BITS = 16; 20 | 21 | public TreeTopicModelFastSortW(int numTopics, Random random) { 22 | super(numTopics, random); 23 | } 24 | 25 | /** 26 | * After adding instances, update the parameters. 27 | */ 28 | public void updateParams() { 29 | 30 | for(int ww : this.nonZeroPaths.keys()) { 31 | if (!this.nonZeroPathsBubbleSorted.containsKey(ww)) { 32 | ArrayList sorted = new ArrayList (); 33 | this.nonZeroPathsBubbleSorted.put(ww, sorted); 34 | } 35 | } 36 | for(int tt = 0; tt < this.numTopics; tt++) { 37 | for(int pp = 0; pp < this.getPathNum(); pp++) { 38 | this.updatePathMaskedCount(pp, tt); 39 | } 40 | this.computeNormalizer(tt); 41 | } 42 | 43 | // for(int ww : this.nonZeroPaths.keys()) { 44 | // System.out.println("Word " + ww); 45 | // ArrayList sorted = this.nonZeroPathsBubbleSorted.get(ww); 46 | // for(int ii = 0; ii < sorted.size(); ii++) { 47 | // int[] tmp = sorted.get(ii); 48 | // System.out.println(tmp[0] + " " + tmp[1] + " " + tmp[2] + " " + tmp[3]); 49 | // } 50 | // } 51 | } 52 | 53 | protected void updatePathMaskedCount(int path, int topic) { 54 | TopicTreeWalk tw = this.traversals.get(topic); 55 | int ww = this.getWordFromPath(path); 56 | TIntArrayList path_nodes = this.wordPaths.get(ww, path); 57 | int leaf_node = path_nodes.get(path_nodes.size() - 1); 58 | int original_count = tw.getNodeCount(leaf_node); 59 | 60 | int shift_count = this.INTBITS; 61 | int count = this.maxDepth - 1; 62 | int val = 0; 63 | boolean flag = false; 64 | 65 | // note root is not included here 66 | // if count of a node in the path is larger than 0, denote as "1" 67 | // else use "0" 68 | for(int nn = 1; nn < path_nodes.size(); nn++) { 69 | int node = path_nodes.get(nn); 70 | shift_count--; 71 | count--; 72 | if (tw.getNodeCount(node) > 0) { 73 | flag = true; 74 | val += 1 << shift_count; 75 | } 76 | } 77 | 78 | // if a path is shorter than tree depth, fill in "1" 79 | // should we fit in "0" ??? 80 | while (flag && count > 0) { 81 | shift_count--; 82 | val += 1 << shift_count; 83 | count--; 84 | } 85 | 86 | val += original_count; 87 | this.addOrUpdateValue(topic, path, ww, val, false); 88 | 89 | } 90 | 91 | private void addOrUpdateValueold(int topic, int path, int word, int newvalue, boolean flag) { 92 | ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); 93 | int key = (topic << TOPIC_BITS) + path; 94 | //remove the old value 95 | int oldindex = sorted.size(); 96 | int oldvalue = -1; 97 | for(int ii = 0; ii < sorted.size(); ii++) { 98 | int[] tmp = sorted.get(ii); 99 | if(tmp[0] == key) { 100 | oldvalue = tmp[1]; 101 | sorted.remove(ii); 102 | break; 103 | } 104 | } 105 | if(oldindex > sorted.size()) { 106 | oldindex--; 107 | } 108 | 109 | // flag is true, increase value, else just update value 110 | int value = 0; 111 | if(flag) { 112 | value = oldvalue + newvalue; 113 | } else { 114 | value = newvalue; 115 | } 116 | 117 | //add the new value 118 | if (value > 0) { 119 | int index; 120 | if (value > oldvalue) { 121 | index = 0; 122 | for(int ii = oldindex - 1; ii >= 0; ii--) { 123 | //System.out.println(ii + " " + oldindex + " " + sorted.size()); 124 | int[] tmp = sorted.get(ii); 125 | if(value <= tmp[1]) { 126 | index = ii; 127 | break; 128 | } 129 | } 130 | } else { 131 | index = sorted.size(); 132 | for(int ii = oldindex; ii < sorted.size(); ii++) { 133 | int[] tmp = sorted.get(ii); 134 | if(value >= tmp[1]) { 135 | index = ii; 136 | break; 137 | } 138 | } 139 | } 140 | 141 | int[] newpair = {key, value}; 142 | sorted.add(index, newpair); 143 | } 144 | } 145 | 146 | private void addOrUpdateValue(int topic, int path, int word, int newvalue, boolean flag) { 147 | ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); 148 | int key = (topic << TOPIC_BITS) + path; 149 | //remove the old value 150 | int value = 0; 151 | for(int ii = 0; ii < sorted.size(); ii++) { 152 | int[] tmp = sorted.get(ii); 153 | if(tmp[0] == key) { 154 | value = tmp[1]; 155 | sorted.remove(ii); 156 | break; 157 | } 158 | } 159 | 160 | // flag is true, increase value, else just update value 161 | if(flag) { 162 | value += newvalue; 163 | } else { 164 | value = newvalue; 165 | } 166 | 167 | //add the new value 168 | if (value > 0) { 169 | int index = sorted.size(); 170 | for(int ii = 0; ii < sorted.size(); ii++) { 171 | int[] tmp = sorted.get(ii); 172 | if(value >= tmp[1]) { 173 | index = ii; 174 | break; 175 | } 176 | } 177 | int[] newpair = {key, value}; 178 | sorted.add(index, newpair); 179 | } 180 | } 181 | 182 | public void changeCount(int topic, int word, int path_index, int delta) { 183 | TopicTreeWalk tw = this.traversals.get(topic); 184 | TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 185 | 186 | // for affected paths, firstly remove the old values 187 | // do not consider the root 188 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 189 | int node = path_nodes.get(nn); 190 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 191 | tmp = 1 / tmp; 192 | TIntArrayList paths = this.nodeToPath.get(node); 193 | updateNormalizer(topic, paths, tmp); 194 | } 195 | 196 | // change the count for each edge per topic 197 | // return the node index whose count is changed from 0 or to 0 198 | int[] affected_nodes = tw.changeCount(path_nodes, delta); 199 | 200 | // change path count 201 | this.addOrUpdateValue(topic, path_index, word, delta, true); 202 | 203 | // if necessary, change the path mask of the affected nodes 204 | if (affected_nodes != null && affected_nodes.length > 0) { 205 | int[] affected_paths = this.findAffectedPaths(affected_nodes); 206 | for(int ii = 0; ii < affected_paths.length; ii++) { 207 | this.updatePathMaskedCount(affected_paths[ii], topic); 208 | } 209 | } 210 | 211 | // for affected paths, update the normalizer 212 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 213 | int node = path_nodes.get(nn); 214 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 215 | TIntArrayList paths = this.nodeToPath.get(node); 216 | updateNormalizer(topic, paths, tmp); 217 | } 218 | 219 | // update the root normalizer 220 | double val = this.betaSum.get(root) + tw.getNodeCount(root); 221 | this.rootNormalizer.put(topic, val); 222 | } 223 | 224 | /** 225 | * This function computes the topic term bucket. 226 | */ 227 | public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { 228 | double norm = 0.0; 229 | ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); 230 | 231 | // Notice only the nonzero paths are considered 232 | for(int ii = 0; ii < nonzeros.size(); ii++) { 233 | int[] tmp = nonzeros.get(ii); 234 | int key = tmp[0]; 235 | int tt = key >> TOPIC_BITS; 236 | int pp = key - (tt << TOPIC_BITS); 237 | 238 | double topic_alpha = alpha[tt]; 239 | int topic_count = local_topic_counts.get(tt); 240 | 241 | double val = this.getObservation(tt, word, pp); 242 | val *= (topic_alpha + topic_count); 243 | val /= this.getNormalizer(tt, pp); 244 | 245 | //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); 246 | 247 | double[] result = {tt, pp, val}; 248 | dict.add(result); 249 | 250 | norm += val; 251 | } 252 | 253 | return norm; 254 | } 255 | 256 | public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict) { 257 | double norm = 0.0; 258 | ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); 259 | 260 | 261 | int[] tmpTopics = new int[this.numTopics]; 262 | for(int jj = 0; jj < this.numTopics; jj++) { 263 | tmpTopics[jj] = 0; 264 | } 265 | for(int jj = 0; jj < local_topic_counts.size(); jj++) { 266 | int[] current = local_topic_counts.get(jj); 267 | int tt = current[0]; 268 | tmpTopics[tt] = current[1]; 269 | } 270 | 271 | // Notice only the nonzero paths are considered 272 | for(int ii = 0; ii < nonzeros.size(); ii++) { 273 | int[] tmp = nonzeros.get(ii); 274 | int key = tmp[0]; 275 | int tt = key >> TOPIC_BITS; 276 | int pp = key - (tt << TOPIC_BITS); 277 | 278 | double topic_alpha = alpha[tt]; 279 | int topic_count = tmpTopics[tt]; 280 | 281 | double val = this.getObservation(tt, word, pp); 282 | val *= (topic_alpha + topic_count); 283 | val /= this.getNormalizer(tt, pp); 284 | 285 | //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); 286 | 287 | double[] result = {tt, pp, val}; 288 | dict.add(result); 289 | 290 | norm += val; 291 | } 292 | return norm; 293 | } 294 | 295 | } 296 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelFastSortW1.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntIntHashMap; 5 | 6 | import java.util.ArrayList; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | import java.util.Random; 10 | 11 | /** 12 | * nonZeroPathsBubbleSorted: Arraylist sorted 13 | * sorted[0] = topic 14 | * sorted[1] = path 15 | * sorted[2] = (masked_path) + real_count 16 | * Author: Yuening Hu 17 | */ 18 | public class TreeTopicModelFastSortW1 extends TreeTopicModelFast { 19 | 20 | public TreeTopicModelFastSortW1(int numTopics, Random random) { 21 | super(numTopics, random); 22 | } 23 | 24 | public void updateParams() { 25 | 26 | for(int ww : this.nonZeroPaths.keys()) { 27 | if (!this.nonZeroPathsBubbleSorted.containsKey(ww)) { 28 | ArrayList sorted = new ArrayList (); 29 | this.nonZeroPathsBubbleSorted.put(ww, sorted); 30 | } 31 | } 32 | for(int tt = 0; tt < this.numTopics; tt++) { 33 | for(int pp = 0; pp < this.getPathNum(); pp++) { 34 | this.updatePathMaskedCount(pp, tt); 35 | } 36 | this.computeNormalizer(tt); 37 | } 38 | 39 | // for(int ww : this.nonZeroPaths.keys()) { 40 | // System.out.println("Word " + ww); 41 | // ArrayList sorted = this.nonZeroPathsBubbleSorted.get(ww); 42 | // for(int ii = 0; ii < sorted.size(); ii++) { 43 | // int[] tmp = sorted.get(ii); 44 | // System.out.println(tmp[0] + " " + tmp[1] + " " + tmp[2] + " " + tmp[3]); 45 | // } 46 | // } 47 | } 48 | 49 | protected void updatePathMaskedCount(int path, int topic) { 50 | TopicTreeWalk tw = this.traversals.get(topic); 51 | int ww = this.getWordFromPath(path); 52 | TIntArrayList path_nodes = this.wordPaths.get(ww, path); 53 | int leaf_node = path_nodes.get(path_nodes.size() - 1); 54 | int original_count = tw.getNodeCount(leaf_node); 55 | 56 | int shift_count = this.INTBITS; 57 | int count = this.maxDepth - 1; 58 | int val = 0; 59 | boolean flag = false; 60 | 61 | // note root is not included here 62 | // if count of a node in the path is larger than 0, denote as "1" 63 | // else use "0" 64 | for(int nn = 1; nn < path_nodes.size(); nn++) { 65 | int node = path_nodes.get(nn); 66 | shift_count--; 67 | count--; 68 | if (tw.getNodeCount(node) > 0) { 69 | flag = true; 70 | val += 1 << shift_count; 71 | } 72 | } 73 | 74 | // if a path is shorter than tree depth, fill in "1" 75 | // should we fit in "0" ??? 76 | while (flag && count > 0) { 77 | shift_count--; 78 | val += 1 << shift_count; 79 | count--; 80 | } 81 | 82 | val += original_count; 83 | this.addOrUpdateValue(topic, path, ww, val, false); 84 | 85 | } 86 | 87 | private void addOrUpdateValue(int topic, int path, int word, int newvalue, boolean flag) { 88 | ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); 89 | 90 | //remove the old value 91 | int value = 0; 92 | for(int ii = 0; ii < sorted.size(); ii++) { 93 | int[] tmp = sorted.get(ii); 94 | if(tmp[0] == topic && tmp[1] == path) { 95 | value = tmp[2]; 96 | sorted.remove(ii); 97 | break; 98 | } 99 | } 100 | 101 | // flag is true, increase value, else just update value 102 | if(flag) { 103 | value += newvalue; 104 | } else { 105 | value = newvalue; 106 | } 107 | 108 | //add the new value 109 | if (value > 0) { 110 | int index = sorted.size(); 111 | for(int ii = 0; ii < sorted.size(); ii++) { 112 | int[] tmp = sorted.get(ii); 113 | if(value >= tmp[2]) { 114 | index = ii; 115 | break; 116 | } 117 | } 118 | int[] newpair = {topic, path, value}; 119 | sorted.add(index, newpair); 120 | } 121 | } 122 | 123 | public void changeCount(int topic, int word, int path_index, int delta) { 124 | TopicTreeWalk tw = this.traversals.get(topic); 125 | TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 126 | 127 | // for affected paths, firstly remove the old values 128 | // do not consider the root 129 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 130 | int node = path_nodes.get(nn); 131 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 132 | tmp = 1 / tmp; 133 | TIntArrayList paths = this.nodeToPath.get(node); 134 | updateNormalizer(topic, paths, tmp); 135 | } 136 | 137 | // change the count for each edge per topic 138 | // return the node index whose count is changed from 0 or to 0 139 | int[] affected_nodes = tw.changeCount(path_nodes, delta); 140 | 141 | // change path count 142 | this.addOrUpdateValue(topic, path_index, word, delta, true); 143 | 144 | // if necessary, change the path mask of the affected nodes 145 | if (affected_nodes != null && affected_nodes.length > 0) { 146 | int[] affected_paths = this.findAffectedPaths(affected_nodes); 147 | for(int ii = 0; ii < affected_paths.length; ii++) { 148 | this.updatePathMaskedCount(affected_paths[ii], topic); 149 | } 150 | } 151 | 152 | // for affected paths, update the normalizer 153 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 154 | int node = path_nodes.get(nn); 155 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 156 | TIntArrayList paths = this.nodeToPath.get(node); 157 | updateNormalizer(topic, paths, tmp); 158 | } 159 | 160 | // update the root normalizer 161 | double val = this.betaSum.get(root) + tw.getNodeCount(root); 162 | this.rootNormalizer.put(topic, val); 163 | } 164 | 165 | /** 166 | * This function computes the topic term bucket. 167 | */ 168 | public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { 169 | double norm = 0.0; 170 | ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); 171 | 172 | // Notice only the nonzero paths are considered 173 | for(int ii = 0; ii < nonzeros.size(); ii++) { 174 | int[] tmp = nonzeros.get(ii); 175 | int tt = tmp[0]; 176 | int pp = tmp[1]; 177 | 178 | double topic_alpha = alpha[tt]; 179 | int topic_count = local_topic_counts.get(tt); 180 | 181 | double val = this.getObservation(tt, word, pp); 182 | val *= (topic_alpha + topic_count); 183 | val /= this.getNormalizer(tt, pp); 184 | 185 | //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); 186 | 187 | double[] result = {tt, pp, val}; 188 | //dict.add(result); 189 | 190 | // int index = dict.size(); 191 | // for(int jj = 0; jj < dict.size(); jj++) { 192 | // double[] find = dict.get(jj); 193 | // //System.out.println(find[2] + " " + val); 194 | // if(val >= find[2]) { 195 | // index = jj; 196 | // break; 197 | // } 198 | // } 199 | // dict.add(index, result); 200 | 201 | int index = 0; 202 | for(int jj = dict.size() - 1; jj >= 0 ; jj--) { 203 | double[] find = dict.get(jj); 204 | if(val <= find[2]) { 205 | index = jj; 206 | break; 207 | } 208 | } 209 | dict.add(index, result); 210 | 211 | norm += val; 212 | } 213 | 214 | // for(int ii = 0; ii < dict.size(); ii++) { 215 | // double[] tmp = dict.get(ii); 216 | // System.out.println(tmp[0] + " " + tmp[1] + " " + tmp[2]); 217 | // } 218 | 219 | return norm; 220 | } 221 | 222 | public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict) { 223 | double norm = 0.0; 224 | ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); 225 | 226 | 227 | int[] tmpTopics = new int[this.numTopics]; 228 | for(int jj = 0; jj < this.numTopics; jj++) { 229 | tmpTopics[jj] = 0; 230 | } 231 | for(int jj = 0; jj < local_topic_counts.size(); jj++) { 232 | int[] current = local_topic_counts.get(jj); 233 | int tt = current[0]; 234 | tmpTopics[tt] = current[1]; 235 | } 236 | 237 | // Notice only the nonzero paths are considered 238 | for(int ii = 0; ii < nonzeros.size(); ii++) { 239 | int[] tmp = nonzeros.get(ii); 240 | int tt = tmp[0]; 241 | int pp = tmp[1]; 242 | 243 | double topic_alpha = alpha[tt]; 244 | int topic_count = tmpTopics[tt]; 245 | 246 | double val = this.getObservation(tt, word, pp); 247 | val *= (topic_alpha + topic_count); 248 | val /= this.getNormalizer(tt, pp); 249 | 250 | //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); 251 | 252 | double[] result = {tt, pp, val}; 253 | dict.add(result); 254 | 255 | norm += val; 256 | } 257 | return norm; 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelFastSortW2.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntIntHashMap; 5 | 6 | import java.util.ArrayList; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | import java.util.Random; 10 | 11 | /** 12 | * nonZeroPathsBubbleSorted: Arraylist sorted 13 | * sorted[0] = topic 14 | * sorted[1] = path 15 | * sorted[2] = maksed_path 16 | * sorted[3] = real_counts 17 | * Author: Yuening Hu 18 | */ 19 | public class TreeTopicModelFastSortW2 extends TreeTopicModelFast { 20 | 21 | public TreeTopicModelFastSortW2(int numTopics, Random random) { 22 | super(numTopics, random); 23 | } 24 | 25 | /** 26 | * After adding instances, update the parameters. 27 | */ 28 | public void updateParams() { 29 | 30 | for(int ww : this.nonZeroPaths.keys()) { 31 | if (!this.nonZeroPathsBubbleSorted.containsKey(ww)) { 32 | ArrayList sorted = new ArrayList (); 33 | this.nonZeroPathsBubbleSorted.put(ww, sorted); 34 | } 35 | } 36 | 37 | for(int tt = 0; tt < this.numTopics; tt++) { 38 | for(int pp = 0; pp < this.getPathNum(); pp++) { 39 | this.updatePathMaskedCount(pp, tt); 40 | } 41 | this.computeNormalizer(tt); 42 | } 43 | 44 | // for(int ww : this.nonZeroPaths.keys()) { 45 | // System.out.println("Word " + ww); 46 | // ArrayList sorted = this.nonZeroPathsBubbleSorted.get(ww); 47 | // for(int ii = 0; ii < sorted.size(); ii++) { 48 | // int[] tmp = sorted.get(ii); 49 | // System.out.println(tmp[0] + " " + tmp[1] + " " + tmp[2] + " " + tmp[3]); 50 | // } 51 | // } 52 | } 53 | 54 | protected void updatePathMaskedCount(int path, int topic) { 55 | TopicTreeWalk tw = this.traversals.get(topic); 56 | int ww = this.getWordFromPath(path); 57 | TIntArrayList path_nodes = this.wordPaths.get(ww, path); 58 | int leaf_node = path_nodes.get(path_nodes.size() - 1); 59 | int original_count = tw.getNodeCount(leaf_node); 60 | 61 | int count = this.maxDepth; 62 | int val = 0; 63 | boolean flag = false; 64 | 65 | // note root is not included here 66 | // if count of a node in the path is larger than 0, denote as "1" 67 | // else use "0" 68 | for(int nn = 1; nn < path_nodes.size(); nn++) { 69 | int node = path_nodes.get(nn); 70 | count--; 71 | if (tw.getNodeCount(node) > 0) { 72 | flag = true; 73 | val += 1 << count; 74 | } 75 | } 76 | 77 | // if a path is shorter than tree depth, fill in "1" 78 | // should we fit in "0" ??? 79 | while (flag && count > 0) { 80 | count--; 81 | val += 1 << count; 82 | } 83 | 84 | this.addOrUpdateValue(topic, path, ww, val, original_count, false); 85 | 86 | } 87 | 88 | private void addOrUpdateValue (int topic, int path, int word, int newpathvalue, int newcount, boolean flag) { 89 | ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); 90 | 91 | //remove the old value 92 | int pathvalue = 0; 93 | int count = 0; 94 | for(int ii = 0; ii < sorted.size(); ii++) { 95 | int[] tmp = sorted.get(ii); 96 | if(tmp[0] == topic && tmp[1] == path) { 97 | pathvalue = tmp[2]; 98 | count = tmp[3]; 99 | sorted.remove(ii); 100 | break; 101 | } 102 | } 103 | 104 | // flag is true, increase value, else just update value 105 | if(flag) { 106 | count += newcount; 107 | } else { 108 | count = newcount; 109 | } 110 | 111 | if(newpathvalue != -1) { 112 | pathvalue = newpathvalue; 113 | } 114 | 115 | //add the new value 116 | if (pathvalue > 0) { 117 | int index = sorted.size(); 118 | for(int ii = 0; ii < sorted.size(); ii++) { 119 | int[] tmp = sorted.get(ii); 120 | if(count >= tmp[3]) { 121 | index = ii; 122 | break; 123 | } 124 | } 125 | int[] newpair = {topic, path, pathvalue, count}; 126 | sorted.add(index, newpair); 127 | } 128 | } 129 | 130 | private void addOrUpdateValueold (int topic, int path, int word, int newpathvalue, int newcount, boolean flag) { 131 | 132 | ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); 133 | 134 | // find the position 135 | int index = sorted.size(); 136 | int pathvalue = -1; 137 | int count = -1; 138 | for(int ii = 0; ii < sorted.size(); ii++) { 139 | int[] tmp = sorted.get(ii); 140 | if(tmp[0] == topic && tmp[1] == path) { 141 | pathvalue = tmp[2]; 142 | count = tmp[3]; 143 | break; 144 | } 145 | } 146 | 147 | // flag is true, increase value, else just update value 148 | if(flag) { 149 | count += newcount; 150 | } else { 151 | count = newcount; 152 | } 153 | 154 | if(newpathvalue != -1) { 155 | pathvalue = newpathvalue; 156 | } 157 | 158 | // adjust the value and update or insert 159 | if (index == sorted.size()) { 160 | int[] newpair = {topic, path, pathvalue, count}; 161 | sorted.add(newpair); 162 | } else { 163 | int[] current = sorted.get(index); 164 | current[0] = topic; 165 | current[1] = path; 166 | current[2] = pathvalue; 167 | current[3] = count; 168 | } 169 | 170 | // bubble to left 171 | for(int ii = index-1; ii >= 0; ii--) { 172 | int[] left = sorted.get(ii); 173 | int[] current = sorted.get(ii+1); 174 | if(current[3] > left[3]) { 175 | int num = 3; 176 | while(num >= 0) { 177 | int tmp = current[num]; 178 | current[num] = left[num]; 179 | left[num] = tmp; 180 | num--; 181 | } 182 | //int a = 0; 183 | } else { 184 | break; 185 | } 186 | } 187 | 188 | // bubble to right 189 | for(int ii = index + 1; ii < sorted.size(); ii++) { 190 | int[] right = sorted.get(ii); 191 | int[] current = sorted.get(ii-1); 192 | if(current[3] < right[3]) { 193 | int num = 3; 194 | while(num >= 0) { 195 | int tmp = current[num]; 196 | current[num] = right[num]; 197 | right[num] = tmp; 198 | num--; 199 | } 200 | } else { 201 | break; 202 | } 203 | } 204 | } 205 | 206 | public void changeCount(int topic, int word, int path_index, int delta) { 207 | TopicTreeWalk tw = this.traversals.get(topic); 208 | TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 209 | 210 | // for affected paths, firstly remove the old values 211 | // do not consider the root 212 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 213 | int node = path_nodes.get(nn); 214 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 215 | tmp = 1 / tmp; 216 | TIntArrayList paths = this.nodeToPath.get(node); 217 | updateNormalizer(topic, paths, tmp); 218 | } 219 | 220 | // change the count for each edge per topic 221 | // return the node index whose count is changed from 0 or to 0 222 | int[] affected_nodes = tw.changeCount(path_nodes, delta); 223 | 224 | // change path count 225 | this.addOrUpdateValue(topic, path_index, word, -1, delta, true); 226 | 227 | // if necessary, change the path mask of the affected nodes 228 | if (affected_nodes != null && affected_nodes.length > 0) { 229 | int[] affected_paths = this.findAffectedPaths(affected_nodes); 230 | for(int ii = 0; ii < affected_paths.length; ii++) { 231 | this.updatePathMaskedCount(affected_paths[ii], topic); 232 | } 233 | } 234 | 235 | // for affected paths, update the normalizer 236 | for(int nn = 1; nn < path_nodes.size() - 1; nn++) { 237 | int node = path_nodes.get(nn); 238 | double tmp = this.betaSum.get(node) + tw.getNodeCount(node); 239 | TIntArrayList paths = this.nodeToPath.get(node); 240 | updateNormalizer(topic, paths, tmp); 241 | } 242 | 243 | // update the root normalizer 244 | double val = this.betaSum.get(root) + tw.getNodeCount(root); 245 | this.rootNormalizer.put(topic, val); 246 | } 247 | 248 | /** 249 | * This function computes the topic term bucket. 250 | */ 251 | public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { 252 | double norm = 0.0; 253 | ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); 254 | 255 | // Notice only the nonzero paths are considered 256 | for(int ii = 0; ii < nonzeros.size(); ii++) { 257 | int[] tmp = nonzeros.get(ii); 258 | int tt = tmp[0]; 259 | int pp = tmp[1]; 260 | 261 | double topic_alpha = alpha[tt]; 262 | int topic_count = local_topic_counts.get(tt); 263 | 264 | double val = this.getObservation(tt, word, pp); 265 | val *= (topic_alpha + topic_count); 266 | val /= this.getNormalizer(tt, pp); 267 | 268 | //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); 269 | 270 | double[] result = {tt, pp, val}; 271 | dict.add(result); 272 | 273 | // int index = dict.size(); 274 | // for(int jj = 0; jj < dict.size(); jj++) { 275 | // double[] find = dict.get(jj); 276 | // //System.out.println(find[2] + " " + val); 277 | // if(val >= find[2]) { 278 | // index = jj; 279 | // break; 280 | // } 281 | // } 282 | // dict.add(index, result); 283 | 284 | // int index = 0; 285 | // for(int jj = dict.size() - 1; jj >= 0 ; jj--) { 286 | // double[] find = dict.get(jj); 287 | // if(val <= find[2]) { 288 | // index = jj; 289 | // break; 290 | // } 291 | // } 292 | // dict.add(index, result); 293 | 294 | norm += val; 295 | } 296 | 297 | // for(int ii = 0; ii < dict.size(); ii++) { 298 | // double[] tmp = dict.get(ii); 299 | // System.out.println(tmp[0] + " " + tmp[1] + " " + tmp[2]); 300 | // } 301 | 302 | return norm; 303 | } 304 | 305 | public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict) { 306 | double norm = 0.0; 307 | ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); 308 | 309 | 310 | int[] tmpTopics = new int[this.numTopics]; 311 | for(int jj = 0; jj < this.numTopics; jj++) { 312 | tmpTopics[jj] = 0; 313 | } 314 | for(int jj = 0; jj < local_topic_counts.size(); jj++) { 315 | int[] current = local_topic_counts.get(jj); 316 | int tt = current[0]; 317 | tmpTopics[tt] = current[1]; 318 | } 319 | 320 | // Notice only the nonzero paths are considered 321 | for(int ii = 0; ii < nonzeros.size(); ii++) { 322 | int[] tmp = nonzeros.get(ii); 323 | int tt = tmp[0]; 324 | int pp = tmp[1]; 325 | 326 | double topic_alpha = alpha[tt]; 327 | int topic_count = tmpTopics[tt]; 328 | 329 | double val = this.getObservation(tt, word, pp); 330 | val *= (topic_alpha + topic_count); 331 | val /= this.getNormalizer(tt, pp); 332 | 333 | //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); 334 | 335 | double[] result = {tt, pp, val}; 336 | dict.add(result); 337 | 338 | norm += val; 339 | } 340 | return norm; 341 | } 342 | } 343 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicModelNaive.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Random; 5 | 6 | import cc.mallet.types.Dirichlet; 7 | 8 | import gnu.trove.TDoubleArrayList; 9 | import gnu.trove.TIntArrayList; 10 | import gnu.trove.TIntDoubleHashMap; 11 | import gnu.trove.TIntDoubleIterator; 12 | import gnu.trove.TIntHashSet; 13 | import gnu.trove.TIntIntHashMap; 14 | import gnu.trove.TIntObjectHashMap; 15 | import gnu.trove.TIntObjectIterator; 16 | 17 | /** 18 | * This class extends the tree topic model 19 | * It implemented the four abstract methods in a naive way: given a word, 20 | * (1) compute the probability for each topic every time directly 21 | * Author: Yuening Hu 22 | */ 23 | public class TreeTopicModelNaive extends TreeTopicModel{ 24 | 25 | public TreeTopicModelNaive(int numTopics, Random random) { 26 | super(numTopics, random); 27 | } 28 | 29 | /** 30 | * Just calls changeCountOnly(), nothing else. 31 | */ 32 | public void changeCount(int topic, int word, int path, int delta) { 33 | // TIntArrayList path_nodes = this.wordPaths.get(word, path_index); 34 | // TopicTreeWalk tw = this.traversals.get(topic); 35 | // tw.changeCount(path_nodes, delta); 36 | this.changeCountOnly(topic, word, path, delta); 37 | } 38 | 39 | /** 40 | * Given a word and the topic counts in the current document, 41 | * this function computes the probability per path per topic directly 42 | * according to the sampleing equation. 43 | */ 44 | public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { 45 | double norm = 0.0; 46 | int[] paths = this.getWordPathIndexSet(word); 47 | for(int tt = 0; tt < this.numTopics; tt++) { 48 | double topic_alpha = alpha[tt]; 49 | int topic_count = local_topic_counts.get(tt); 50 | for (int pp = 0; pp < paths.length; pp++) { 51 | int path_index = paths[pp]; 52 | double val = this.computeTopicPathProb(tt, word, path_index); 53 | val *= (topic_alpha + topic_count); 54 | double[] tmp = {tt, path_index, val}; 55 | dict.add(tmp); 56 | norm += val; 57 | } 58 | } 59 | return norm; 60 | } 61 | 62 | /** 63 | * No parameter needs to be updated. 64 | */ 65 | public void updateParams() { 66 | } 67 | 68 | /** 69 | * Not actually used. 70 | */ 71 | public double getNormalizer(int topic, int path) { 72 | return 0; 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSampler.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntHashSet; 6 | import gnu.trove.TIntIntHashMap; 7 | import gnu.trove.TIntIntIterator; 8 | 9 | import java.io.BufferedOutputStream; 10 | import java.io.BufferedReader; 11 | import java.io.DataInputStream; 12 | import java.io.File; 13 | import java.io.FileInputStream; 14 | import java.io.FileOutputStream; 15 | import java.io.IOException; 16 | import java.io.InputStreamReader; 17 | import java.io.PrintStream; 18 | import java.io.Serializable; 19 | import java.util.ArrayList; 20 | import java.util.Arrays; 21 | import java.util.zip.GZIPOutputStream; 22 | 23 | import cc.mallet.types.Dirichlet; 24 | import cc.mallet.types.FeatureSequence; 25 | import cc.mallet.types.Instance; 26 | import cc.mallet.types.InstanceList; 27 | import cc.mallet.util.Randoms; 28 | 29 | /** 30 | * This class defines the tree topic sampler, which loads the instances, 31 | * reports the topics, and leaves the sampler method as an abstract method, 32 | * which might be various for different methods. 33 | * Author: Yuening Hu 34 | */ 35 | public interface TreeTopicSampler { 36 | 37 | public void initialize(String treeFiles, String hyperFile, String vocabFile); 38 | public void setNumIterations(int iters); 39 | //public int getNumIterations(); 40 | //public void resumeLHood(String lhoodFile) throws IOException; 41 | public void resume(InstanceList training, String resumeDir); 42 | public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords); 43 | //public double lhood(); 44 | //public void report(String outputDir, int topWords) throws IOException; 45 | //public void printTopWords(File file, int numWords) throws IOException; 46 | //public String displayTopWords (int numWords); 47 | //public void printState (File file) throws IOException; 48 | //public void printStats (File file) throws IOException; 49 | //public void loadVocab(String vocabFile); 50 | //public void loadConstraints(String consFile); 51 | 52 | public void addInstances(InstanceList training); 53 | //public void resumeStates(InstanceList training, String statesFile) throws IOException; 54 | public void clearTopicAssignments(String option, String consFile); 55 | //public void changeTopic(int doc, int index, int word, int new_topic, int new_path); 56 | //public double docLHood(); 57 | //public void printDocumentTopics (File file) throws IOException; 58 | //public void sampleDoc(int doc); 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSamplerFast.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntIntHashMap; 6 | 7 | import java.io.Serializable; 8 | import java.util.ArrayList; 9 | import java.util.Arrays; 10 | 11 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 12 | import cc.mallet.util.Randoms; 13 | 14 | /** 15 | * This class defines a fast tree topic sampler, which calls the fast tree topic model. 16 | * (1) It divides the sampling into three bins: smoothing, topic beta, topic term. 17 | * as Yao and Mimno's paper, KDD, 2009. 18 | * (2) Each time the smoothing, topic beta, and topic term are recomputed. 19 | * It is faster, because, 20 | * (1) For topic term, only compute the one with non-zero paths (see TreeTopicModelFast). 21 | * (2) The normalizer is saved. 22 | * (3) Topic counts for each documents are ranked. 23 | * Author: Yuening Hu 24 | */ 25 | public class TreeTopicSamplerFast extends TreeTopicSamplerHashD { 26 | 27 | public TreeTopicSamplerFast (int numberOfTopics, double alphaSum, int seed, boolean sort) { 28 | super(numberOfTopics, alphaSum, seed); 29 | 30 | if (sort) { 31 | this.topics = new TreeTopicModelFastSortW(this.numTopics, this.random); 32 | //} else if (bubble == 1) { 33 | // this.topics = new TreeTopicModelFastSortT1(this.numTopics, this.random); 34 | //} else if (bubble == 2) { 35 | // this.topics = new TreeTopicModelFastSortT2(this.numTopics, this.random); 36 | } else { 37 | this.topics = new TreeTopicModelFast(this.numTopics, this.random); 38 | } 39 | } 40 | 41 | /** 42 | * For each word in a document, firstly covers its topic and path, then sample a 43 | * topic and path, and update. 44 | */ 45 | public void sampleDoc(int doc_id){ 46 | DocData doc = this.data.get(doc_id); 47 | //System.out.println("doc " + doc_id); 48 | 49 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 50 | //int word = doc.tokens.getIndexAtPosition(ii); 51 | int word = doc.tokens.get(ii); 52 | 53 | this.changeTopic(doc_id, ii, word, -1, -1); 54 | 55 | double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); 56 | double topic_beta_mass = this.topics.computeTermTopicBeta(doc.topicCounts, word); 57 | 58 | ArrayList topic_term_score = new ArrayList (); 59 | double topic_term_mass = this.topics.computeTopicTerm(this.alpha, doc.topicCounts, word, topic_term_score); 60 | 61 | double norm = smoothing_mass + topic_beta_mass + topic_term_mass; 62 | double sample = this.random.nextDouble(); 63 | //double sample = 0.5; 64 | sample *= norm; 65 | 66 | int new_topic = -1; 67 | int new_path = -1; 68 | 69 | int[] paths = this.topics.getWordPathIndexSet(word); 70 | 71 | // sample the smoothing bin 72 | if (sample < smoothing_mass) { 73 | for (int tt = 0; tt < this.numTopics; tt++) { 74 | for (int pp : paths) { 75 | double val = alpha[tt] * this.topics.getPathPrior(word, pp); 76 | val /= this.topics.getNormalizer(tt, pp); 77 | sample -= val; 78 | if (sample <= 0.0) { 79 | new_topic = tt; 80 | new_path = pp; 81 | break; 82 | } 83 | } 84 | if (new_topic >= 0) { 85 | break; 86 | } 87 | } 88 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); 89 | } else { 90 | sample -= smoothing_mass; 91 | } 92 | 93 | // sample the topic beta bin 94 | if (new_topic < 0 && sample < topic_beta_mass) { 95 | for(int tt : doc.topicCounts.keys()) { 96 | for (int pp : paths) { 97 | double val = doc.topicCounts.get(tt) * this.topics.getPathPrior(word, pp); 98 | val /= this.topics.getNormalizer(tt, pp); 99 | sample -= val; 100 | if (sample <= 0.0) { 101 | new_topic = tt; 102 | new_path = pp; 103 | break; 104 | } 105 | } 106 | if (new_topic >= 0) { 107 | break; 108 | } 109 | } 110 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); 111 | } else { 112 | sample -= topic_beta_mass; 113 | } 114 | 115 | // sample the topic term bin 116 | if (new_topic < 0) { 117 | for(int jj = 0; jj < topic_term_score.size(); jj++) { 118 | double[] tmp = topic_term_score.get(jj); 119 | int tt = (int) tmp[0]; 120 | int pp = (int) tmp[1]; 121 | double val = tmp[2]; 122 | sample -= val; 123 | if (sample <= 0.0) { 124 | new_topic = tt; 125 | new_path = pp; 126 | break; 127 | } 128 | } 129 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); 130 | } 131 | 132 | this.changeTopic(doc_id, ii, word, new_topic, new_path); 133 | } 134 | } 135 | 136 | ///////////////////////////// 137 | // The following methods are for testing only. 138 | 139 | public double callComputeTermTopicBeta(TIntIntHashMap topic_counts, int word) { 140 | return this.topics.computeTermTopicBeta(topic_counts, word); 141 | } 142 | 143 | public double callComputeTermSmoothing(int word) { 144 | return this.topics.computeTermSmoothing(this.alpha, word); 145 | } 146 | 147 | public double computeTopicSmoothTest(int word) { 148 | double smooth = 0.0; 149 | int[] paths = this.topics.getWordPathIndexSet(word); 150 | for(int tt = 0; tt < this.numTopics; tt++) { 151 | double topic_alpha = alpha[tt]; 152 | for (int pp = 0; pp < paths.length; pp++) { 153 | int path_index = paths[pp]; 154 | 155 | TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); 156 | TopicTreeWalk tw = this.topics.traversals.get(tt); 157 | 158 | double tmp = 1.0; 159 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 160 | int parent = path_nodes.get(ii); 161 | int child = path_nodes.get(ii+1); 162 | tmp *= this.topics.beta.get(parent, child); 163 | tmp /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); 164 | } 165 | tmp *= topic_alpha; 166 | smooth += tmp; 167 | } 168 | } 169 | return smooth; 170 | } 171 | 172 | public double computeTopicTermBetaTest(TIntIntHashMap local_topic_counts, int word) { 173 | double topictermbeta = 0.0; 174 | int[] paths = this.topics.getWordPathIndexSet(word); 175 | for(int tt = 0; tt < this.numTopics; tt++) { 176 | int topic_count = local_topic_counts.get(tt); 177 | for (int pp = 0; pp < paths.length; pp++) { 178 | int path_index = paths[pp]; 179 | 180 | TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); 181 | TopicTreeWalk tw = this.topics.traversals.get(tt); 182 | 183 | double tmp = 1.0; 184 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 185 | int parent = path_nodes.get(ii); 186 | int child = path_nodes.get(ii+1); 187 | tmp *= this.topics.beta.get(parent, child); 188 | tmp /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); 189 | } 190 | tmp *= topic_count; 191 | 192 | topictermbeta += tmp; 193 | } 194 | } 195 | return topictermbeta; 196 | } 197 | 198 | public double computeTopicTermScoreTest(double[] alpha, TIntIntHashMap local_topic_counts, int word, HIntIntDoubleHashMap dict) { 199 | double termscore = 0.0; 200 | int[] paths = this.topics.getWordPathIndexSet(word); 201 | for(int tt = 0; tt < this.numTopics; tt++) { 202 | double topic_alpha = alpha[tt]; 203 | int topic_count = local_topic_counts.get(tt); 204 | for (int pp = 0; pp < paths.length; pp++) { 205 | int path_index = paths[pp]; 206 | 207 | TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); 208 | TopicTreeWalk tw = this.topics.traversals.get(tt); 209 | 210 | double val = 1.0; 211 | double tmp = 1.0; 212 | double normalizer = 1.0; 213 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 214 | int parent = path_nodes.get(ii); 215 | int child = path_nodes.get(ii+1); 216 | val *= this.topics.beta.get(parent, child) + tw.getCount(parent, child); 217 | tmp *= this.topics.beta.get(parent, child); 218 | normalizer *= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); 219 | } 220 | val -= tmp; 221 | val *= (topic_alpha + topic_count); 222 | val /= normalizer; 223 | 224 | dict.put(tt, path_index, val); 225 | termscore += val; 226 | } 227 | } 228 | return termscore; 229 | } 230 | 231 | public double computeTopicTermTest(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { 232 | double norm = 0.0; 233 | int[] paths = this.topics.getWordPathIndexSet(word); 234 | for(int tt = 0; tt < this.numTopics; tt++) { 235 | double topic_alpha = alpha[tt]; 236 | int topic_count = local_topic_counts.get(tt); 237 | for (int pp = 0; pp < paths.length; pp++) { 238 | int path_index = paths[pp]; 239 | 240 | TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); 241 | TopicTreeWalk tw = this.topics.traversals.get(tt); 242 | 243 | double smooth = 1.0; 244 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 245 | int parent = path_nodes.get(ii); 246 | int child = path_nodes.get(ii+1); 247 | smooth *= this.topics.beta.get(parent, child); 248 | smooth /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); 249 | } 250 | smooth *= topic_alpha; 251 | 252 | double topicterm = 1.0; 253 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 254 | int parent = path_nodes.get(ii); 255 | int child = path_nodes.get(ii+1); 256 | topicterm *= this.topics.beta.get(parent, child); 257 | topicterm /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); 258 | } 259 | topicterm *= topic_count; 260 | 261 | double termscore = 1.0; 262 | double tmp = 1.0; 263 | double normalizer = 1.0; 264 | for(int ii = 0; ii < path_nodes.size()-1; ii++) { 265 | int parent = path_nodes.get(ii); 266 | int child = path_nodes.get(ii+1); 267 | termscore *= this.topics.beta.get(parent, child) + tw.getCount(parent, child); 268 | tmp *= this.topics.beta.get(parent, child); 269 | normalizer *= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); 270 | } 271 | termscore -= tmp; 272 | termscore *= (topic_alpha + topic_count); 273 | termscore /= normalizer; 274 | 275 | double val = smooth + topicterm + termscore; 276 | double[] tmptmp = {tt, path_index, val}; 277 | dict.add(tmptmp); 278 | norm += val; 279 | System.out.println("Fast Topic " + tt + " " + smooth + " " + topicterm + " " + termscore + " " + tmp + " " + topic_alpha + " " + topic_count + " " + termscore); 280 | } 281 | } 282 | return norm; 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSamplerFastEst.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import java.util.ArrayList; 4 | 5 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 6 | import gnu.trove.TIntArrayList; 7 | import gnu.trove.TIntDoubleHashMap; 8 | 9 | /** 10 | * This class improves the fast sampler based on estimation of smoothing. 11 | * Most of the time, the smoothing is very small and not worth to recompute since 12 | * it will hardly be hit. So we use an upper bound for smoothing. 13 | * Only if the smoothing bin is hit, the actual smoothing is computed and resampled. 14 | * Author: Yuening Hu 15 | */ 16 | public class TreeTopicSamplerFastEst extends TreeTopicSamplerHashD{ 17 | 18 | public TreeTopicSamplerFastEst (int numberOfTopics, double alphaSum, int seed, boolean sort) { 19 | super(numberOfTopics, alphaSum, seed); 20 | 21 | if (sort) { 22 | this.topics = new TreeTopicModelFastEstSortW(this.numTopics, this.random); 23 | } else { 24 | this.topics = new TreeTopicModelFastEst(this.numTopics, this.random); 25 | } 26 | } 27 | 28 | /** 29 | * Use an upper bound for smoothing. Only if the smoothing 30 | * bin is hit, the actual smoothing is computed and resampled. 31 | */ 32 | public void sampleDoc(int doc_id) { 33 | DocData doc = this.data.get(doc_id); 34 | //System.out.println("doc " + doc_id); 35 | //int[] tmpstats = this.stats.get(this.stats.size()-1); 36 | 37 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 38 | //int word = doc.tokens.getIndexAtPosition(ii); 39 | int word = doc.tokens.get(ii); 40 | 41 | this.changeTopic(doc_id, ii, word, -1, -1); 42 | 43 | 44 | //double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); 45 | double smoothing_mass_est = this.topics.smoothingEst.get(word); 46 | 47 | double topic_beta_mass = this.topics.computeTermTopicBeta(doc.topicCounts, word); 48 | 49 | ArrayList topic_term_score = new ArrayList(); 50 | double topic_term_mass = this.topics.computeTopicTerm(this.alpha, doc.topicCounts, word, topic_term_score); 51 | 52 | double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; 53 | double sample = this.random.nextDouble(); 54 | //double sample = 0.5; 55 | sample *= norm_est; 56 | 57 | int new_topic = -1; 58 | int new_path = -1; 59 | 60 | int[] paths = this.topics.getWordPathIndexSet(word); 61 | 62 | // sample the smoothing bin 63 | if (sample < smoothing_mass_est) { 64 | double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); 65 | double norm = smoothing_mass + topic_beta_mass + topic_term_mass; 66 | sample /= norm_est; 67 | sample *= norm; 68 | if (sample < smoothing_mass) { 69 | for (int tt = 0; tt < this.numTopics; tt++) { 70 | for (int pp : paths) { 71 | double val = alpha[tt] * this.topics.getPathPrior(word, pp); 72 | val /= this.topics.getNormalizer(tt, pp); 73 | sample -= val; 74 | if (sample <= 0.0) { 75 | new_topic = tt; 76 | new_path = pp; 77 | break; 78 | } 79 | } 80 | if (new_topic >= 0) { 81 | break; 82 | } 83 | } 84 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); 85 | } else { 86 | sample -= smoothing_mass; 87 | } 88 | } else { 89 | sample -= smoothing_mass_est; 90 | } 91 | 92 | // sample topic beta bin 93 | if (new_topic < 0 && sample < topic_beta_mass) { 94 | for(int tt : doc.topicCounts.keys()) { 95 | for (int pp : paths) { 96 | double val = doc.topicCounts.get(tt) * this.topics.getPathPrior(word, pp); 97 | val /= this.topics.getNormalizer(tt, pp); 98 | sample -= val; 99 | if (sample <= 0.0) { 100 | new_topic = tt; 101 | new_path = pp; 102 | break; 103 | } 104 | } 105 | if (new_topic >= 0) { 106 | break; 107 | } 108 | } 109 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); 110 | } else { 111 | sample -= topic_beta_mass; 112 | } 113 | 114 | // sample topic term bin 115 | if (new_topic < 0) { 116 | for(int jj = 0; jj < topic_term_score.size(); jj++) { 117 | double[] tmp = topic_term_score.get(jj); 118 | int tt = (int) tmp[0]; 119 | int pp = (int) tmp[1]; 120 | double val = tmp[2]; 121 | sample -= val; 122 | if (sample <= 0.0) { 123 | new_topic = tt; 124 | new_path = pp; 125 | break; 126 | } 127 | } 128 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); 129 | } 130 | 131 | this.changeTopic(doc_id, ii, word, new_topic, new_path); 132 | } 133 | 134 | } 135 | 136 | /** 137 | * Before sampling start, compute smoothing upper bound for each word. 138 | */ 139 | public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { 140 | if(this.topics instanceof TreeTopicModelFastEst) { 141 | TreeTopicModelFastEst tmp = (TreeTopicModelFastEst) this.topics; 142 | tmp.computeSmoothingEst(this.alpha); 143 | } else if (this.topics instanceof TreeTopicModelFastEstSortW) { 144 | TreeTopicModelFastEstSortW tmp = (TreeTopicModelFastEstSortW) this.topics; 145 | tmp.computeSmoothingEst(this.alpha); 146 | } 147 | 148 | super.estimate(numIterations, outputFolder, outputInterval, topWords); 149 | } 150 | 151 | } 152 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSamplerFastEstSortD.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import java.util.ArrayList; 4 | 5 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 6 | 7 | /** 8 | * This class improves the fast sampler based on estimation of smoothing. 9 | * Most of the time, the smoothing is very small and not worth to recompute since 10 | * it will hardly be hit. So we use an upper bound for smoothing. 11 | * Only if the smoothing bin is hit, the actual smoothing is computed and resampled. 12 | * Author: Yuening Hu 13 | */ 14 | public class TreeTopicSamplerFastEstSortD extends TreeTopicSamplerSortD{ 15 | 16 | public TreeTopicSamplerFastEstSortD (int numberOfTopics, double alphaSum, int seed, boolean sort) { 17 | super(numberOfTopics, alphaSum, seed); 18 | 19 | if (sort) { 20 | this.topics = new TreeTopicModelFastEstSortW(this.numTopics, this.random); 21 | } else { 22 | this.topics = new TreeTopicModelFastEst(this.numTopics, this.random); 23 | } 24 | } 25 | 26 | /** 27 | * Use an upper bound for smoothing. Only if the smoothing 28 | * bin is hit, the actual smoothing is computed and resampled. 29 | */ 30 | public void sampleDoc(int doc_id) { 31 | DocData doc = this.data.get(doc_id); 32 | //System.out.println("doc " + doc_id); 33 | 34 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 35 | //int word = doc.tokens.getIndexAtPosition(ii); 36 | int word = doc.tokens.get(ii); 37 | 38 | this.changeTopic(doc_id, ii, word, -1, -1); 39 | 40 | //double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); 41 | double smoothing_mass_est = this.topics.smoothingEst.get(word); 42 | double topic_beta_mass = this.topics.computeTermTopicBetaSortD(doc.topicCounts, word); 43 | 44 | ArrayList topic_term_score = new ArrayList(); 45 | double topic_term_mass = this.topics.computeTopicTermSortD(this.alpha, doc.topicCounts, word, topic_term_score); 46 | 47 | double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; 48 | double sample = this.random.nextDouble(); 49 | //double sample = 0.5; 50 | sample *= norm_est; 51 | 52 | int new_topic = -1; 53 | int new_path = -1; 54 | 55 | int[] paths = this.topics.getWordPathIndexSet(word); 56 | 57 | // sample the smoothing bin 58 | if (sample < smoothing_mass_est) { 59 | double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); 60 | double norm = smoothing_mass + topic_beta_mass + topic_term_mass; 61 | sample /= norm_est; 62 | sample *= norm; 63 | if (sample < smoothing_mass) { 64 | for (int tt = 0; tt < this.numTopics; tt++) { 65 | for (int pp : paths) { 66 | double val = alpha[tt] * this.topics.getPathPrior(word, pp); 67 | val /= this.topics.getNormalizer(tt, pp); 68 | sample -= val; 69 | if (sample <= 0.0) { 70 | new_topic = tt; 71 | new_path = pp; 72 | break; 73 | } 74 | } 75 | if (new_topic >= 0) { 76 | break; 77 | } 78 | } 79 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); 80 | } else { 81 | sample -= smoothing_mass; 82 | } 83 | } else { 84 | sample -= smoothing_mass_est; 85 | } 86 | 87 | // sample topic beta bin 88 | if (new_topic < 0 && sample < topic_beta_mass) { 89 | for(int jj = 0; jj < doc.topicCounts.size(); jj++) { 90 | int[] current = doc.topicCounts.get(jj); 91 | int tt = current[0]; 92 | int count = current[1]; 93 | for(int pp : paths) { 94 | double val = count * this.topics.getPathPrior(word, pp); 95 | val /= this.topics.getNormalizer(tt, pp); 96 | sample -= val; 97 | if (sample <= 0.0) { 98 | new_topic = tt; 99 | new_path = pp; 100 | break; 101 | } 102 | } 103 | if (new_topic >= 0) { 104 | break; 105 | } 106 | } 107 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); 108 | } else { 109 | sample -= topic_beta_mass; 110 | } 111 | 112 | // sample topic term bin 113 | if (new_topic < 0) { 114 | for(int jj = 0; jj < topic_term_score.size(); jj++) { 115 | double[] tmp = topic_term_score.get(jj); 116 | int tt = (int) tmp[0]; 117 | int pp = (int) tmp[1]; 118 | double val = tmp[2]; 119 | sample -= val; 120 | if (sample <= 0.0) { 121 | new_topic = tt; 122 | new_path = pp; 123 | break; 124 | } 125 | } 126 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); 127 | } 128 | 129 | this.changeTopic(doc_id, ii, word, new_topic, new_path); 130 | } 131 | 132 | } 133 | 134 | /** 135 | * Before sampling start, compute smoothing upper bound for each word. 136 | */ 137 | public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { 138 | if(this.topics instanceof TreeTopicModelFastEst) { 139 | TreeTopicModelFastEst tmp = (TreeTopicModelFastEst) this.topics; 140 | tmp.computeSmoothingEst(this.alpha); 141 | } else if (this.topics instanceof TreeTopicModelFastEstSortW) { 142 | TreeTopicModelFastEstSortW tmp = (TreeTopicModelFastEstSortW) this.topics; 143 | tmp.computeSmoothingEst(this.alpha); 144 | } 145 | 146 | super.estimate(numIterations, outputFolder, outputInterval, topWords); 147 | } 148 | } -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSamplerFastSortD.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntIntHashMap; 6 | 7 | import java.io.Serializable; 8 | import java.util.ArrayList; 9 | import java.util.Arrays; 10 | 11 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 12 | import cc.mallet.util.Randoms; 13 | 14 | /** 15 | * This class defines a fast tree topic sampler, which calls the fast tree topic model. 16 | * (1) It divides the sampling into three bins: smoothing, topic beta, topic term. 17 | * as Yao and Mimno's paper, KDD, 2009. 18 | * (2) Each time the smoothing, topic beta, and topic term are recomputed. 19 | * It is faster, because, 20 | * (1) For topic term, only compute the one with non-zero paths (see TreeTopicModelFast). 21 | * (2) The normalizer is saved. 22 | * (3) Topic counts for each documents are ranked. 23 | * Author: Yuening Hu 24 | */ 25 | public class TreeTopicSamplerFastSortD extends TreeTopicSamplerSortD { 26 | 27 | public TreeTopicSamplerFastSortD (int numberOfTopics, double alphaSum, int seed, boolean sort) { 28 | super(numberOfTopics, alphaSum, seed); 29 | this.topics = new TreeTopicModelFast(this.numTopics, this.random); 30 | 31 | if (sort) { 32 | this.topics = new TreeTopicModelFastSortW(this.numTopics, this.random); 33 | } else { 34 | this.topics = new TreeTopicModelFast(this.numTopics, this.random); 35 | } 36 | } 37 | 38 | /** 39 | * For each word in a document, firstly covers its topic and path, then sample a 40 | * topic and path, and update. 41 | */ 42 | public void sampleDoc(int doc_id){ 43 | DocData doc = this.data.get(doc_id); 44 | //System.out.println("doc " + doc_id); 45 | 46 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 47 | //int word = doc.tokens.getIndexAtPosition(ii); 48 | int word = doc.tokens.get(ii); 49 | 50 | this.changeTopic(doc_id, ii, word, -1, -1); 51 | 52 | double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); 53 | double topic_beta_mass = this.topics.computeTermTopicBetaSortD(doc.topicCounts, word); 54 | 55 | ArrayList topic_term_score = new ArrayList (); 56 | double topic_term_mass = this.topics.computeTopicTermSortD(this.alpha, doc.topicCounts, word, topic_term_score); 57 | 58 | double norm = smoothing_mass + topic_beta_mass + topic_term_mass; 59 | double sample = this.random.nextDouble(); 60 | //double sample = 0.5; 61 | sample *= norm; 62 | 63 | int new_topic = -1; 64 | int new_path = -1; 65 | 66 | int[] paths = this.topics.getWordPathIndexSet(word); 67 | 68 | // sample the smoothing bin 69 | if (sample < smoothing_mass) { 70 | for (int tt = 0; tt < this.numTopics; tt++) { 71 | for (int pp : paths) { 72 | double val = alpha[tt] * this.topics.getPathPrior(word, pp); 73 | val /= this.topics.getNormalizer(tt, pp); 74 | sample -= val; 75 | if (sample <= 0.0) { 76 | new_topic = tt; 77 | new_path = pp; 78 | break; 79 | } 80 | } 81 | if (new_topic >= 0) { 82 | break; 83 | } 84 | } 85 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); 86 | } else { 87 | sample -= smoothing_mass; 88 | } 89 | 90 | // sample the topic beta bin 91 | if (new_topic < 0 && sample < topic_beta_mass) { 92 | 93 | for(int jj = 0; jj < doc.topicCounts.size(); jj++) { 94 | int[] current = doc.topicCounts.get(jj); 95 | int tt = current[0]; 96 | int count = current[1]; 97 | for(int pp : paths) { 98 | double val = count * this.topics.getPathPrior(word, pp); 99 | val /= this.topics.getNormalizer(tt, pp); 100 | sample -= val; 101 | if (sample <= 0.0) { 102 | new_topic = tt; 103 | new_path = pp; 104 | break; 105 | } 106 | } 107 | if (new_topic >= 0) { 108 | break; 109 | } 110 | } 111 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); 112 | } else { 113 | sample -= topic_beta_mass; 114 | } 115 | 116 | // sample the topic term bin 117 | if (new_topic < 0) { 118 | for(int jj = 0; jj < topic_term_score.size(); jj++) { 119 | double[] tmp = topic_term_score.get(jj); 120 | int tt = (int) tmp[0]; 121 | int pp = (int) tmp[1]; 122 | double val = tmp[2]; 123 | sample -= val; 124 | if (sample <= 0.0) { 125 | new_topic = tt; 126 | new_path = pp; 127 | break; 128 | } 129 | } 130 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); 131 | } 132 | 133 | this.changeTopic(doc_id, ii, word, new_topic, new_path); 134 | } 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSamplerHashD.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntHashSet; 6 | import gnu.trove.TIntIntHashMap; 7 | import gnu.trove.TIntIntIterator; 8 | 9 | import java.io.BufferedOutputStream; 10 | import java.io.BufferedReader; 11 | import java.io.DataInputStream; 12 | import java.io.File; 13 | import java.io.FileInputStream; 14 | import java.io.FileOutputStream; 15 | import java.io.IOException; 16 | import java.io.InputStreamReader; 17 | import java.io.PrintStream; 18 | import java.io.Serializable; 19 | import java.util.ArrayList; 20 | import java.util.Arrays; 21 | import java.util.zip.GZIPOutputStream; 22 | 23 | import cc.mallet.types.Dirichlet; 24 | import cc.mallet.types.FeatureSequence; 25 | import cc.mallet.types.Instance; 26 | import cc.mallet.types.InstanceList; 27 | import cc.mallet.util.Randoms; 28 | 29 | /** 30 | * This class defines the tree topic sampler, which loads the instances, 31 | * reports the topics, and leaves the sampler method as an abstract method, 32 | * which might be various for different methods. 33 | * Author: Yuening Hu 34 | */ 35 | public abstract class TreeTopicSamplerHashD implements TreeTopicSampler{ 36 | 37 | /** 38 | * This class defines the format of a document. 39 | */ 40 | public class DocData { 41 | TIntArrayList tokens; 42 | TIntArrayList topics; 43 | TIntArrayList paths; 44 | // sort 45 | TIntIntHashMap topicCounts; 46 | String docName; 47 | 48 | public DocData (String name, TIntArrayList tokens, TIntArrayList topics, 49 | TIntArrayList paths, TIntIntHashMap topicCounts) { 50 | this.docName = name; 51 | this.tokens = tokens; 52 | this.topics = topics; 53 | this.paths = paths; 54 | this.topicCounts = topicCounts; 55 | } 56 | 57 | public String toString() { 58 | String result = "***************\n"; 59 | result += docName + "\n"; 60 | 61 | result += "tokens: "; 62 | for (int jj = 0; jj < tokens.size(); jj++) { 63 | int index = tokens.get(jj); 64 | String word = vocab.get(index); 65 | result += word + " " + index + ", "; 66 | } 67 | 68 | result += "\ntopics: "; 69 | result += topics.toString(); 70 | 71 | result += "\npaths: "; 72 | result += paths.toString(); 73 | 74 | result += "\ntopicCounts: "; 75 | 76 | for(TIntIntIterator it = this.topicCounts.iterator(); it.hasNext(); ) { 77 | it.advance(); 78 | result += "Topic " + it.key() + ": " + it.value() + ", "; 79 | } 80 | result += "\n*****************\n"; 81 | return result; 82 | } 83 | } 84 | 85 | int numTopics; // Number of topics to be fit 86 | int numIterations; 87 | int startIter; 88 | Randoms random; 89 | double[] alpha; 90 | double alphaSum; 91 | TDoubleArrayList lhood; 92 | TDoubleArrayList iterTime; 93 | ArrayList vocab; 94 | ArrayList data; 95 | TreeTopicModel topics; 96 | TIntHashSet cons; 97 | //ArrayList stats; 98 | 99 | public TreeTopicSamplerHashD (int numberOfTopics, double alphaSum, int seed) { 100 | this.numTopics = numberOfTopics; 101 | this.random = new Randoms(seed); 102 | 103 | this.alphaSum = alphaSum; 104 | this.alpha = new double[numTopics]; 105 | Arrays.fill(alpha, alphaSum / numTopics); 106 | 107 | this.data = new ArrayList (); 108 | this.vocab = new ArrayList (); 109 | this.cons = new TIntHashSet(); 110 | 111 | this.lhood = new TDoubleArrayList(); 112 | this.iterTime = new TDoubleArrayList(); 113 | this.startIter = 0; 114 | //this.stats = new ArrayList (); 115 | 116 | // notice: this.topics is not initialized in this abstract class, 117 | // in each sub class, the topics variable is initialized differently. 118 | } 119 | 120 | /** 121 | * This function loads vocab, loads tree, and initialize parameters. 122 | */ 123 | public void initialize(String treeFiles, String hyperFile, String vocabFile) { 124 | this.loadVocab(vocabFile); 125 | this.topics.initializeParams(treeFiles, hyperFile, this.vocab); 126 | } 127 | 128 | public void setNumIterations(int iters) { 129 | this.numIterations = iters; 130 | } 131 | 132 | public int getNumIterations() { 133 | return this.numIterations; 134 | } 135 | 136 | /** 137 | * This function adds instances given the training data in mallet input data format. 138 | * For each token in a document, sample a topic and then sample a path based on prior. 139 | */ 140 | public void addInstances(InstanceList training) { 141 | boolean debug = false; 142 | int count = 0; 143 | for (Instance instance : training) { 144 | count++; 145 | FeatureSequence original_tokens = (FeatureSequence) instance.getData(); 146 | String name = instance.getName().toString(); 147 | 148 | // *** remained problem: keep topicCounts sorted 149 | TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); 150 | TIntIntHashMap topicCounts = new TIntIntHashMap (); 151 | TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); 152 | TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); 153 | 154 | for (int jj = 0; jj < original_tokens.getLength(); jj++) { 155 | String word = (String) original_tokens.getObjectAtPosition(jj); 156 | int token = this.vocab.indexOf(word); 157 | if(token != -1) { 158 | int topic = random.nextInt(numTopics); 159 | if(debug) { topic = count % numTopics; } 160 | tokens.add(token); 161 | topics.add(topic); 162 | topicCounts.adjustOrPutValue(topic, 1, 1); 163 | // sample a path for this topic 164 | int path_index = this.topics.initialize(token, topic); 165 | paths.add(path_index); 166 | } 167 | } 168 | 169 | DocData doc = new DocData(name, tokens, topics, paths, topicCounts); 170 | this.data.add(doc); 171 | 172 | //System.out.println(doc); 173 | } 174 | 175 | } 176 | 177 | /** 178 | * Resume instance states from the saved states file. 179 | */ 180 | public void resumeStates(InstanceList training, String statesFile) throws IOException{ 181 | FileInputStream statesfstream = new FileInputStream(statesFile); 182 | DataInputStream statesdstream = new DataInputStream(statesfstream); 183 | BufferedReader states = new BufferedReader(new InputStreamReader(statesdstream)); 184 | 185 | // reading topics, paths 186 | for (Instance instance : training) { 187 | FeatureSequence original_tokens = (FeatureSequence) instance.getData(); 188 | String name = instance.getName().toString(); 189 | 190 | // *** remained problem: keep topicCounts sorted 191 | TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); 192 | TIntIntHashMap topicCounts = new TIntIntHashMap (); 193 | TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); 194 | TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); 195 | 196 | // 197 | String statesLine = states.readLine(); 198 | myAssert(statesLine != null, "statesFile doesn't match with the training data"); 199 | statesLine = statesLine.trim(); 200 | String[] str = statesLine.split("\t"); 201 | 202 | int count = -1; 203 | for (int jj = 0; jj < original_tokens.getLength(); jj++) { 204 | String word = (String) original_tokens.getObjectAtPosition(jj); 205 | int token = this.vocab.indexOf(word); 206 | if(token != -1) { 207 | count++; 208 | String[] tp = str[count].split(":"); 209 | myAssert(tp.length == 2, "statesFile problem!"); 210 | int topic = Integer.parseInt(tp[0]); 211 | int path = Integer.parseInt(tp[1]); 212 | tokens.add(token); 213 | topics.add(topic); 214 | paths.add(path); 215 | topicCounts.adjustOrPutValue(topic, 1, 1); 216 | this.topics.changeCountOnly(topic, token, path, 1); 217 | } 218 | } 219 | if(count != -1) { 220 | count++; 221 | myAssert(str.length == count, "resume problem!"); 222 | } 223 | 224 | DocData doc = new DocData(name, tokens, topics, paths, topicCounts); 225 | this.data.add(doc); 226 | } 227 | states.close(); 228 | } 229 | 230 | /** 231 | * Resume lhood and iterTime from the saved lhood file. 232 | */ 233 | public void resumeLHood(String lhoodFile) throws IOException{ 234 | FileInputStream lhoodfstream = new FileInputStream(lhoodFile); 235 | DataInputStream lhooddstream = new DataInputStream(lhoodfstream); 236 | BufferedReader brLHood = new BufferedReader(new InputStreamReader(lhooddstream)); 237 | // the first line is the title 238 | String strLine = brLHood.readLine(); 239 | while ((strLine = brLHood.readLine()) != null) { 240 | strLine = strLine.trim(); 241 | String[] str = strLine.split("\t"); 242 | // iteration, likelihood, iter_time 243 | myAssert(str.length == 3, "lhood file problem!"); 244 | this.lhood.add(Double.parseDouble(str[1])); 245 | this.iterTime.add(Double.parseDouble(str[2])); 246 | } 247 | this.startIter = this.lhood.size(); 248 | if (this.startIter > this.numIterations) { 249 | System.out.println("Have already sampled " + this.numIterations + " iterations!"); 250 | System.exit(0); 251 | } 252 | System.out.println("Start sampling for iteration " + this.startIter); 253 | brLHood.close(); 254 | } 255 | 256 | /** 257 | * Resumes from the saved files. 258 | */ 259 | public void resume(InstanceList training, String resumeDir) { 260 | try { 261 | String statesFile = resumeDir + ".states"; 262 | resumeStates(training, statesFile); 263 | 264 | String lhoodFile = resumeDir + ".lhood"; 265 | resumeLHood(lhoodFile); 266 | } catch (IOException e) { 267 | System.out.println(e.getMessage()); 268 | } 269 | } 270 | 271 | /** 272 | * This function clears the topic and path assignments for some words: 273 | * (1) term option: only clears the topic and path for constraint words; 274 | * (2) doc option: clears the topic and path for documents which contain 275 | * at least one of the constraint words. 276 | */ 277 | public void clearTopicAssignments(String option, String consFile) { 278 | this.loadConstraints(consFile); 279 | if (this.cons == null || this.cons.size() <= 0) { 280 | return; 281 | } 282 | 283 | for(int dd = 0; dd < this.data.size(); dd++) { 284 | DocData doc = this.data.get(dd); 285 | Boolean flag = false; 286 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 287 | int word = doc.tokens.get(ii); 288 | if(this.cons.contains(word)) { 289 | if (option.equals("term")) { 290 | int topic = doc.topics.get(ii); 291 | int path = doc.paths.get(ii); 292 | // change the count for count and node_count in TopicTreeWalk 293 | this.topics.changeCountOnly(topic, word, path, -1); 294 | doc.topics.set(ii, -1); 295 | doc.paths.set(ii, -1); 296 | myAssert(doc.topicCounts.get(topic) >= 1, "clear topic assignments problem"); 297 | doc.topicCounts.adjustValue(topic, -1); 298 | } else if (option.equals("doc")) { 299 | flag = true; 300 | break; 301 | } 302 | } 303 | } 304 | if (flag) { 305 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 306 | int word = doc.tokens.get(ii); 307 | int topic = doc.topics.get(ii); 308 | int path = doc.paths.get(ii); 309 | this.topics.changeCountOnly(topic, word, path, -1); 310 | doc.topics.set(ii, -1); 311 | doc.paths.set(ii, -1); 312 | } 313 | doc.topicCounts.clear(); 314 | } 315 | } 316 | } 317 | 318 | /** 319 | * This function defines how to change a topic during the sampling process. 320 | * It handles the case where both new_topic and old_topic are "-1" (empty topic). 321 | */ 322 | public void changeTopic(int doc, int index, int word, int new_topic, int new_path) { 323 | DocData current_doc = this.data.get(doc); 324 | int old_topic = current_doc.topics.get(index); 325 | int old_path = current_doc.paths.get(index); 326 | 327 | if (old_topic != -1) { 328 | myAssert((new_topic == -1 && new_path == -1), "old_topic != -1 but new_topic != -1"); 329 | this.topics.changeCount(old_topic, word, old_path, -1); 330 | myAssert(current_doc.topicCounts.get(old_topic) > 0, "Something wrong in changTopic"); 331 | current_doc.topicCounts.adjustValue(old_topic, -1); 332 | current_doc.topics.set(index, -1); 333 | current_doc.paths.set(index, -1); 334 | } 335 | 336 | if (new_topic != -1) { 337 | myAssert((old_topic == -1 && old_path == -1), "new_topic != -1 but old_topic != -1"); 338 | this.topics.changeCount(new_topic, word, new_path, 1); 339 | current_doc.topicCounts.adjustOrPutValue(new_topic, 1, 1); 340 | current_doc.topics.set(index, new_topic); 341 | current_doc.paths.set(index, new_path); 342 | } 343 | } 344 | 345 | /** 346 | * This function defines the sampling process, computes the likelihood and running time, 347 | * and specifies when to save the states files. 348 | */ 349 | public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { 350 | // update parameters 351 | this.topics.updateParams(); 352 | for (int ii = this.startIter; ii <= numIterations; ii++) { 353 | //int[] tmpstats = {0, 0, 0, 0}; 354 | //this.stats.add(tmpstats); 355 | long starttime = System.currentTimeMillis(); 356 | //System.out.println("Iter " + ii); 357 | for (int dd = 0; dd < this.data.size(); dd++) { 358 | this.sampleDoc(dd); 359 | if (dd > 0 && dd % 1000 == 0) { 360 | System.out.println("Sampled " + dd + " documents."); 361 | } 362 | } 363 | 364 | double totaltime = (double)(System.currentTimeMillis() - starttime) / 1000; 365 | double lhood = this.lhood(); 366 | this.lhood.add(lhood); 367 | this.iterTime.add(totaltime); 368 | 369 | String tmp = "Iteration " + ii; 370 | tmp += " likelihood " + lhood; 371 | tmp += " totaltime " + totaltime; 372 | System.out.println(tmp); 373 | 374 | if ((ii > 0 && ii % outputInterval == 0) || ii == numIterations) { 375 | try { 376 | this.report(outputFolder, topWords); 377 | } catch (IOException e) { 378 | System.out.println(e.getMessage()); 379 | } 380 | } 381 | } 382 | } 383 | 384 | /** 385 | * The function computes the document likelihood. 386 | */ 387 | public double docLHood() { 388 | int docNum = this.data.size(); 389 | 390 | double val = 0.0; 391 | val += Dirichlet.logGamma(this.alphaSum) * docNum; 392 | double tmp = 0.0; 393 | for (int tt = 0; tt < this.numTopics; tt++) { 394 | tmp += Dirichlet.logGamma(this.alpha[tt]); 395 | } 396 | val -= tmp * docNum; 397 | for (int dd = 0; dd < docNum; dd++) { 398 | DocData doc = this.data.get(dd); 399 | for (int tt = 0; tt < this.numTopics; tt++) { 400 | val += Dirichlet.logGamma(this.alpha[tt] + doc.topicCounts.get(tt)); 401 | } 402 | val -= Dirichlet.logGamma(this.alphaSum + doc.topics.size()); 403 | } 404 | return val; 405 | } 406 | 407 | /** 408 | * This function returns the likelihood. 409 | */ 410 | public double lhood() { 411 | return this.docLHood() + this.topics.topicLHood(); 412 | } 413 | 414 | /** 415 | * This function reports the detected topics, the documents topics, 416 | * and saves states file and lhood file. 417 | */ 418 | public void report(String outputDir, int topWords) throws IOException { 419 | 420 | String topicKeysFile = outputDir + ".topics"; 421 | this.printTopWords(new File(topicKeysFile), topWords); 422 | 423 | String docTopicsFile = outputDir + ".docs"; 424 | this.printDocumentTopics(new File(docTopicsFile)); 425 | 426 | String stateFile = outputDir + ".states"; 427 | this.printState (new File(stateFile)); 428 | 429 | String statsFile = outputDir + ".lhood"; 430 | this.printStats (new File(statsFile)); 431 | } 432 | 433 | /** 434 | * This function prints the topic words of each topic. 435 | */ 436 | public void printTopWords(File file, int numWords) throws IOException { 437 | PrintStream out = new PrintStream (file); 438 | out.print(displayTopWords(numWords)); 439 | out.close(); 440 | } 441 | 442 | /** 443 | * By implementing the comparable interface, this function ranks the words 444 | * in each topic, and returns the top words for each topic. 445 | */ 446 | public String displayTopWords (int numWords) { 447 | 448 | class WordProb implements Comparable { 449 | int wi; 450 | double p; 451 | public WordProb (int wi, double p) { this.wi = wi; this.p = p; } 452 | public final int compareTo (Object o2) { 453 | if (p > ((WordProb)o2).p) 454 | return -1; 455 | else if (p == ((WordProb)o2).p) 456 | return 0; 457 | else return 1; 458 | } 459 | } 460 | 461 | StringBuilder out = new StringBuilder(); 462 | int numPaths = this.topics.getPathNum(); 463 | //System.out.println(numPaths); 464 | 465 | for (int tt = 0; tt < this.numTopics; tt++){ 466 | String tmp = "\n--------------\nTopic " + tt + "\n------------------------\n"; 467 | //System.out.print(tmp); 468 | out.append(tmp); 469 | WordProb[] wp = new WordProb[numPaths]; 470 | for (int pp = 0; pp < numPaths; pp++){ 471 | int ww = this.topics.getWordFromPath(pp); 472 | double val = this.topics.computeTopicPathProb(tt, ww, pp); 473 | wp[pp] = new WordProb(pp, val); 474 | } 475 | Arrays.sort(wp); 476 | for (int ii = 0; ii < wp.length; ii++){ 477 | if(ii >= numWords) { 478 | break; 479 | } 480 | int pp = wp[ii].wi; 481 | int ww = this.topics.getWordFromPath(pp); 482 | //tmp = wp[ii].p + "\t" + this.vocab.lookupObject(ww) + "\n"; 483 | tmp = wp[ii].p + "\t" + this.vocab.get(ww) + "\n"; 484 | //System.out.print(tmp); 485 | out.append(tmp); 486 | } 487 | } 488 | return out.toString(); 489 | } 490 | 491 | /** 492 | * Prints the index, original document dir, topic counts for each document. 493 | */ 494 | public void printDocumentTopics (File file) throws IOException { 495 | PrintStream out = new PrintStream (file); 496 | 497 | for (int dd = 0; dd < this.data.size(); dd++) { 498 | DocData doc = this.data.get(dd); 499 | String tmp = dd + "\t" + doc.docName + "\t"; 500 | for (int tt : doc.topicCounts.keys()) { 501 | int count = doc.topicCounts.get(tt); 502 | tmp += tt + ":" + count + "\t"; 503 | } 504 | out.print(tmp + "\n"); 505 | } 506 | out.close(); 507 | } 508 | 509 | /** 510 | * Prints the topic and path of each word for all documents. 511 | */ 512 | public void printState (File file) throws IOException { 513 | //PrintStream out = 514 | // new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file)))); 515 | PrintStream out = new PrintStream(file); 516 | 517 | for (int dd = 0; dd < this.data.size(); dd++) { 518 | DocData doc = this.data.get(dd); 519 | String tmp = ""; 520 | for (int ww = 0; ww < doc.topics.size(); ww++) { 521 | int topic = doc.topics.get(ww); 522 | int path = doc.paths.get(ww); 523 | tmp += topic + ":" + path + "\t"; 524 | } 525 | out.println(tmp); 526 | } 527 | out.close(); 528 | } 529 | 530 | /** 531 | * Prints likelihood and iter time. 532 | */ 533 | public void printStats (File file) throws IOException { 534 | PrintStream out = new PrintStream (file); 535 | String tmp = "Iteration\t\tlikelihood\titer_time\n"; 536 | out.print(tmp); 537 | for (int iter = 0; iter < this.lhood.size(); iter++) { 538 | tmp = iter + "\t" + this.lhood.get(iter) + "\t" + this.iterTime.get(iter); 539 | //int[] tmpstats = this.stats.get(iter); 540 | //for(int ii = 0; ii < tmpstats.length; ii++) { 541 | // tmp += "\t" + tmpstats[ii]; 542 | //} 543 | out.println(tmp); 544 | } 545 | out.close(); 546 | } 547 | 548 | /** 549 | * Load vocab 550 | */ 551 | public void loadVocab(String vocabFile) { 552 | 553 | try { 554 | FileInputStream infstream = new FileInputStream(vocabFile); 555 | DataInputStream in = new DataInputStream(infstream); 556 | BufferedReader br = new BufferedReader(new InputStreamReader(in)); 557 | 558 | String strLine; 559 | //Read File Line By Line 560 | while ((strLine = br.readLine()) != null) { 561 | strLine = strLine.trim(); 562 | String[] str = strLine.split("\t"); 563 | if (str.length > 1) { 564 | this.vocab.add(str[1]); 565 | } else { 566 | System.out.println("Error! " + strLine); 567 | } 568 | } 569 | in.close(); 570 | 571 | } catch (IOException e) { 572 | System.out.println("No vocab file Found!"); 573 | } 574 | 575 | } 576 | 577 | /** 578 | * Load constraints 579 | */ 580 | public void loadConstraints(String consFile) { 581 | try { 582 | FileInputStream infstream = new FileInputStream(consFile); 583 | DataInputStream in = new DataInputStream(infstream); 584 | BufferedReader br = new BufferedReader(new InputStreamReader(in)); 585 | 586 | String strLine; 587 | //Read File Line By Line 588 | while ((strLine = br.readLine()) != null) { 589 | strLine = strLine.trim(); 590 | String[] str = strLine.split("\t"); 591 | if (str.length > 1) { 592 | // str[0] is either "MERGE_" or "SPLIT_", not a word 593 | for(int ii = 1; ii < str.length; ii++) { 594 | int word = this.vocab.indexOf(str[ii]); 595 | myAssert(word >= 0, "Constraint words not found in vocab: " + str[ii]); 596 | cons.add(word); 597 | } 598 | this.vocab.add(str[1]); 599 | } else { 600 | System.out.println("Error! " + strLine); 601 | } 602 | } 603 | in.close(); 604 | 605 | } catch (IOException e) { 606 | System.out.println("No vocab file Found!"); 607 | } 608 | 609 | } 610 | 611 | /** 612 | * For testing~~ 613 | */ 614 | public static void myAssert(boolean flag, String info) { 615 | if(!flag) { 616 | System.out.println(info); 617 | System.exit(0); 618 | } 619 | } 620 | 621 | abstract void sampleDoc(int doc); 622 | } 623 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TreeTopicSamplerNaive.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TDoubleArrayList; 4 | import gnu.trove.TIntArrayList; 5 | import gnu.trove.TIntDoubleHashMap; 6 | import gnu.trove.TIntDoubleIterator; 7 | import gnu.trove.TIntHashSet; 8 | import gnu.trove.TIntIntHashMap; 9 | import gnu.trove.TIntIntIterator; 10 | import gnu.trove.TIntObjectHashMap; 11 | import gnu.trove.TIntObjectIterator; 12 | import gnu.trove.TObjectIntHashMap; 13 | 14 | import java.io.BufferedOutputStream; 15 | import java.io.BufferedReader; 16 | import java.io.DataInputStream; 17 | import java.io.File; 18 | import java.io.FileInputStream; 19 | import java.io.FileOutputStream; 20 | import java.io.IOException; 21 | import java.io.InputStreamReader; 22 | import java.io.PrintStream; 23 | import java.io.PrintWriter; 24 | import java.io.Serializable; 25 | import java.util.ArrayList; 26 | import java.util.Arrays; 27 | import java.util.zip.GZIPOutputStream; 28 | 29 | import cc.mallet.types.Alphabet; 30 | import cc.mallet.types.Dirichlet; 31 | import cc.mallet.types.FeatureSequence; 32 | import cc.mallet.types.Instance; 33 | import cc.mallet.types.InstanceList; 34 | import cc.mallet.types.LabelAlphabet; 35 | import cc.mallet.util.Randoms; 36 | 37 | /** 38 | * This class defines a naive tree topic sampler. 39 | * It calls the naive tree topic model. 40 | * Author: Yuening Hu 41 | */ 42 | public class TreeTopicSamplerNaive extends TreeTopicSamplerHashD { 43 | 44 | public TreeTopicSamplerNaive (int numberOfTopics, double alphaSum) { 45 | this (numberOfTopics, alphaSum, 0); 46 | } 47 | 48 | public TreeTopicSamplerNaive (int numberOfTopics, double alphaSum, int seed) { 49 | super (numberOfTopics, alphaSum, seed); 50 | this.topics = new TreeTopicModelNaive(this.numTopics, this.random); 51 | } 52 | 53 | /** 54 | * For each word in a document, firstly covers its topic and path, then sample a 55 | * topic and path, and update. 56 | */ 57 | public void sampleDoc(int doc_id){ 58 | DocData doc = this.data.get(doc_id); 59 | //System.out.println("doc " + doc_id); 60 | 61 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 62 | //int word = doc.tokens.getIndexAtPosition(ii); 63 | int word = doc.tokens.get(ii); 64 | 65 | this.changeTopic(doc_id, ii, word, -1, -1); 66 | ArrayList topic_term_score = new ArrayList(); 67 | double norm = this.topics.computeTopicTerm(this.alpha, doc.topicCounts, word, topic_term_score); 68 | //System.out.println(norm); 69 | 70 | int new_topic = -1; 71 | int new_path = -1; 72 | 73 | double sample = this.random.nextDouble(); 74 | //double sample = 0.8; 75 | sample *= norm; 76 | 77 | for(int jj = 0; jj < topic_term_score.size(); jj++) { 78 | double[] tmp = topic_term_score.get(jj); 79 | int tt = (int) tmp[0]; 80 | int pp = (int) tmp[1]; 81 | double val = tmp[2]; 82 | sample -= val; 83 | if (sample <= 0.0) { 84 | new_topic = tt; 85 | new_path = pp; 86 | break; 87 | } 88 | } 89 | 90 | myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling!"); 91 | 92 | this.changeTopic(doc_id, ii, word, new_topic, new_path); 93 | } 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/TwoIntHashMap.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntHash; 4 | import gnu.trove.TIntObjectHashMap; 5 | 6 | /** 7 | * This class defines a two level hashmap, so a value will be indexed by two keys. 8 | * Author: Travis Brown 9 | */ 10 | public abstract class TwoIntHashMap { 11 | protected TIntObjectHashMap data; 12 | 13 | public TwoIntHashMap() { 14 | this.data = new TIntObjectHashMap(); 15 | } 16 | 17 | /** 18 | * Return the HashMap indexed by the first key. 19 | */ 20 | public T get(int key1) { 21 | if (this.contains(key1)) { 22 | return this.data.get(key1); 23 | } 24 | return null; 25 | } 26 | 27 | /** 28 | * Return the first key set. 29 | */ 30 | public int[] getKey1Set() { 31 | return this.data.keys(); 32 | } 33 | 34 | /** 35 | * Check whether key1 is contained in the first key set or not. 36 | */ 37 | public boolean contains(int key1) { 38 | return this.data.contains(key1); 39 | } 40 | 41 | /** 42 | * Check whether the key pair (key1, key2) is contained or not. 43 | */ 44 | public boolean contains(int key1, int key2) { 45 | if (this.data.contains(key1)) { 46 | return this.data.get(key1).contains(key2); 47 | } else { 48 | return false; 49 | } 50 | } 51 | } 52 | 53 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/testFast.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntDoubleHashMap; 5 | import gnu.trove.TIntIntHashMap; 6 | 7 | import java.io.File; 8 | import java.util.ArrayList; 9 | 10 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 11 | import cc.mallet.types.InstanceList; 12 | import junit.framework.TestCase; 13 | 14 | /** 15 | * This class tests the fast sampler. 16 | * Author: Yuening Hu 17 | */ 18 | public class testFast extends TestCase{ 19 | 20 | public TreeTopicSamplerFast Initialize() { 21 | 22 | String inputFile = "input/toy/toy-topic-input.mallet"; 23 | String treeFiles = "input/toy/toy.wn.*"; 24 | String hyperFile = "input/toy/tree_hyperparams"; 25 | String vocabFile = "input/toy/toy.voc"; 26 | int numTopics = 3; 27 | double alpha_sum = 0.3; 28 | int randomSeed = 0; 29 | int numIterations = 10; 30 | 31 | // String inputFile = "../input/synthetic-topic-input.mallet"; 32 | // String treeFiles = "../synthetic/synthetic_empty.wn.*"; 33 | // String hyperFile = "../synthetic/tree_hyperparams"; 34 | // String vocabFile = "../synthetic/synthetic.voc"; 35 | // int numTopics = 5; 36 | // double alpha_sum = 0.5; 37 | // int randomSeed = 0; 38 | // int numIterations = 10; 39 | 40 | InstanceList ilist = InstanceList.load (new File(inputFile)); 41 | System.out.println ("Data loaded."); 42 | 43 | TreeTopicSamplerFast topicModel = null; 44 | topicModel = new TreeTopicSamplerFast(numTopics, alpha_sum, randomSeed, false); 45 | 46 | topicModel.initialize(treeFiles, hyperFile, vocabFile); 47 | topicModel.addInstances(ilist); 48 | 49 | topicModel.setNumIterations(numIterations); 50 | 51 | return topicModel; 52 | } 53 | 54 | public void testUpdateParams() { 55 | TreeTopicSamplerFast topicModel = this.Initialize(); 56 | topicModel.topics.updateParams(); 57 | 58 | for(int dd = 0; dd < topicModel.data.size(); dd++) { 59 | System.out.println(topicModel.data.get(dd)); 60 | } 61 | 62 | System.out.println("**************\nNormalizer"); 63 | int numPaths = topicModel.topics.pathToWord.size(); 64 | for(int tt = 0; tt < topicModel.numTopics; tt++) { 65 | for(int pp = 0; pp < numPaths; pp++) { 66 | System.out.println("topic " + tt + " path " + pp + " normalizer " + topicModel.topics.normalizer.get(tt, pp)); 67 | } 68 | } 69 | 70 | System.out.println("**************\nNon zero paths"); 71 | for(int ww : topicModel.topics.nonZeroPaths.keys()) { 72 | for(int tt : topicModel.topics.nonZeroPaths.get(ww).getKey1Set()) { 73 | for(int pp : topicModel.topics.nonZeroPaths.get(ww).get(tt).keys()) { 74 | System.out.println("word " + ww + " topic " + tt + " path " + pp + " " + topicModel.topics.nonZeroPaths.get(ww).get(tt, pp)); 75 | } 76 | } 77 | } 78 | } 79 | 80 | public void testUpdatePathmaskedCount() { 81 | TreeTopicSamplerFast topicModel = this.Initialize(); 82 | topicModel.topics.updateParams(); 83 | int numPaths = topicModel.topics.pathToWord.size(); 84 | 85 | TreeTopicModelFast topics = (TreeTopicModelFast)topicModel.topics; 86 | 87 | for (int ww : topics.nonZeroPaths.keys()) { 88 | for(int tt : topics.nonZeroPaths.get(ww).getKey1Set()) { 89 | for(int pp : topicModel.topics.nonZeroPaths.get(ww).get(tt).keys()) { 90 | TIntArrayList path_nodes = topics.wordPaths.get(ww, pp); 91 | int parent = path_nodes.get(path_nodes.size() - 2); 92 | int child = path_nodes.get(path_nodes.size() - 1); 93 | 94 | int mask = topics.nonZeroPaths.get(ww).get(tt, pp) - topics.traversals.get(tt).getCount(parent, child); 95 | 96 | System.out.println("*************************"); 97 | System.out.println("Topic " + tt + " Word " + ww + " path " + pp); 98 | String tmp = "["; 99 | for (int ii : path_nodes.toNativeArray()) { 100 | tmp += " " + ii; 101 | } 102 | System.out.println("Real path " + tmp + " ]"); 103 | System.out.println("Real count " + topics.traversals.get(tt).getCount(parent, child)); 104 | System.out.println("Masked count " + topics.nonZeroPaths.get(ww).get(tt, pp)); 105 | System.out.println("Masekd count " + Integer.toBinaryString(topics.nonZeroPaths.get(ww).get(tt, pp))); 106 | System.out.println("*************************"); 107 | } 108 | } 109 | } 110 | } 111 | 112 | public void testChangeTopic() { 113 | TreeTopicSamplerFast topicModel = this.Initialize(); 114 | topicModel.topics.updateParams(); 115 | TreeTopicModelFast topics = (TreeTopicModelFast)topicModel.topics; 116 | //for(int dd = 0; dd < topicModel.data.size(); dd++){ 117 | for(int dd = 0; dd < 1; dd++){ 118 | DocData doc = topicModel.data.get(dd); 119 | for(int ii = 0; ii < doc.tokens.size(); ii++) { 120 | int word = doc.tokens.get(ii); 121 | int old_topic = doc.topics.get(ii); 122 | int old_path = doc.paths.get(ii); 123 | TIntArrayList path_nodes = topicModel.topics.wordPaths.get(word, old_path); 124 | int node = path_nodes.get(0); 125 | int leaf = path_nodes.get(path_nodes.size() - 1); 126 | int total = 0; 127 | for(int nn : topics.traversals.get(word).counts.get(node).keys()){ 128 | total += topics.traversals.get(word).getCount(node, nn); 129 | } 130 | 131 | assertTrue(topics.traversals.get(word).getNodeCount(node) == total); 132 | 133 | System.out.println("*************************"); 134 | System.out.println("old topic " + old_topic + " word " + word); 135 | System.out.println("old normalizer " + topics.normalizer.get(old_topic, old_path)); 136 | System.out.println("old root count " + topics.traversals.get(old_topic).getNodeCount(node) + " " + total); 137 | System.out.println("old non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(old_topic, old_path))); 138 | System.out.println("old leaf count " + topics.traversals.get(old_topic).getNodeCount(leaf)); 139 | 140 | topicModel.changeTopic(dd, ii, word, -1, -1); 141 | 142 | total = 0; 143 | for(int nn : topics.traversals.get(old_topic).counts.get(node).keys()){ 144 | total += topics.traversals.get(old_topic).getCount(node, nn); 145 | } 146 | assertTrue(topics.traversals.get(old_topic).getNodeCount(node) == total); 147 | System.out.println("*************************"); 148 | System.out.println("updated old topic " + old_topic + " word " + word); 149 | System.out.println("updated old normalizer " + topics.normalizer.get(old_topic, old_path)); 150 | System.out.println("updated old root count " + topics.traversals.get(old_topic).getNodeCount(node) + " " + total); 151 | System.out.println("updated old non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(old_topic, old_path))); 152 | System.out.println("updated old leaf count " + topics.traversals.get(old_topic).getNodeCount(leaf)); 153 | 154 | 155 | int new_topic = topicModel.numTopics - old_topic - 1; 156 | int new_path = old_path; 157 | 158 | total = 0; 159 | for(int nn : topics.traversals.get(new_topic).counts.get(node).keys()){ 160 | total += topics.traversals.get(new_topic).getCount(node, nn); 161 | } 162 | assertTrue(topics.traversals.get(new_topic).getNodeCount(node) == total); 163 | 164 | System.out.println("*************************"); 165 | System.out.println("new topic " + new_topic + " word " + word); 166 | System.out.println("new normalizer " + topics.normalizer.get(new_topic, new_path)); 167 | System.out.println("new root count " + topics.traversals.get(new_topic).getNodeCount(node) + " " + total); 168 | System.out.println("new non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(new_topic, new_path))); 169 | System.out.println("new leaf count " + topics.traversals.get(new_topic).getNodeCount(leaf)); 170 | 171 | topicModel.changeTopic(dd, ii, word, new_topic, new_path); 172 | 173 | 174 | total = 0; 175 | for(int nn : topics.traversals.get(new_topic).counts.get(node).keys()){ 176 | total += topics.traversals.get(new_topic).getCount(node, nn); 177 | } 178 | assertTrue(topics.traversals.get(new_topic).getNodeCount(node) == total); 179 | System.out.println("*************************"); 180 | System.out.println("updated new topic " + new_topic + " word " + word); 181 | System.out.println("updated new normalizer " + topics.normalizer.get(new_topic, new_path)); 182 | System.out.println("updated new root count " + topics.traversals.get(new_topic).getNodeCount(node) + " " + total); 183 | System.out.println("updated new non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(new_topic, new_path))); 184 | System.out.println("updated new leaf count " + topics.traversals.get(new_topic).getNodeCount(leaf)); 185 | 186 | System.out.println("*************************\n"); 187 | } 188 | } 189 | } 190 | 191 | public void testBinValues() { 192 | TreeTopicSamplerFast topicModelFast = this.Initialize(); 193 | topicModelFast.topics.updateParams(); 194 | 195 | TreeTopicSamplerNaive topicModelNaive = testNaive.Initialize(); 196 | topicModelNaive.topics.updateParams(); 197 | 198 | //for(int dd = 0; dd < topicModelFast.data.size(); dd++){ 199 | for(int dd = 0; dd < 1; dd++){ 200 | DocData doc = topicModelFast.data.get(dd); 201 | DocData doc1 = topicModelNaive.data.get(dd); 202 | 203 | //for(int ii = 0; ii < doc.tokens.size(); ii++) { 204 | for(int ii = 4; ii < 5; ii++) { 205 | int word = doc.tokens.get(ii); 206 | int topic = doc.topics.get(ii); 207 | int path = doc.paths.get(ii); 208 | 209 | double smoothing = topicModelFast.callComputeTermSmoothing(word); 210 | double topicbeta = topicModelFast.callComputeTermTopicBeta(doc.topicCounts, word); 211 | ArrayList dict = new ArrayList(); 212 | double topictermscore = topicModelFast.topics.computeTopicTerm(topicModelFast.alpha, 213 | doc.topicCounts, word, dict); 214 | double norm = smoothing + topicbeta + topictermscore; 215 | 216 | double smoothing1 = topicModelFast.computeTopicSmoothTest(word); 217 | double topicbeta1 = topicModelFast.computeTopicTermBetaTest(doc.topicCounts, word); 218 | HIntIntDoubleHashMap dict1 = new HIntIntDoubleHashMap(); 219 | double topictermscore1 = topicModelFast.computeTopicTermScoreTest(topicModelFast.alpha, 220 | doc.topicCounts, word, dict1); 221 | double norm1 = smoothing1 + topicbeta1 + topictermscore1; 222 | 223 | System.out.println("*************"); 224 | System.out.println("Index " + ii); 225 | System.out.println(smoothing + " " + smoothing1); 226 | System.out.println(topicbeta + " " + topicbeta1); 227 | System.out.println(topictermscore + " " + topictermscore1); 228 | 229 | ArrayList dict2 = new ArrayList(); 230 | double norm2 = topicModelFast.computeTopicTermTest(topicModelNaive.alpha, doc.topicCounts, word, dict2); 231 | 232 | ArrayList dict3 = new ArrayList(); 233 | double norm3 = topicModelNaive.topics.computeTopicTerm(topicModelNaive.alpha, doc.topicCounts, word, dict3); 234 | 235 | System.out.println(norm + " " + norm1 + " " + norm2 + " " + norm3); 236 | // if (norm1 != norm2) { 237 | // System.out.println(norm + " " + norm1 + " " + norm2 + " " + norm3 ); 238 | // } 239 | System.out.println("*************"); 240 | assert(norm == norm1); 241 | assert(1 == 0); 242 | } 243 | } 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tree/testNaive.java: -------------------------------------------------------------------------------- 1 | package edu.umd.umiacs.itm.tree; 2 | 3 | import gnu.trove.TIntArrayList; 4 | import gnu.trove.TIntIntHashMap; 5 | 6 | import java.io.File; 7 | import java.util.ArrayList; 8 | 9 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD.DocData; 10 | import cc.mallet.types.Alphabet; 11 | import cc.mallet.types.FeatureSequence; 12 | import cc.mallet.types.Instance; 13 | import cc.mallet.types.InstanceList; 14 | import junit.framework.TestCase; 15 | 16 | /** 17 | * This class tests the naive sampler. 18 | * Author: Yuening Hu 19 | */ 20 | public class testNaive extends TestCase{ 21 | 22 | public static TreeTopicSamplerNaive Initialize() { 23 | 24 | String inputFile = "input/toy/toy-topic-input.mallet"; 25 | String treeFiles = "input/toy/toy.wn.*"; 26 | String hyperFile = "input/toy/tree_hyperparams"; 27 | String vocabFile = "input/toy/toy.voc"; 28 | int numTopics = 3; 29 | double alpha_sum = 0.3; 30 | int randomSeed = 0; 31 | int numIterations = 10; 32 | 33 | // String inputFile = "../input/synthetic-topic-input.mallet"; 34 | // String treeFiles = "../synthetic/synthetic.wn.*"; 35 | // String hyperFile = "../synthetic/tree_hyperparams"; 36 | // String vocabFile = "../synthetic/synthetic.voc"; 37 | // int numTopics = 5; 38 | // double alpha_sum = 0.5; 39 | // int randomSeed = 0; 40 | // int numIterations = 10; 41 | 42 | InstanceList ilist = InstanceList.load (new File(inputFile)); 43 | System.out.println ("Data loaded."); 44 | 45 | TreeTopicSamplerNaive topicModel = null; 46 | topicModel = new TreeTopicSamplerNaive(numTopics, alpha_sum, randomSeed); 47 | 48 | topicModel.initialize(treeFiles, hyperFile, vocabFile); 49 | topicModel.addInstances(ilist); 50 | 51 | topicModel.setNumIterations(numIterations); 52 | 53 | return topicModel; 54 | } 55 | 56 | public void testChangeTopic() { 57 | TreeTopicSamplerNaive topicModel = this.Initialize(); 58 | for (int dd = 0; dd < topicModel.data.size(); dd++ ) { 59 | DocData doc = topicModel.data.get(dd); 60 | for (int index = 0; index < doc.tokens.size(); index++) { 61 | int word = doc.tokens.get(index); 62 | int old_topic = doc.topics.get(index); 63 | int old_path = doc.paths.get(index); 64 | int old_count = doc.topicCounts.get(old_topic); 65 | 66 | topicModel.changeTopic(dd, index, word, -1, -1); 67 | assertTrue(doc.topics.get(index) == -1); 68 | assertTrue(doc.paths.get(index) == -1); 69 | assertTrue(doc.topicCounts.get(old_topic) == old_count-1); 70 | 71 | int new_topic = topicModel.numTopics - old_topic - 1; 72 | int new_path = old_path; 73 | int new_count = doc.topicCounts.get(new_topic); 74 | topicModel.changeTopic(dd, index, word, new_topic, new_path); 75 | 76 | assertTrue(doc.topics.get(index) == new_topic); 77 | assertTrue(doc.paths.get(index) == new_path); 78 | assertTrue(doc.topicCounts.get(new_topic) == new_count+1); 79 | } 80 | } 81 | } 82 | 83 | public void testChangCount() { 84 | TreeTopicSamplerNaive topicModel = this.Initialize(); 85 | for (int dd = 0; dd < topicModel.data.size(); dd++ ) { 86 | DocData doc = topicModel.data.get(dd); 87 | 88 | for (int index = 0; index < doc.tokens.size(); index++) { 89 | int word = doc.tokens.get(index); 90 | int old_topic = doc.topics.get(index); 91 | int old_path = doc.paths.get(index); 92 | 93 | TopicTreeWalk tw = topicModel.topics.traversals.get(old_topic); 94 | TIntArrayList path_nodes = topicModel.topics.wordPaths.get(word, old_path); 95 | 96 | int[] old_count = new int[path_nodes.size() - 1]; 97 | for(int nn = 0; nn < path_nodes.size() - 1; nn++) { 98 | int parent = path_nodes.get(nn); 99 | int child = path_nodes.get(nn+1); 100 | old_count[nn] = tw.getCount(parent, child); 101 | } 102 | 103 | int[] old_node_count = new int[path_nodes.size()]; 104 | for(int nn = 0; nn < path_nodes.size(); nn++) { 105 | int node = path_nodes.get(nn); 106 | old_node_count[nn] = tw.getNodeCount(node); 107 | } 108 | 109 | int inc = 1; 110 | tw.changeCount(path_nodes, inc); 111 | 112 | for(int nn = 0; nn < path_nodes.size() - 1; nn++) { 113 | int parent = path_nodes.get(nn); 114 | int child = path_nodes.get(nn+1); 115 | assertTrue(old_count[nn] == tw.getCount(parent, child) - inc); 116 | } 117 | 118 | for(int nn = 0; nn < path_nodes.size(); nn++) { 119 | int node = path_nodes.get(nn); 120 | assertTrue(old_node_count[nn] == tw.getNodeCount(node) - inc); 121 | } 122 | 123 | } 124 | } 125 | 126 | } 127 | 128 | public void testComputeTermScore() { 129 | TreeTopicSamplerNaive topicModel = this.Initialize(); 130 | for (int dd = 0; dd < topicModel.data.size(); dd++ ) { 131 | DocData doc = topicModel.data.get(dd); 132 | System.out.println("------------" + dd + "------------"); 133 | for (int index = 0; index < doc.tokens.size(); index++) { 134 | int word = doc.tokens.get(index); 135 | 136 | //topicModel.changeTopic(dd, index, word, -1, -1); 137 | 138 | ArrayList topic_term_score = new ArrayList(); 139 | double norm = topicModel.topics.computeTopicTerm(topicModel.alpha, doc.topicCounts, word, topic_term_score); 140 | System.out.println(norm); 141 | 142 | for(int jj = 0; jj < topic_term_score.size(); jj++) { 143 | double[] tmp = topic_term_score.get(jj); 144 | int tt = (int) tmp[0]; 145 | int pp = (int) tmp[1]; 146 | double val = tmp[2]; 147 | System.out.println(tt + " " + pp + " " + val); 148 | } 149 | } 150 | } 151 | } 152 | 153 | } 154 | -------------------------------------------------------------------------------- /src/main/java/edu/umd/umiacs/itm/tui/ITMVectors2Topics.java: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept. 2 | This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). 3 | http://www.cs.umass.edu/~mccallum/mallet 4 | This software is provided under the terms of the Common Public License, 5 | version 1.0, as published by http://www.opensource.org. For further 6 | information, see the file `LICENSE' included with this distribution. */ 7 | 8 | package edu.umd.umiacs.itm.tui; 9 | 10 | import cc.mallet.util.CommandOption; 11 | import cc.mallet.util.Randoms; 12 | import cc.mallet.types.Alphabet; 13 | import cc.mallet.types.Dirichlet; 14 | import cc.mallet.types.InstanceList; 15 | import cc.mallet.types.FeatureSequence; 16 | import cc.mallet.types.LabelSequence; 17 | import cc.mallet.topics.*; 18 | import edu.umd.umiacs.itm.tree.GenerateVocab; 19 | import edu.umd.umiacs.itm.tree.PriorTree; 20 | import edu.umd.umiacs.itm.tree.TopicSampler; 21 | import edu.umd.umiacs.itm.tree.TreeTopicSampler; 22 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerHashD; 23 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerFast; 24 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerFastEst; 25 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerFastEstSortD; 26 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerFastSortD; 27 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerNaive; 28 | import edu.umd.umiacs.itm.tree.TreeTopicSamplerSortD; 29 | 30 | import java.io.*; 31 | 32 | /** Perform topic analysis in the style of LDA and its variants. 33 | * @author Andrew McCallum 34 | */ 35 | 36 | public class ITMVectors2Topics { 37 | 38 | // common options in mallet 39 | static CommandOption.String inputFile = new CommandOption.String 40 | (ITMVectors2Topics.class, "input", "FILENAME", true, null, 41 | "The filename from which to read the list of training instances. Use - for stdin. " + 42 | "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); 43 | 44 | static CommandOption.Integer numTopics = new CommandOption.Integer 45 | (ITMVectors2Topics.class, "num-topics", "INTEGER", true, 10, 46 | "The number of topics to fit.", null); 47 | 48 | static CommandOption.Integer numIterations = new CommandOption.Integer 49 | (ITMVectors2Topics.class, "num-iterations", "INTEGER", true, 1000, 50 | "The number of iterations of Gibbs sampling.", null); 51 | 52 | static CommandOption.Integer randomSeed = new CommandOption.Integer 53 | (ITMVectors2Topics.class, "random-seed", "INTEGER", true, 0, 54 | "The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null); 55 | 56 | static CommandOption.Integer topWords = new CommandOption.Integer 57 | (ITMVectors2Topics.class, "num-top-words", "INTEGER", true, 20, 58 | "The number of most probable words to print for each topic after model estimation.", null); 59 | 60 | static CommandOption.Double alpha = new CommandOption.Double 61 | (ITMVectors2Topics.class, "alpha", "DECIMAL", true, 50.0, 62 | "Alpha parameter: smoothing over topic distribution.",null); 63 | 64 | //////////////////////////////////// 65 | // new options 66 | 67 | static CommandOption.Boolean useTreeLDA = new CommandOption.Boolean 68 | (ITMVectors2Topics.class, "use-tree-lda", "true|false", false, false, 69 | "Rather than using flat prior for LDA, use the tree-based prior for LDA, which models words' correlations." + 70 | "You cannot do this and also --use-ngrams or --use-PAM.", null); 71 | 72 | static CommandOption.String modelType = new CommandOption.String 73 | (ITMVectors2Topics.class, "tree-model-type", "TYPENAME", true, "fast-est", 74 | "Three possible types: naive, fast, fast-est, fast-sortD, fast-sortW, fast-sortD-sortW, fast-est-sortD, fast-est-sortW, fast-est-sortD-sortW.", null); 75 | 76 | static CommandOption.String treeFiles = new CommandOption.String 77 | (ITMVectors2Topics.class, "tree", "FILENAME", true, null, 78 | "The input files for tree structure.", null); 79 | 80 | static CommandOption.String hyperFile = new CommandOption.String 81 | (ITMVectors2Topics.class, "tree-hyperparameters", "FILENAME", true, null, 82 | "The hyperparameters for tree structure.", null); 83 | 84 | static CommandOption.String vocabFile = new CommandOption.String 85 | (ITMVectors2Topics.class, "vocab", "FILENAME", true, null, 86 | "The input vocabulary.", null); 87 | 88 | static CommandOption.String consFile = new CommandOption.String 89 | (ITMVectors2Topics.class, "constraint", "FILENAME", true, null, 90 | "The input constraint file.", null); 91 | 92 | static CommandOption.Integer outputInteval = new CommandOption.Integer 93 | (ITMVectors2Topics.class, "output-interval", "INTEGER", true, 20, 94 | "For each interval, the result files are output to the outputFolder.", null); 95 | 96 | static CommandOption.String outputDir= new CommandOption.String 97 | (ITMVectors2Topics.class, "output-dir", "FOLDERNAME", true, null, 98 | "The output folder.", null); 99 | 100 | static CommandOption.Boolean resume = new CommandOption.Boolean 101 | (ITMVectors2Topics.class, "resume", "true|false", false, false, 102 | "Resume from the previous output states.", null); 103 | 104 | static CommandOption.String resumeDir = new CommandOption.String 105 | (ITMVectors2Topics.class, "resume-dir", "FOLDERNAME", true, null, 106 | "The resume folder.", null); 107 | 108 | static CommandOption.String clearType = new CommandOption.String 109 | (ITMVectors2Topics.class, "clear-type", "TYPENAME", true, null, 110 | "Two possible types: doc, term.", null); 111 | 112 | static CommandOption.Boolean genVocab = new CommandOption.Boolean 113 | (ITMVectors2Topics.class, "generate-vocab", "true|false", false, false, 114 | "Generate vocab after mallet preprocessing.", null); 115 | 116 | //static CommandOption.Boolean sortTOption = new CommandOption.Boolean 117 | //(ITMVectors2Topics.class, "sort-topic", "true|false", false, true, 118 | // "Sort the topic counts for each term or not.", null); 119 | 120 | public static void main (String[] args) throws java.io.IOException 121 | { 122 | // Process the command-line options 123 | CommandOption.setSummary (ITMVectors2Topics.class, 124 | "A tool for estimating, saving and printing diagnostics for topic models, such as LDA."); 125 | CommandOption.process (ITMVectors2Topics.class, args); 126 | 127 | if (useTreeLDA.value) { 128 | InstanceList ilist = InstanceList.load (new File(inputFile.value)); 129 | System.out.println ("Data loaded."); 130 | 131 | if (genVocab.value) { 132 | GenerateVocab.genVocab(ilist, vocabFile.value); 133 | } else { 134 | //TreeTopicSamplerHashD topicModel = null; 135 | TreeTopicSampler topicModel = null; 136 | boolean sortW = false; 137 | 138 | if (modelType.value.equals("naive")) { 139 | topicModel = new TreeTopicSamplerNaive( 140 | numTopics.value, alpha.value, randomSeed.value); 141 | 142 | } else if (modelType.value.equals("fast")){ 143 | topicModel = new TreeTopicSamplerFast( 144 | numTopics.value, alpha.value, randomSeed.value, sortW); 145 | } else if (modelType.value.equals("fast-sortD")){ 146 | topicModel = new TreeTopicSamplerFastSortD( 147 | numTopics.value, alpha.value, randomSeed.value, sortW); 148 | } else if (modelType.value.equals("fast-sortW")){ 149 | sortW = true; 150 | topicModel = new TreeTopicSamplerFast( 151 | numTopics.value, alpha.value, randomSeed.value, sortW); 152 | } else if (modelType.value.equals("fast-sortD-sortW")){ 153 | sortW = true; 154 | topicModel = new TreeTopicSamplerFastSortD( 155 | numTopics.value, alpha.value, randomSeed.value, sortW); 156 | 157 | } else if (modelType.value.equals("fast-est")) { 158 | topicModel = new TreeTopicSamplerFastEst( 159 | numTopics.value, alpha.value, randomSeed.value, sortW); 160 | } else if (modelType.value.equals("fast-est-sortD")) { 161 | topicModel = new TreeTopicSamplerFastEstSortD( 162 | numTopics.value, alpha.value, randomSeed.value, sortW); 163 | } else if (modelType.value.equals("fast-est-sortW")) { 164 | sortW = true; 165 | topicModel = new TreeTopicSamplerFastEst( 166 | numTopics.value, alpha.value, randomSeed.value, sortW); 167 | } else if (modelType.value.equals("fast-est-sortD-sortW")) { 168 | sortW = true; 169 | topicModel = new TreeTopicSamplerFastEstSortD( 170 | numTopics.value, alpha.value, randomSeed.value, sortW); 171 | 172 | } else { 173 | System.out.println("model type wrong! please use " + 174 | "'naive' 'fast' or 'fast-est'!"); 175 | System.exit(0); 176 | } 177 | 178 | //TreeTopicSamplerSortD topicModel = new TreeTopicSamplerFastEstSortD( 179 | // numTopics.value, alpha.value, randomSeed.value, bubbleOption.value); 180 | //TreeTopicSamplerSort topicModel = new TreeTopicSamplerFastSort( 181 | // numTopics.value, alpha.value, randomSeed.value, bubbleOption.value); 182 | //TreeTopicSamplerSort topicModel = new TreeTopicSamplerFastEstSort( 183 | // numTopics.value, alpha.value, randomSeed.value, bubbleOption.value); 184 | 185 | // load tree and vocab 186 | topicModel.initialize(treeFiles.value, hyperFile.value, vocabFile.value); 187 | topicModel.setNumIterations(numIterations.value); 188 | 189 | if (resume.value == true) { 190 | // resume instances from the saved states 191 | topicModel.resume(ilist, resumeDir.value); 192 | } else { 193 | // add instances 194 | topicModel.addInstances(ilist); 195 | } 196 | 197 | // if clearType is not null, clear the topic assignments of the 198 | // constraint words 199 | if (clearType.value != null) { 200 | if (clearType.value.equals("term") || clearType.value.equals("doc")) { 201 | topicModel.clearTopicAssignments(clearType.value, consFile.value); 202 | } else { 203 | System.out.println("clear type wrong! please use either 'doc' or 'term'!"); 204 | System.exit(0); 205 | } 206 | } 207 | 208 | // sampling and save states 209 | topicModel.estimate(numIterations.value, outputDir.value, 210 | outputInteval.value, topWords.value); 211 | 212 | // topic report 213 | //System.out.println(topicModel.displayTopWords(topWords.value)); 214 | } 215 | } 216 | } 217 | 218 | } 219 | -------------------------------------------------------------------------------- /src/main/resources/cc/mallet/util/resources/logging.properties: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Default Logging Configuration File 3 | # 4 | # You can use a different file by specifying a filename 5 | # with the java.util.logging.config.file system property. 6 | # For example java -Djava.util.logging.config.file=myfile 7 | ############################################################ 8 | 9 | ############################################################ 10 | # Global properties 11 | ############################################################ 12 | 13 | # "handlers" specifies a comma separated list of log Handler 14 | # classes. These handlers will be installed during VM startup. 15 | # Note that these classes must be on the system classpath. 16 | # By default we only configure a ConsoleHandler, which will only 17 | # show messages at the INFO and above levels. 18 | handlers= java.util.logging.ConsoleHandler 19 | 20 | # To also add the FileHandler, use the following line instead. 21 | #handlers= java.util.logging.FileHandler, java.util.logging.ConsoleHandler 22 | 23 | # Default global logging level. 24 | # This specifies which kinds of events are logged across 25 | # all loggers. For any given facility this global level 26 | # can be overriden by a facility specific level 27 | # Note that the ConsoleHandler also has a separate level 28 | # setting to limit messages printed to the console. 29 | .level= INFO 30 | 31 | ############################################################ 32 | # Handler specific properties. 33 | # Describes specific configuration info for Handlers. 34 | ############################################################ 35 | 36 | # default file output is in user's home directory. 37 | java.util.logging.FileHandler.pattern = %h/java%u.log 38 | java.util.logging.FileHandler.limit = 50000 39 | java.util.logging.FileHandler.count = 1 40 | java.util.logging.FileHandler.formatter = cc.mallet.util.PlainLogFormatter 41 | #java.util.logging.FileHandler.formatter = java.util.logging.XMLFormatter 42 | 43 | # Limit the message that are printed on the console. ALL means all messages are reported. Off means no messages are reported. 44 | java.util.logging.ConsoleHandler.level = FINE 45 | java.util.logging.ConsoleHandler.formatter = cc.mallet.util.PlainLogFormatter 46 | 47 | 48 | ############################################################ 49 | # Facility specific properties. 50 | # Provides extra control for each logger. 51 | ############################################################ 52 | 53 | # For example, set the com.xyz.foo logger to only log SEVERE 54 | # messages: 55 | 56 | #Put the level of specific loggers here. If not included, default is INFO 57 | 58 | #cc.mallet.fst.MaxLatticeDefault.level = FINE 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /tree/lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tree/lib/flags.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import getopt 4 | import glob 5 | import os 6 | import sys 7 | 8 | def dict_parser(d): 9 | print d 10 | if d.startswith("{"): 11 | return eval(d) 12 | else: 13 | print "Cannot parse dictionary: " + str(d) 14 | return {} 15 | 16 | def list_parser(s): 17 | global list_separator 18 | 19 | s = s.strip() 20 | if s != "": 21 | l = s.split(list_separator) 22 | else: 23 | l = [] 24 | return l 25 | 26 | 27 | def intlist_parser(s): 28 | global list_separator 29 | return [int(x) for x in s.split(list_separator)] 30 | 31 | def filename_parser(s): 32 | return os.path.expanduser(s) 33 | 34 | def glob_parser(s): 35 | return glob.glob(s) 36 | 37 | def bool_parser(s): 38 | s = s.lower() 39 | if s == "true": 40 | return True 41 | elif s == "false": 42 | return False 43 | else: 44 | raise "Cannot parse boolean: '" + s + "'" 45 | 46 | __flags = [] 47 | __post_inits = [] 48 | 49 | def define_dict(name, default, description): 50 | global __flags 51 | __flags.append((name, default, description, dict_parser)) 52 | 53 | def define_bool(name, default, description): 54 | global __flags 55 | __flags.append((name, default, description, bool_parser)) 56 | 57 | def define_string(name, default, description): 58 | global __flags 59 | __flags.append((name, default, description, str)) 60 | 61 | def define_str(name, default, description): 62 | global __flags 63 | __flags.append((name, default, description, str)) 64 | 65 | def define_int(name, default, description): 66 | global __flags 67 | __flags.append((name, default, description, int)) 68 | 69 | def define_float(name, default, description): 70 | global __flags 71 | __flags.append((name, default, description, float)) 72 | 73 | def define_list(name, default, description): 74 | global __flags 75 | __flags.append((name, default, description, list_parser)) 76 | 77 | def define_intlist(name, default, description): 78 | global __flags 79 | __flags.append((name, default, description, intlist_parser)) 80 | 81 | def define_filename(name, default, description): 82 | global __flags 83 | __flags.append((name, default, description, filename_parser)) 84 | 85 | def define_glob(name, default, description): 86 | global __flags 87 | __flags.append((name, default, description, glob_parser)) 88 | 89 | 90 | define_string("list_separator", ",", "String used to separate list item.") 91 | 92 | def RunAfterInit(func): 93 | global __post_inits 94 | __post_inits.append(func) 95 | 96 | def InitFlags(): 97 | global __flags 98 | this_mod_dict = globals() 99 | 100 | usage = "" 101 | switches = [] 102 | for f in __flags: 103 | name = f[0] 104 | description = f[2] 105 | default = f[1] 106 | this_mod_dict[name] = default 107 | switches.append(name + "=") 108 | usage += "\t" + name + "\t" + description + " [default = " + str(default) + "]\n" 109 | 110 | try: 111 | opts, args = getopt.getopt(sys.argv[1:], "", switches) 112 | except getopt.GetoptError: 113 | print usage 114 | sys.exit(2) 115 | 116 | for o, a in opts: 117 | index = [x for x in __flags if x[0] == o[2:]] 118 | if len(index) != 1: 119 | print usage 120 | sys.exit(2) 121 | else: 122 | index = index[0] 123 | this_mod_dict[index[0]] = index[3](a) 124 | 125 | for func in __post_inits: 126 | func() 127 | 128 | return args 129 | -------------------------------------------------------------------------------- /tree/lib/proto/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tree/lib/proto/wordnet_file.proto: -------------------------------------------------------------------------------- 1 | 2 | package topicmod_projects_ldawn; 3 | 4 | message WordNetFile { 5 | message Synset { 6 | required int32 offset = 1; 7 | optional string key = 2; 8 | repeated uint32 children_offsets = 3; 9 | 10 | // The total count of the synset's words ONLY 11 | optional double raw_count = 4; 12 | 13 | // The total count of the synset's words AND hyponym words 14 | optional double hyponym_count = 5; 15 | 16 | message Word { 17 | required uint32 term_id = 1; 18 | // TODO(jbg): Make this explicitly a language 19 | optional uint32 lang_id = 2[default = 0]; 20 | 21 | required string term_str = 3; 22 | optional string lang_str = 4; 23 | 24 | optional double x_location = 5; 25 | optional double y_location = 6; 26 | 27 | optional int32 depth = 7; 28 | optional double count = 8; 29 | } 30 | repeated Word words = 6; 31 | optional string hyperparameter = 7; 32 | } 33 | repeated Synset synsets = 1; 34 | required int32 root = 2; 35 | } 36 | -------------------------------------------------------------------------------- /tree/lib/proto/wordnet_file_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | 3 | from google.protobuf import descriptor 4 | from google.protobuf import message 5 | from google.protobuf import reflection 6 | from google.protobuf import descriptor_pb2 7 | # @@protoc_insertion_point(imports) 8 | 9 | 10 | DESCRIPTOR = descriptor.FileDescriptor( 11 | name='wordnet_file.proto', 12 | package='topicmod_projects_ldawn', 13 | serialized_pb='\n\x12wordnet_file.proto\x12\x17topicmod_projects_ldawn\"\xb6\x03\n\x0bWordNetFile\x12<\n\x07synsets\x18\x01 \x03(\x0b\x32+.topicmod_projects_ldawn.WordNetFile.Synset\x12\x0c\n\x04root\x18\x02 \x02(\x05\x1a\xda\x02\n\x06Synset\x12\x0e\n\x06offset\x18\x01 \x02(\x05\x12\x0b\n\x03key\x18\x02 \x01(\t\x12\x18\n\x10\x63hildren_offsets\x18\x03 \x03(\r\x12\x11\n\traw_count\x18\x04 \x01(\x01\x12\x15\n\rhyponym_count\x18\x05 \x01(\x01\x12?\n\x05words\x18\x06 \x03(\x0b\x32\x30.topicmod_projects_ldawn.WordNetFile.Synset.Word\x12\x16\n\x0ehyperparameter\x18\x07 \x01(\t\x1a\x95\x01\n\x04Word\x12\x0f\n\x07term_id\x18\x01 \x02(\r\x12\x12\n\x07lang_id\x18\x02 \x01(\r:\x01\x30\x12\x10\n\x08term_str\x18\x03 \x02(\t\x12\x10\n\x08lang_str\x18\x04 \x01(\t\x12\x12\n\nx_location\x18\x05 \x01(\x01\x12\x12\n\ny_location\x18\x06 \x01(\x01\x12\r\n\x05\x64\x65pth\x18\x07 \x01(\x05\x12\r\n\x05\x63ount\x18\x08 \x01(\x01') 14 | 15 | 16 | 17 | 18 | _WORDNETFILE_SYNSET_WORD = descriptor.Descriptor( 19 | name='Word', 20 | full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word', 21 | filename=None, 22 | file=DESCRIPTOR, 23 | containing_type=None, 24 | fields=[ 25 | descriptor.FieldDescriptor( 26 | name='term_id', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.term_id', index=0, 27 | number=1, type=13, cpp_type=3, label=2, 28 | has_default_value=False, default_value=0, 29 | message_type=None, enum_type=None, containing_type=None, 30 | is_extension=False, extension_scope=None, 31 | options=None), 32 | descriptor.FieldDescriptor( 33 | name='lang_id', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.lang_id', index=1, 34 | number=2, type=13, cpp_type=3, label=1, 35 | has_default_value=True, default_value=0, 36 | message_type=None, enum_type=None, containing_type=None, 37 | is_extension=False, extension_scope=None, 38 | options=None), 39 | descriptor.FieldDescriptor( 40 | name='term_str', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.term_str', index=2, 41 | number=3, type=9, cpp_type=9, label=2, 42 | has_default_value=False, default_value=unicode("", "utf-8"), 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | descriptor.FieldDescriptor( 47 | name='lang_str', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.lang_str', index=3, 48 | number=4, type=9, cpp_type=9, label=1, 49 | has_default_value=False, default_value=unicode("", "utf-8"), 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | descriptor.FieldDescriptor( 54 | name='x_location', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.x_location', index=4, 55 | number=5, type=1, cpp_type=5, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None), 60 | descriptor.FieldDescriptor( 61 | name='y_location', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.y_location', index=5, 62 | number=6, type=1, cpp_type=5, label=1, 63 | has_default_value=False, default_value=0, 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None), 67 | descriptor.FieldDescriptor( 68 | name='depth', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.depth', index=6, 69 | number=7, type=5, cpp_type=1, label=1, 70 | has_default_value=False, default_value=0, 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None), 74 | descriptor.FieldDescriptor( 75 | name='count', full_name='topicmod_projects_ldawn.WordNetFile.Synset.Word.count', index=7, 76 | number=8, type=1, cpp_type=5, label=1, 77 | has_default_value=False, default_value=0, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None), 81 | ], 82 | extensions=[ 83 | ], 84 | nested_types=[], 85 | enum_types=[ 86 | ], 87 | options=None, 88 | is_extendable=False, 89 | extension_ranges=[], 90 | serialized_start=337, 91 | serialized_end=486, 92 | ) 93 | 94 | _WORDNETFILE_SYNSET = descriptor.Descriptor( 95 | name='Synset', 96 | full_name='topicmod_projects_ldawn.WordNetFile.Synset', 97 | filename=None, 98 | file=DESCRIPTOR, 99 | containing_type=None, 100 | fields=[ 101 | descriptor.FieldDescriptor( 102 | name='offset', full_name='topicmod_projects_ldawn.WordNetFile.Synset.offset', index=0, 103 | number=1, type=5, cpp_type=1, label=2, 104 | has_default_value=False, default_value=0, 105 | message_type=None, enum_type=None, containing_type=None, 106 | is_extension=False, extension_scope=None, 107 | options=None), 108 | descriptor.FieldDescriptor( 109 | name='key', full_name='topicmod_projects_ldawn.WordNetFile.Synset.key', index=1, 110 | number=2, type=9, cpp_type=9, label=1, 111 | has_default_value=False, default_value=unicode("", "utf-8"), 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | options=None), 115 | descriptor.FieldDescriptor( 116 | name='children_offsets', full_name='topicmod_projects_ldawn.WordNetFile.Synset.children_offsets', index=2, 117 | number=3, type=13, cpp_type=3, label=3, 118 | has_default_value=False, default_value=[], 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | options=None), 122 | descriptor.FieldDescriptor( 123 | name='raw_count', full_name='topicmod_projects_ldawn.WordNetFile.Synset.raw_count', index=3, 124 | number=4, type=1, cpp_type=5, label=1, 125 | has_default_value=False, default_value=0, 126 | message_type=None, enum_type=None, containing_type=None, 127 | is_extension=False, extension_scope=None, 128 | options=None), 129 | descriptor.FieldDescriptor( 130 | name='hyponym_count', full_name='topicmod_projects_ldawn.WordNetFile.Synset.hyponym_count', index=4, 131 | number=5, type=1, cpp_type=5, label=1, 132 | has_default_value=False, default_value=0, 133 | message_type=None, enum_type=None, containing_type=None, 134 | is_extension=False, extension_scope=None, 135 | options=None), 136 | descriptor.FieldDescriptor( 137 | name='words', full_name='topicmod_projects_ldawn.WordNetFile.Synset.words', index=5, 138 | number=6, type=11, cpp_type=10, label=3, 139 | has_default_value=False, default_value=[], 140 | message_type=None, enum_type=None, containing_type=None, 141 | is_extension=False, extension_scope=None, 142 | options=None), 143 | descriptor.FieldDescriptor( 144 | name='hyperparameter', full_name='topicmod_projects_ldawn.WordNetFile.Synset.hyperparameter', index=6, 145 | number=7, type=9, cpp_type=9, label=1, 146 | has_default_value=False, default_value=unicode("", "utf-8"), 147 | message_type=None, enum_type=None, containing_type=None, 148 | is_extension=False, extension_scope=None, 149 | options=None), 150 | ], 151 | extensions=[ 152 | ], 153 | nested_types=[_WORDNETFILE_SYNSET_WORD, ], 154 | enum_types=[ 155 | ], 156 | options=None, 157 | is_extendable=False, 158 | extension_ranges=[], 159 | serialized_start=140, 160 | serialized_end=486, 161 | ) 162 | 163 | _WORDNETFILE = descriptor.Descriptor( 164 | name='WordNetFile', 165 | full_name='topicmod_projects_ldawn.WordNetFile', 166 | filename=None, 167 | file=DESCRIPTOR, 168 | containing_type=None, 169 | fields=[ 170 | descriptor.FieldDescriptor( 171 | name='synsets', full_name='topicmod_projects_ldawn.WordNetFile.synsets', index=0, 172 | number=1, type=11, cpp_type=10, label=3, 173 | has_default_value=False, default_value=[], 174 | message_type=None, enum_type=None, containing_type=None, 175 | is_extension=False, extension_scope=None, 176 | options=None), 177 | descriptor.FieldDescriptor( 178 | name='root', full_name='topicmod_projects_ldawn.WordNetFile.root', index=1, 179 | number=2, type=5, cpp_type=1, label=2, 180 | has_default_value=False, default_value=0, 181 | message_type=None, enum_type=None, containing_type=None, 182 | is_extension=False, extension_scope=None, 183 | options=None), 184 | ], 185 | extensions=[ 186 | ], 187 | nested_types=[_WORDNETFILE_SYNSET, ], 188 | enum_types=[ 189 | ], 190 | options=None, 191 | is_extendable=False, 192 | extension_ranges=[], 193 | serialized_start=48, 194 | serialized_end=486, 195 | ) 196 | 197 | 198 | _WORDNETFILE_SYNSET_WORD.containing_type = _WORDNETFILE_SYNSET; 199 | _WORDNETFILE_SYNSET.fields_by_name['words'].message_type = _WORDNETFILE_SYNSET_WORD 200 | _WORDNETFILE_SYNSET.containing_type = _WORDNETFILE; 201 | _WORDNETFILE.fields_by_name['synsets'].message_type = _WORDNETFILE_SYNSET 202 | 203 | class WordNetFile(message.Message): 204 | __metaclass__ = reflection.GeneratedProtocolMessageType 205 | 206 | class Synset(message.Message): 207 | __metaclass__ = reflection.GeneratedProtocolMessageType 208 | 209 | class Word(message.Message): 210 | __metaclass__ = reflection.GeneratedProtocolMessageType 211 | DESCRIPTOR = _WORDNETFILE_SYNSET_WORD 212 | 213 | # @@protoc_insertion_point(class_scope:topicmod_projects_ldawn.WordNetFile.Synset.Word) 214 | DESCRIPTOR = _WORDNETFILE_SYNSET 215 | 216 | # @@protoc_insertion_point(class_scope:topicmod_projects_ldawn.WordNetFile.Synset) 217 | DESCRIPTOR = _WORDNETFILE 218 | 219 | # @@protoc_insertion_point(class_scope:topicmod_projects_ldawn.WordNetFile) 220 | 221 | # @@protoc_insertion_point(module_scope) 222 | --------------------------------------------------------------------------------