├── .gitignore ├── .gitmodules ├── windows ├── README.md ├── lightlda.sln ├── dump_binary │ └── dump_binary.vcxproj ├── infer │ └── infer.vcxproj └── lightlda │ └── lightlda.vcxproj ├── src ├── document.cpp ├── data_stream.h ├── util.h ├── eval.h ├── trainer.h ├── document.h ├── model.h ├── alias_table.h ├── eval.cpp ├── sampler.h ├── common.h ├── data_block.cpp ├── meta.h ├── data_block.h ├── data_stream.cpp ├── common.cpp ├── trainer.cpp ├── meta.cpp ├── sampler.cpp ├── model.cpp ├── lightlda.cpp └── alias_table.cpp ├── example ├── get_meta.py ├── pubmed.sh ├── nytimes.sh ├── text2libsvm.py └── README.md ├── LICENSE ├── inference ├── inferer.h ├── inferer.cpp └── infer.cpp ├── MakefileDocker ├── Makefile ├── README.md └── preprocess └── dump_binary.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | multiverso 3 | *.o 4 | example/data 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "multiverso"] 2 | path = multiverso 3 | url = https://github.com/Microsoft/multiverso.git 4 | branch = multiverso-initial 5 | -------------------------------------------------------------------------------- /windows/README.md: -------------------------------------------------------------------------------- 1 | # Windows Installation 2 | 3 | 1. Get and build the DMTK Framework [multiverso](https://github.com/Microsoft/multiverso.git). 4 | 5 | 2. Open lightlda.sln, change configuration and platform to Release and x64, set the include and lib path of multiverso in project property. 6 | 7 | 3. Build the solution. 8 | -------------------------------------------------------------------------------- /src/document.cpp: -------------------------------------------------------------------------------- 1 | #include "document.h" 2 | 3 | #include 4 | 5 | namespace multiverso { namespace lightlda 6 | { 7 | Document::Document(int32_t* begin, int32_t* end) 8 | : begin_(begin), end_(end), cursor_(*begin_) 9 | {} 10 | 11 | void Document::GetDocTopicVector(Row& topic_counter) 12 | { 13 | int32_t* p = begin_ + 2; 14 | int32_t num = 0; 15 | while (p < end_) 16 | { 17 | topic_counter.Add(*p, 1); 18 | ++p; ++p; 19 | if (++num == topic_counter.Capacity()) 20 | return; 21 | } 22 | } 23 | } // namespace lightlda 24 | } // namespace multiverso 25 | -------------------------------------------------------------------------------- /example/get_meta.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | 3 | ''' this script is used to count the tf information from 4 | libsvm data 5 | Usage: 6 | python get_meta.py 7 | ''' 8 | 9 | import sys 10 | import re 11 | 12 | finput = open(sys.argv[1], 'r') 13 | 14 | word_dict = {} 15 | 16 | line = finput.readline() 17 | while line: 18 | doc = re.split(" |\t", line.strip())[1:] 19 | for word_count in doc: 20 | col = word_count.strip().split(":") 21 | if len(col) != 2: 22 | print "error!" 23 | if not word_dict.has_key(int(col[0])): 24 | word_dict[int(col[0])] = 0 25 | word_dict[int(col[0])] += int(col[1]) 26 | line = finput.readline() 27 | 28 | foutput = open(sys.argv[2], 'w') 29 | for word in word_dict: 30 | line = '\t'.join([str(word), "word", str(word_dict[word])]) + '\n' 31 | foutput.write(line) 32 | 33 | 34 | -------------------------------------------------------------------------------- /example/pubmed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | root=`pwd` 4 | echo $root 5 | bin=$root/../bin 6 | dir=$root/data/pubmed 7 | 8 | mkdir -p $dir 9 | cd $dir 10 | 11 | # 1. Download the data 12 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/docword.pubmed.txt.gz 13 | gunzip $dir/docword.pubmed.txt.gz 14 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/vocab.pubmed.txt 15 | 16 | # 2. UCI format to libsvm format 17 | python $root/text2libsvm.py $dir/docword.pubmed.txt $dir/vocab.pubmed.txt $dir/pubmed.libsvm $dir/pubmed.word_id.dict 18 | 19 | # 3. libsvm format to binary format 20 | $bin/dump_binary $dir/pubmed.libsvm $dir/pubmed.word_id.dict $dir 0 21 | 22 | # 4. Run LightLDA 23 | $bin/lightlda -num_vocabs 144400 -num_topics 1000 -num_iterations 100 -alpha 0.1 -beta 0.01 -mh_steps 2 -num_local_workers 1 -num_blocks 1 -max_num_document 8300000 -input_dir $dir -data_capacity 6200 24 | -------------------------------------------------------------------------------- /example/nytimes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | root=`pwd` 4 | echo $root 5 | bin=$root/../bin 6 | dir=$root/data/nytimes 7 | 8 | mkdir -p $dir 9 | cd $dir 10 | 11 | # 1. Download the data 12 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/docword.nytimes.txt.gz 13 | gunzip $dir/docword.nytimes.txt.gz 14 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/vocab.nytimes.txt 15 | 16 | # 2. UCI format to libsvm format 17 | python $root/text2libsvm.py $dir/docword.nytimes.txt $dir/vocab.nytimes.txt $dir/nytimes.libsvm $dir/nytimes.word_id.dict 18 | 19 | # 3. libsvm format to binary format 20 | $bin/dump_binary $dir/nytimes.libsvm $dir/nytimes.word_id.dict $dir 0 21 | 22 | # 4. Run LightLDA 23 | $bin/lightlda -num_vocabs 111400 -num_topics 1000 -num_iterations 100 -alpha 0.1 -beta 0.01 -mh_steps 2 -num_local_workers 1 -num_blocks 1 -max_num_document 300000 -input_dir $dir -data_capacity 800 24 | -------------------------------------------------------------------------------- /src/data_stream.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file data_stream.h 3 | * \brief Defines interface for data access 4 | */ 5 | 6 | #ifndef LIGHTLDA_DATA_STREAM_H_ 7 | #define LIGHTLDA_DATA_STREAM_H_ 8 | 9 | namespace multiverso { namespace lightlda 10 | { 11 | class DataBlock; 12 | /*! \brief interface of data stream */ 13 | class IDataStream 14 | { 15 | public: 16 | virtual ~IDataStream() {} 17 | /*! \brief Should call this method before access a data block */ 18 | virtual void BeforeDataAccess() = 0; 19 | /*! \brief Should call this method after access a data block */ 20 | virtual void EndDataAccess() = 0; 21 | /*! 22 | * \brief Get one data block 23 | * \return reference to data block 24 | */ 25 | virtual DataBlock& CurrDataBlock() = 0; 26 | }; 27 | 28 | /*! \brief Factory method to create data stream */ 29 | IDataStream* CreateDataStream(); 30 | 31 | } // namespace lightlda 32 | } // namespace multiverso 33 | 34 | #endif // LIGHTLDA_DATA_STREAM_H_ 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) Microsoft Corporation 4 | 5 | All rights reserved. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /src/util.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file util.h 3 | * \brief Defines random number generator 4 | */ 5 | 6 | #ifndef LIGHTLDA_UTIL_H_ 7 | #define LIGHTLDA_UTIL_H_ 8 | 9 | #include 10 | 11 | namespace multiverso { namespace lightlda 12 | { 13 | /*! \brief xorshift_rng is a random number generator */ 14 | class xorshift_rng 15 | { 16 | public: 17 | xorshift_rng() 18 | { 19 | jxr_ = static_cast(time(nullptr)); 20 | } 21 | ~xorshift_rng() {} 22 | 23 | /*! \brief get random (xorshift) 32-bit integer*/ 24 | int32_t rand() 25 | { 26 | jxr_ ^= (jxr_ << 13); jxr_ ^= (jxr_ >> 17); jxr_ ^= (jxr_ << 5); 27 | return jxr_ & 0x7fffffff; 28 | } 29 | 30 | double rand_double() 31 | { 32 | return rand() * 4.6566125e-10; 33 | } 34 | int32_t rand_k(int K) 35 | { 36 | return static_cast(rand() * 4.6566125e-10 * K); 37 | } 38 | private: 39 | // No copying allowed 40 | xorshift_rng(const xorshift_rng &other); 41 | void operator=(const xorshift_rng &other); 42 | /*! \brief seed */ 43 | uint32_t jxr_; 44 | }; 45 | } // namespace lightlda 46 | } // namespace multiverso 47 | 48 | #endif // LIGHTLDA_UTIL_H_ 49 | -------------------------------------------------------------------------------- /inference/inferer.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file inferer.h 3 | * \brief data inference 4 | */ 5 | #ifndef LIGHTLDA_INFERER_H_ 6 | #define LIGHTLDA_INFERER_H_ 7 | 8 | // #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace multiverso 14 | { 15 | class Barrier; 16 | 17 | namespace lightlda 18 | { 19 | class AliasTable; 20 | class LDADataBlock; 21 | class LightDocSampler; 22 | class Meta; 23 | class LocalModel; 24 | class IDataStream; 25 | 26 | class Inferer 27 | { 28 | public: 29 | Inferer(AliasTable* alias_table, 30 | IDataStream * data_stream, 31 | Meta* meta, LocalModel * model, 32 | Barrier* barrier, 33 | int32_t id, int32_t thread_num); 34 | 35 | ~Inferer(); 36 | void BeforeIteration(int32_t block); 37 | void DoIteration(int32_t iter); 38 | void EndIteration(); 39 | private: 40 | AliasTable* alias_; 41 | IDataStream * data_stream_; 42 | Meta* meta_; 43 | LocalModel * model_; 44 | Barrier* barrier_; 45 | int32_t id_; 46 | int32_t thread_num_; 47 | LightDocSampler* sampler_; 48 | }; 49 | } // namespace lightlda 50 | } // namespace multiverso 51 | 52 | 53 | #endif //LIGHTLDA_INFERER_H_ 54 | -------------------------------------------------------------------------------- /src/eval.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file eval.h 3 | * \brief Defines utility for evaluating likelihood of lda 4 | */ 5 | 6 | #ifndef LIGHTLDA_EVAL_H_ 7 | #define LIGHTLDA_EVAL_H_ 8 | 9 | #include "common.h" 10 | 11 | namespace multiverso 12 | { 13 | template 14 | class Row; 15 | } 16 | 17 | namespace multiverso { namespace lightlda 18 | { 19 | class Document; 20 | class Trainer; 21 | 22 | /*! 23 | * \brief Eval defines functions to compute the likelihood of lightlda 24 | * Likelihood is split into doc-likelihood and word-likelihood. 25 | * The total likelihood can be get by adding these values. 26 | */ 27 | class Eval 28 | { 29 | public: 30 | /*! 31 | * \brief Compute doc-likelihood for one document 32 | * \param doc input document for evaluation 33 | */ 34 | static double ComputeOneDocLLH(Document* doc, 35 | Row& doc_topic_counter); 36 | 37 | /*! 38 | * \brief Compute word-likelihood for one word 39 | * \param word input word for evaluation 40 | * \param trainer for multiverso parameter access 41 | */ 42 | static double ComputeOneWordLLH(int32_t word, Trainer* trainer); 43 | 44 | /*! 45 | * \brief Compute normalization item for word-likelihood 46 | * \param trainer for multiverso parameter access 47 | */ 48 | static double NormalizeWordLLH(Trainer* trainer); 49 | }; 50 | } // namespace lightlda 51 | } // namespace multiverso 52 | 53 | #endif // LIGHTLDA_EVAL_H_ 54 | -------------------------------------------------------------------------------- /example/text2libsvm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is for converting UCI format docword and vocab file to libsvm format data and dict 3 | 4 | (How to run) 5 | 6 | python text2libsvm.py 7 | 8 | """ 9 | 10 | import sys 11 | 12 | if len(sys.argv) != 5: 13 | print "Usage: python text2libsvm.py " 14 | exit(1) 15 | 16 | data_file = open(sys.argv[1], 'r') 17 | vocab_file = open(sys.argv[2], 'r') 18 | 19 | libsvm_file = open(sys.argv[3], 'w') 20 | dict_file = open(sys.argv[4], 'w') 21 | 22 | word_dict = {} 23 | vocab_dict = [] 24 | doc = "" 25 | last_doc_id = 0 26 | 27 | line = vocab_file.readline() 28 | while line: 29 | vocab_dict.append(line.strip()) 30 | line = vocab_file.readline() 31 | 32 | line = data_file.readline() 33 | while line: 34 | col = line.strip().split(' ') 35 | if len(col) == 3: 36 | doc_id = int(col[0]) 37 | word_id = int(col[1]) - 1 38 | word_count = int(col[2]) 39 | if not word_dict.has_key(word_id): 40 | word_dict[word_id] = 0 41 | word_dict[word_id] += word_count 42 | if doc_id != last_doc_id: 43 | if doc != "": 44 | libsvm_file.write(doc.strip() + '\n') 45 | doc = str(doc_id) + '\t' 46 | doc += str(word_id) + ':' + str(word_count) + ' ' 47 | last_doc_id = doc_id 48 | line = data_file.readline() 49 | 50 | if doc != "": 51 | libsvm_file.write(doc.strip() + '\n') 52 | 53 | libsvm_file.close() 54 | 55 | for word in word_dict: 56 | line = '\t'.join([str(word), vocab_dict[word], str(word_dict[word])]) + '\n' 57 | dict_file.write(line) 58 | 59 | -------------------------------------------------------------------------------- /MakefileDocker: -------------------------------------------------------------------------------- 1 | PROJECT := $(shell readlink $(dir $(lastword $(MAKEFILE_LIST))) -f) 2 | 3 | CXX = mpic++ 4 | CXXFLAGS = -O3 \ 5 | -std=c++11 \ 6 | -Wall \ 7 | -Wno-sign-compare \ 8 | -fno-omit-frame-pointer 9 | 10 | MULTIVERSO_DIR = $(PROJECT)/multiverso 11 | MULTIVERSO_INC = $(MULTIVERSO_DIR)/include 12 | MULTIVERSO_LIB = $(MULTIVERSO_DIR)/lib 13 | THIRD_PARTY_LIB = $(MULTIVERSO_DIR)/third_party/lib 14 | 15 | INC_FLAGS = -I$(MULTIVERSO_INC) -I$(PROJECT)/src -I$(PROJECT)/inference 16 | LD_FLAGS = -L$(MULTIVERSO_LIB) -lmultiverso 17 | LD_FLAGS += -L$(THIRD_PARTY_LIB) -lzmq -lpthread 18 | 19 | BASE_SRC = $(shell find $(PROJECT)/src -type f -name "*.cpp" -type f ! -name "lightlda.cpp") 20 | BASE_OBJ = $(BASE_SRC:.cpp=.o) 21 | 22 | LIGHTLDA_HEADERS = $(shell find $(PROJECT)/src -type f -name "*.h") 23 | LIGHTLDA_SRC = $(shell find $(PROJECT)/src -type f -name "*.cpp") 24 | LIGHTLDA_OBJ = $(LIGHTLDA_SRC:.cpp=.o) 25 | 26 | INFER_HEADERS = $(shell find $(PROJECT)/inference -type f -name "*.h") 27 | INFER_SRC = $(shell find $(PROJECT)/inference -type f -name "*.cpp") 28 | INFER_OBJ = $(INFER_SRC:.cpp=.o) 29 | 30 | DUMP_BINARY_SRC = $(shell find $(PROJECT)/preprocess -type f -name "*.cpp") 31 | 32 | BIN_DIR = $(PROJECT)/bin 33 | LIGHTLDA = $(BIN_DIR)/lightlda 34 | INFER = $(BIN_DIR)/infer 35 | DUMP_BINARY = $(BIN_DIR)/dump_binary 36 | 37 | all: path \ 38 | lightlda \ 39 | infer \ 40 | dump_binary 41 | 42 | path: $(BIN_DIR) 43 | 44 | $(BIN_DIR): 45 | mkdir -p $@ 46 | 47 | $(LIGHTLDA): $(LIGHTLDA_OBJ) 48 | $(CXX) $(LIGHTLDA_OBJ) $(CXXFLAGS) $(INC_FLAGS) $(LD_FLAGS) -o $@ 49 | 50 | $(LIGHTLDA_OBJ): %.o: %.cpp $(LIGHTLDA_HEADERS) $(MULTIVERSO_INC) 51 | $(CXX) $(CXXFLAGS) $(INC_FLAGS) -c $< -o $@ 52 | 53 | $(INFER): $(INFER_OBJ) $(BASE_OBJ) 54 | $(CXX) $(INFER_OBJ) $(BASE_OBJ) $(CXXFLAGS) $(INC_FLAGS) $(LD_FLAGS) -o $@ 55 | 56 | $(INFER_OBJ): %.o: %.cpp $(INFER_HEADERS) $(MULTIVERSO_INC) 57 | $(CXX) $(CXXFLAGS) $(INC_FLAGS) -c $< -o $@ 58 | 59 | $(DUMP_BINARY): $(DUMP_BINARY_SRC) 60 | $(CXX) $(CXXFLAGS) $< -o $@ 61 | 62 | lightlda: path $(LIGHTLDA) 63 | 64 | infer: path $(INFER) 65 | 66 | dump_binary: path $(DUMP_BINARY) 67 | 68 | clean: 69 | rm -rf $(BIN_DIR) $(LIGHTLDA_OBJ) $(INFER_OBJ) 70 | 71 | .PHONY: all path lightlda infer dump_binary clean 72 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT := $(shell readlink $(dir $(lastword $(MAKEFILE_LIST))) -f) 2 | 3 | CXX = g++ 4 | CXXFLAGS = -O3 \ 5 | -std=c++11 \ 6 | -Wall \ 7 | -Wno-sign-compare \ 8 | -fno-omit-frame-pointer 9 | 10 | MULTIVERSO_DIR = $(PROJECT)/multiverso 11 | MULTIVERSO_INC = $(MULTIVERSO_DIR)/include 12 | MULTIVERSO_LIB = $(MULTIVERSO_DIR)/lib 13 | THIRD_PARTY_LIB = $(MULTIVERSO_DIR)/third_party/lib 14 | 15 | INC_FLAGS = -I$(MULTIVERSO_INC) -I$(PROJECT)/src -I$(PROJECT)/inference 16 | LD_FLAGS = -L$(MULTIVERSO_LIB) -lmultiverso 17 | LD_FLAGS += -L$(THIRD_PARTY_LIB) -lzmq -lmpich -lmpl -lpthread 18 | 19 | BASE_SRC = $(shell find $(PROJECT)/src -type f -name "*.cpp" -type f ! -name "lightlda.cpp") 20 | BASE_OBJ = $(BASE_SRC:.cpp=.o) 21 | 22 | LIGHTLDA_HEADERS = $(shell find $(PROJECT)/src -type f -name "*.h") 23 | LIGHTLDA_SRC = $(shell find $(PROJECT)/src -type f -name "*.cpp") 24 | LIGHTLDA_OBJ = $(LIGHTLDA_SRC:.cpp=.o) 25 | 26 | INFER_HEADERS = $(shell find $(PROJECT)/inference -type f -name "*.h") 27 | INFER_SRC = $(shell find $(PROJECT)/inference -type f -name "*.cpp") 28 | INFER_OBJ = $(INFER_SRC:.cpp=.o) 29 | 30 | DUMP_BINARY_SRC = $(shell find $(PROJECT)/preprocess -type f -name "*.cpp") 31 | 32 | BIN_DIR = $(PROJECT)/bin 33 | LIGHTLDA = $(BIN_DIR)/lightlda 34 | INFER = $(BIN_DIR)/infer 35 | DUMP_BINARY = $(BIN_DIR)/dump_binary 36 | 37 | all: path \ 38 | lightlda \ 39 | infer \ 40 | dump_binary 41 | 42 | path: $(BIN_DIR) 43 | 44 | $(BIN_DIR): 45 | mkdir -p $@ 46 | 47 | $(LIGHTLDA): $(LIGHTLDA_OBJ) 48 | $(CXX) $(LIGHTLDA_OBJ) $(CXXFLAGS) $(INC_FLAGS) $(LD_FLAGS) -o $@ 49 | 50 | $(LIGHTLDA_OBJ): %.o: %.cpp $(LIGHTLDA_HEADERS) $(MULTIVERSO_INC) 51 | $(CXX) $(CXXFLAGS) $(INC_FLAGS) -c $< -o $@ 52 | 53 | $(INFER): $(INFER_OBJ) $(BASE_OBJ) 54 | $(CXX) $(INFER_OBJ) $(BASE_OBJ) $(CXXFLAGS) $(INC_FLAGS) $(LD_FLAGS) -o $@ 55 | 56 | $(INFER_OBJ): %.o: %.cpp $(INFER_HEADERS) $(MULTIVERSO_INC) 57 | $(CXX) $(CXXFLAGS) $(INC_FLAGS) -c $< -o $@ 58 | 59 | $(DUMP_BINARY): $(DUMP_BINARY_SRC) 60 | $(CXX) $(CXXFLAGS) $< -o $@ 61 | 62 | lightlda: path $(LIGHTLDA) 63 | 64 | infer: path $(INFER) 65 | 66 | dump_binary: path $(DUMP_BINARY) 67 | 68 | clean: 69 | rm -rf $(BIN_DIR) $(LIGHTLDA_OBJ) $(INFER_OBJ) 70 | 71 | .PHONY: all path lightlda infer dump_binary clean 72 | -------------------------------------------------------------------------------- /src/trainer.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file trainer.h 3 | * \brief Defines multiverso interface for parameter loading and data training 4 | */ 5 | 6 | #ifndef LIGHTLDA_TRAINER_H_ 7 | #define LIGHTLDA_TRAINER_H_ 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | namespace multiverso { namespace lightlda 15 | { 16 | class AliasTable; 17 | class LDADataBlock; 18 | class LightDocSampler; 19 | class Meta; 20 | class PSModel; 21 | 22 | /*! \brief Trainer is responsible for training a data block */ 23 | class Trainer : public TrainerBase 24 | { 25 | public: 26 | Trainer(AliasTable* alias, Barrier* barrier, Meta* meta); 27 | ~Trainer(); 28 | /*! 29 | * \brief Defines Trainning method for a data_block in one iteration 30 | * \param data_block pointer to data block base 31 | */ 32 | void TrainIteration(DataBlockBase* data_block) override; 33 | /*! 34 | * \brief Evaluates a data block, compute its loss function 35 | * \param block pointer to data block 36 | */ 37 | void Evaluate(LDADataBlock* block); 38 | 39 | void Dump(int32_t iter, LDADataBlock* lda_data_block); 40 | 41 | private: 42 | /*! \brief alias table, for alias access */ 43 | AliasTable* alias_; 44 | /*! \brief sampler for lightlda */ 45 | LightDocSampler* sampler_; 46 | /*! \brief barrier for thread-sync */ 47 | Barrier* barrier_; 48 | /*! \brief meta information */ 49 | Meta* meta_; 50 | /*! \brief model acceccor */ 51 | PSModel * model_; 52 | static std::mutex mutex_; 53 | 54 | static double doc_llh_; 55 | static double word_llh_; 56 | }; 57 | 58 | /*! 59 | * \brief ParamLoader is responsible for parsing a data block and 60 | * preload parameters needed by this block 61 | */ 62 | class ParamLoader : public ParameterLoaderBase 63 | { 64 | /*! 65 | * \brief Parse a data block to record which parameters (word) is 66 | * needed for training this block 67 | */ 68 | void ParseAndRequest(DataBlockBase* data_block) override; 69 | }; 70 | 71 | } // namespace lightlda 72 | } // namespace multiverso 73 | 74 | #endif // LIGHTLDA_TRAINER_H_ 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightLDA 2 | 3 | LightLDA is a distributed system for large scale topic modeling. It implements a distributed sampler that enables very large data sizes and models. LightLDA improves sampling throughput and convergence speed via a fast O(1) metropolis-Hastings algorithm, and allows small cluster to tackle very large data and model sizes through model scheduling and data parallelism architecture. LightLDA is implemented with C++ for performance consideration. 4 | 5 | We have sucessfully trained big topic models (with trillions of parameters) on big data (Top 10% PageRank values of Bing indexed page, containing billions of documents) in Microsoft. For more technical details, please refer to our [WWW'15 paper](http://www.www2015.it/documents/proceedings/proceedings/p1351.pdf). 6 | 7 | For documents, please view our website [http://www.dmtk.io](http://www.dmtk.io). 8 | 9 | ## Why LightLDA 10 | 11 | The highlight features of LightLDA are 12 | 13 | * **Scalable**: LightLDA can train models with trillions of parameters on big data with billions of documents, a scale previous implementations cann't handle. 14 | * **Fast**: The sampler can sample millions of tokens per second per multi-core node. 15 | * **Lightweight**: Such big tasks can be trained with as few as tens of machines. 16 | 17 | ## Quick Start 18 | 19 | Run ``` $ sh build.sh ``` to build lightlda. 20 | Run ``` $ sh example/nytimes.sh ``` for a simple example. 21 | 22 | 23 | ## Reference 24 | 25 | Please cite LightLDA if it helps in your research: 26 | 27 | ``` 28 | @inproceedings{yuan2015lightlda, 29 | title={LightLDA: Big Topic Models on Modest Computer Clusters}, 30 | author={Yuan, Jinhui and Gao, Fei and Ho, Qirong and Dai, Wei and Wei, Jinliang and Zheng, Xun and Xing, Eric Po and Liu, Tie-Yan and Ma, Wei-Ying}, 31 | booktitle={Proceedings of the 24th International Conference on World Wide Web}, 32 | pages={1351--1361}, 33 | year={2015}, 34 | organization={International World Wide Web Conferences Steering Committee} 35 | } 36 | ``` 37 | 38 | Microsoft Open Source Code of Conduct 39 | ------------ 40 | 41 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 42 | -------------------------------------------------------------------------------- /src/document.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file document.h 3 | * \brief Defines Document data structure 4 | */ 5 | 6 | #ifndef LIGHTLDA_DOCUMENT_H_ 7 | #define LIGHTLDA_DOCUMENT_H_ 8 | 9 | #include "common.h" 10 | 11 | namespace multiverso 12 | { 13 | template class Row; 14 | } 15 | 16 | namespace multiverso { namespace lightlda 17 | { 18 | /*! 19 | * \brief Document presents a document. Document doesn't own memory, but 20 | * would interpret a contiguous piece of extern memory as a document 21 | * with the format : 22 | * #cursor, word1, topic1, word2, topic2, ..., wordn, topicn.# 23 | */ 24 | class Document 25 | { 26 | public: 27 | /*! 28 | * \brief Constructs a document based on the start and end pointer 29 | */ 30 | Document(int32_t* begin, int32_t* end); 31 | /*! \brief Get the length of the document */ 32 | int32_t Size() const; 33 | /*! \brief Get the word based on the index */ 34 | int32_t Word(int32_t index) const; 35 | /*! \brief Get the topic based on the index */ 36 | int32_t Topic(int32_t index) const; 37 | /*! \brief Get the cursor */ 38 | int32_t& Cursor(); 39 | /*! \brief Set the topic based on the index */ 40 | void SetTopic(int32_t index, int32_t topic); 41 | /*! \brief Get the doc-topic vector */ 42 | void GetDocTopicVector(Row& vec); 43 | private: 44 | int32_t* begin_; 45 | int32_t* end_; 46 | int32_t& cursor_; 47 | 48 | // No copying allowed 49 | Document(const Document&); 50 | void operator=(const Document&); 51 | }; 52 | 53 | // -- inline functions definition area --------------------------------- // 54 | inline int32_t Document::Size() const 55 | { 56 | return static_cast((end_ - begin_) / 2); 57 | } 58 | inline int32_t Document::Word(int32_t index) const 59 | { 60 | return *(begin_ + 1 + index * 2); 61 | } 62 | inline int32_t Document::Topic(int32_t index) const 63 | { 64 | return *(begin_ + 2 + index * 2); 65 | } 66 | inline int32_t& Document::Cursor() { return cursor_; } 67 | inline void Document::SetTopic(int32_t index, int32_t topic) 68 | { 69 | *(begin_ + 2 + index * 2) = topic; 70 | } 71 | // -- inline functions definition area --------------------------------- // 72 | 73 | } // namespace lightlda 74 | } // namespace multiverso 75 | 76 | #endif // LIGHTLDA_DOCUMENT_H_ 77 | -------------------------------------------------------------------------------- /src/model.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file model.h 3 | * \brief define local model reader 4 | */ 5 | 6 | #ifndef LIGHTLDA_MODEL_H_ 7 | #define LIGHTLDA_MODEL_H_ 8 | 9 | #include 10 | #include 11 | 12 | #include "common.h" 13 | #include 14 | 15 | namespace multiverso 16 | { 17 | template class Row; 18 | class Table; 19 | 20 | namespace lightlda 21 | { 22 | class Meta; 23 | class Trainer; 24 | 25 | /*! \brief interface for acceess to model */ 26 | class ModelBase 27 | { 28 | public: 29 | virtual ~ModelBase() {} 30 | virtual Row& GetWordTopicRow(integer_t word_id) = 0; 31 | virtual Row& GetSummaryRow() = 0; 32 | virtual void AddWordTopicRow(integer_t word_id, integer_t topic_id, 33 | int32_t delta) = 0; 34 | virtual void AddSummaryRow(integer_t topic_id, int64_t delta) = 0; 35 | }; 36 | 37 | /*! \brief model based on local buffer */ 38 | class LocalModel : public ModelBase 39 | { 40 | public: 41 | explicit LocalModel(Meta * meta); 42 | void Init(); 43 | 44 | Row& GetWordTopicRow(integer_t word_id) override; 45 | Row& GetSummaryRow() override; 46 | void AddWordTopicRow(integer_t word_id, integer_t topic_id, 47 | int32_t delta) override; 48 | void AddSummaryRow(integer_t topic_id, int64_t delta) override; 49 | 50 | private: 51 | void CreateTable(); 52 | void LoadTable(); 53 | void LoadWordTopicTable(const std::string& model_fname); 54 | void LoadSummaryTable(const std::string& model_fname); 55 | 56 | std::unique_ptr word_topic_table_; 57 | std::unique_ptr
summary_table_; 58 | Meta* meta_; 59 | 60 | LocalModel(const LocalModel&) = delete; 61 | void operator=(const LocalModel&) = delete; 62 | }; 63 | 64 | /*! \brief model based on parameter server */ 65 | class PSModel : public ModelBase 66 | { 67 | public: 68 | explicit PSModel(Trainer* trainer) : trainer_(trainer) {} 69 | 70 | Row& GetWordTopicRow(integer_t word_id) override; 71 | Row& GetSummaryRow() override; 72 | void AddWordTopicRow(integer_t word_id, integer_t topic_id, 73 | int32_t delta) override; 74 | void AddSummaryRow(integer_t topic_id, int64_t delta) override; 75 | 76 | private: 77 | Trainer* trainer_; 78 | 79 | PSModel(const PSModel&) = delete; 80 | void operator=(const PSModel&) = delete; 81 | }; 82 | 83 | } // namespace lightlda 84 | } // namespace multiverso 85 | 86 | #endif // LIGHTLDA_MODEL_H_ 87 | -------------------------------------------------------------------------------- /inference/inferer.cpp: -------------------------------------------------------------------------------- 1 | #include "inferer.h" 2 | 3 | #include "alias_table.h" 4 | #include "common.h" 5 | #include "data_block.h" 6 | #include "meta.h" 7 | #include "sampler.h" 8 | #include "model.h" 9 | #include "data_stream.h" 10 | #include 11 | #include 12 | #include 13 | 14 | namespace multiverso { namespace lightlda 15 | { 16 | Inferer::Inferer(AliasTable* alias_table, 17 | IDataStream * data_stream, 18 | Meta* meta, LocalModel * model, 19 | Barrier* barrier, 20 | int32_t id, int32_t thread_num): 21 | alias_(alias_table), data_stream_(data_stream), 22 | meta_(meta), model_(model), 23 | barrier_(barrier), 24 | id_(id), thread_num_(thread_num) 25 | { 26 | sampler_ = new LightDocSampler(); 27 | } 28 | 29 | Inferer::~Inferer() 30 | { 31 | delete sampler_; 32 | } 33 | 34 | void Inferer::BeforeIteration(int32_t block) 35 | { 36 | //init current data block 37 | if(id_ == 0) 38 | { 39 | data_stream_->BeforeDataAccess(); 40 | DataBlock& data = data_stream_->CurrDataBlock(); 41 | data.set_meta(&(meta_->local_vocab(block))); 42 | alias_->Init(meta_->alias_index(block, 0)); 43 | alias_->Build(-1, model_); 44 | } 45 | barrier_->Wait(); 46 | 47 | // build alias table 48 | DataBlock& data = data_stream_->CurrDataBlock(); 49 | const LocalVocab& local_vocab = data.meta(); 50 | StopWatch watch; watch.Start(); 51 | for (const int32_t* pword = local_vocab.begin(0) + id_; 52 | pword < local_vocab.end(0); 53 | pword += thread_num_) 54 | { 55 | alias_->Build(*pword, model_); 56 | } 57 | barrier_->Wait(); 58 | if (id_ == 0) 59 | { 60 | Log::Info("block=%d, Alias Time used: %.2f s \n", block, watch.ElapsedSeconds()); 61 | } 62 | } 63 | 64 | void Inferer::DoIteration(int32_t iter) 65 | { 66 | if (id_ == 0) 67 | { 68 | Log::Info("iter=%d\n", iter); 69 | } 70 | DataBlock& data = data_stream_->CurrDataBlock(); 71 | const LocalVocab& local_vocab = data.meta(); 72 | int32_t lastword = local_vocab.LastWord(0); 73 | // Inference with lightlda sampler 74 | for (int32_t doc_id = id_; doc_id < data.Size(); doc_id += thread_num_) 75 | { 76 | Document* doc = data.GetOneDoc(doc_id); 77 | sampler_->SampleOneDoc(doc, 0, lastword, model_, alias_); 78 | } 79 | } 80 | 81 | void Inferer::EndIteration() 82 | { 83 | barrier_->Wait(); 84 | if(id_ == 0) 85 | { 86 | data_stream_->EndDataAccess(); 87 | alias_->Clear(); 88 | } 89 | } 90 | 91 | } // namespace lightlda 92 | } // namespace multiverso 93 | -------------------------------------------------------------------------------- /src/alias_table.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file alias_table.h 3 | * \brief Defines alias table 4 | */ 5 | 6 | #ifndef LIGHTLDA_ALIAS_TABLE_H_ 7 | #define LIGHTLDA_ALIAS_TABLE_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #if defined(_WIN32) || defined(_WIN64) 14 | // vs currently not support c++11 keyword thread_local 15 | #define _THREAD_LOCAL __declspec(thread) 16 | #else 17 | #define _THREAD_LOCAL thread_local 18 | #endif 19 | 20 | namespace multiverso { namespace lightlda 21 | { 22 | class ModelBase; 23 | class xorshift_rng; 24 | class AliasTableIndex; 25 | 26 | /*! 27 | * \brief AliasTable is the storage for alias tables used for fast sampling 28 | * from lightlda word proposal distribution. It optimize memory usage 29 | * through a hybrid storage by exploiting the sparsity of word proposal. 30 | * AliasTable containes two part: 1) a memory pool to store the alias 31 | * 2) an index table to access each row 32 | */ 33 | class AliasTable 34 | { 35 | public: 36 | AliasTable(); 37 | ~AliasTable(); 38 | /*! 39 | * \brief Set the table index. Must call this method before 40 | */ 41 | void Init(AliasTableIndex* table_index); 42 | /*! 43 | * \brief Build alias table for a word 44 | * \param word word to bulid 45 | * \param model access 46 | * \return success of not 47 | */ 48 | int Build(int word, ModelBase* model); 49 | /*! 50 | * \brief sample from word proposal distribution 51 | * \param word word to sample 52 | * \param rng random number generator 53 | * \return sample proposed from the distribution 54 | */ 55 | int Propose(int word, xorshift_rng& rng); 56 | /*! \brief Clear the alias table */ 57 | void Clear(); 58 | private: 59 | void AliasMultinomialRNG(int32_t size, float mass, int32_t& height, 60 | int32_t* kv_vector); 61 | int* memory_block_; 62 | int64_t memory_size_; 63 | AliasTableIndex* table_index_; 64 | 65 | std::vector height_; 66 | std::vector mass_; 67 | int32_t beta_height_; 68 | float beta_mass_; 69 | 70 | int32_t* beta_kv_vector_; 71 | 72 | // thread local storage used for building alias 73 | _THREAD_LOCAL static std::vector* q_w_proportion_; 74 | _THREAD_LOCAL static std::vector* q_w_proportion_int_; 75 | _THREAD_LOCAL static std::vector>* L_; 76 | _THREAD_LOCAL static std::vector>* H_; 77 | 78 | int num_vocabs_; 79 | int num_topics_; 80 | float beta_; 81 | float beta_sum_; 82 | 83 | // No copying allowed 84 | AliasTable(const AliasTable&); 85 | void operator=(const AliasTable&); 86 | }; 87 | } // namespace lightlda 88 | } // namespace multiverso 89 | #endif // LIGHTLDA_ALIAS_TABLE_H_ 90 | -------------------------------------------------------------------------------- /src/eval.cpp: -------------------------------------------------------------------------------- 1 | #include "eval.h" 2 | 3 | #include 4 | 5 | #include "common.h" 6 | #include "document.h" 7 | #include "trainer.h" 8 | 9 | #include 10 | #include 11 | 12 | namespace 13 | { 14 | const double cof[6] = 15 | { 16 | 76.18009172947146, -86.50532032941677, 17 | 24.01409824083091, -1.231739572450155, 18 | 0.1208650973866179e-2, -0.5395239384953e-5 19 | }; 20 | 21 | double LogGamma(double xx) 22 | { 23 | int j; 24 | double x, y, tmp1, ser; 25 | y = xx; 26 | x = xx; 27 | tmp1 = x + 5.5; 28 | tmp1 -= (x + 0.5)*log(tmp1); 29 | ser = 1.000000000190015; 30 | for (j = 0; j < 6; j++) ser += cof[j] / ++y; 31 | return -tmp1 + log(2.5066282746310005*ser / x); 32 | } 33 | } 34 | 35 | namespace multiverso { namespace lightlda 36 | { 37 | double Eval::ComputeOneDocLLH(Document* doc, Row& doc_topic_counter) 38 | { 39 | if (doc->Size() == 0) return 0.0; 40 | double one_doc_llh = LogGamma(Config::num_topics * Config::alpha) 41 | - Config::num_topics * LogGamma(Config::alpha); 42 | int32_t nonzero_num = 0; 43 | doc_topic_counter.Clear(); 44 | doc->GetDocTopicVector(doc_topic_counter); 45 | Row::iterator iter = doc_topic_counter.Iterator(); 46 | while (iter.HasNext()) 47 | { 48 | one_doc_llh += LogGamma(iter.Value() + Config::alpha); 49 | ++nonzero_num; 50 | iter.Next(); 51 | } 52 | one_doc_llh += (Config::num_topics - nonzero_num) 53 | * LogGamma(Config::alpha); 54 | one_doc_llh -= LogGamma(doc->Size() + 55 | Config::alpha * Config::num_topics); 56 | return one_doc_llh; 57 | } 58 | 59 | double Eval::ComputeOneWordLLH(int32_t word, Trainer* trainer) 60 | { 61 | Row& params = trainer->GetRow( 62 | kWordTopicTable, word); 63 | if (params.NonzeroSize() == 0) return 0.0; 64 | double word_llh = 0.0; 65 | int32_t nonzero_num = 0; 66 | RowIterator iter = params.Iterator(); 67 | while (iter.HasNext()) 68 | { 69 | word_llh += LogGamma(iter.Value() + Config::beta); 70 | ++nonzero_num; 71 | iter.Next(); 72 | } 73 | word_llh += (Config::num_topics - nonzero_num) 74 | * LogGamma(Config::beta); 75 | return word_llh; 76 | } 77 | 78 | double Eval::NormalizeWordLLH(Trainer* trainer) 79 | { 80 | Row& params = trainer->GetRow(kSummaryRow, 0); 81 | double llh = Config::num_topics * 82 | (LogGamma(Config::beta * Config::num_vocabs) - 83 | Config::num_vocabs * LogGamma(Config::beta)); 84 | for (int32_t k = 0; k < Config::num_topics; ++k) 85 | { 86 | llh -= LogGamma(params.At(k) 87 | + Config::num_vocabs * Config::beta); 88 | } 89 | return llh; 90 | } 91 | } // namespace lightlda 92 | } // namespace multiverso 93 | -------------------------------------------------------------------------------- /src/sampler.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file util.h 3 | * \brief Defines lightlda samplers 4 | */ 5 | 6 | #ifndef LIGHTLDA_SAMPLER_H_ 7 | #define LIGHTLDA_SAMPLER_H_ 8 | 9 | #include 10 | #include "util.h" 11 | 12 | namespace multiverso 13 | { 14 | template 15 | class Row; 16 | } 17 | 18 | namespace multiverso { namespace lightlda 19 | { 20 | class AliasTable; 21 | class Document; 22 | class ModelBase; 23 | 24 | /*! \brief lightlda sampler */ 25 | class LightDocSampler 26 | { 27 | public: 28 | LightDocSampler(); 29 | /*! 30 | * \brief Sample one document, update latent topic assignment 31 | * and statistics 32 | * \param doc pointer to document 33 | * \param slice slice id 34 | * \param lastword last word of current slice 35 | * \param model pointer model, for access of model 36 | * \param alias pointer to alias table, for access of alias 37 | * \return number of sampled token 38 | */ 39 | int32_t SampleOneDoc(Document* doc, int32_t slice, int32_t lastword, 40 | ModelBase* model, AliasTable* alias); 41 | /*! 42 | * \brief Get doc-topic-counter, for reusing this container 43 | * \return reference to light hash map 44 | */ 45 | Row& doc_topic_counter() { return *doc_topic_counter_; } 46 | private: 47 | /*! 48 | * \brief Init document before sampling 49 | * \param doc pointer to document 50 | */ 51 | void DocInit(Document* doc); 52 | /*! 53 | * \brief Sample the latent topic assignment for a token 54 | * \param doc current document 55 | * \param word current token 56 | * \param state state of the word 57 | * \param old_topic old topic assignment of this token 58 | * \param model access 59 | * \param alias for alias table access 60 | */ 61 | int32_t Sample(Document* doc, int32_t word, int32_t state, 62 | int32_t old_topic, ModelBase* model, AliasTable* alias); 63 | 64 | /*! 65 | * \brief Sample the latent topic assignment for a token. This function 66 | * make a little approximation to the proper Metropolis-Hasting 67 | * algorithm, but empirically this converges as good as exact Sample, 68 | * with faster speed. 69 | * \param same with Sample 70 | */ 71 | int32_t ApproxSample(Document* doc, int32_t word, int32_t state, 72 | int32_t old_topic, ModelBase* model, AliasTable* alias); 73 | private: 74 | // lda hyper-parameter 75 | float alpha_; 76 | float beta_; 77 | float alpha_sum_; 78 | float beta_sum_; 79 | 80 | int32_t subtractor_; 81 | 82 | int32_t num_vocab_; 83 | int32_t num_topic_; 84 | int32_t mh_steps_; 85 | 86 | xorshift_rng rng_; 87 | std::unique_ptr> doc_topic_counter_; 88 | }; 89 | } // namespace lightlda 90 | } // namespace multiverso 91 | 92 | #endif // LIGHTLDA_SAMPLER_H_ 93 | -------------------------------------------------------------------------------- /src/common.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file common.h 3 | * \brief Defines common settings in LightLDA 4 | */ 5 | 6 | #ifndef LIGHTLDA_COMMON_H_ 7 | #define LIGHTLDA_COMMON_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace multiverso { namespace lightlda 14 | { 15 | /*! \brief constant variable for table id */ 16 | const int32_t kWordTopicTable = 0; 17 | /*! \brief constant variable for table id */ 18 | const int32_t kSummaryRow = 1; 19 | /*! \brief load factor for sparse hash table */ 20 | const int32_t kLoadFactor = 2; 21 | /*! \brief max length of a document */ 22 | const int32_t kMaxDocLength = 8192; 23 | 24 | // 25 | typedef int64_t DocNumber; 26 | 27 | /*! 28 | * \brief Defines LightLDA configs 29 | */ 30 | struct Config 31 | { 32 | public: 33 | /*! \brief Inits configs from command line arguments */ 34 | static void Init(int argc, char* argv[]); 35 | /*! \brief size of vocabulary */ 36 | static int32_t num_vocabs; 37 | /*! \brief number of topics */ 38 | static int32_t num_topics; 39 | /*! \brief number of iterations for trainning */ 40 | static int32_t num_iterations; 41 | /*! \brief number of metropolis-hastings steps */ 42 | static int32_t mh_steps; 43 | /*! \brief number of servers for Multiverso setting */ 44 | static int32_t num_servers; 45 | /*! \brief server endpoint file */ 46 | static std::string server_file; 47 | /*! \brief number of worker threads */ 48 | static int32_t num_local_workers; 49 | /*! \brief number of local aggregation threads */ 50 | static int32_t num_aggregator; 51 | /*! \brief number of blocks to train in disk */ 52 | static int32_t num_blocks; 53 | /*! \brief maximum number of documents in a block */ 54 | static int64_t max_num_document; 55 | /*! \brief hyper-parameter for symmetric dirichlet prior */ 56 | static float alpha; 57 | /*! \brief hyper-parameter for symmetric dirichlet prior */ 58 | static float beta; 59 | /*! \brief path of input directory */ 60 | static std::string input_dir; 61 | /*! \brief option specify whether warm_start */ 62 | static bool warm_start; 63 | /*! \brief inference mode */ 64 | static bool inference; 65 | /*! \brief option specity whether use out of core computation */ 66 | static bool out_of_core; 67 | /*! \brief memory capacity settings, for memory pools */ 68 | static int64_t data_capacity; 69 | static int64_t model_capacity; 70 | static int64_t delta_capacity; 71 | static int64_t alias_capacity; 72 | private: 73 | /*! \brief Print usage */ 74 | static void PrintUsage(); 75 | static void PrintTrainingUsage(); 76 | static void PrintInferenceUsage(); 77 | /*! \brief Check if the configs are valid */ 78 | static void Check(); 79 | }; 80 | } // namespace lightlda 81 | } // namespace multiverso 82 | 83 | #endif // LIGHTLDA_COMMON_H_ 84 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | #LightLDA usage 2 | 3 | Running ```lightlda --help``` gives the usage information. 4 | 5 | LightLDA usage: 6 | ``` 7 | -num_vocabs Size of dataset vocabulary 8 | -num_topics Number of topics. Default: 100 9 | -num_iterations Number of iteratioins. Default: 100 10 | -mh_steps Metropolis-hasting steps. Default: 2 11 | -alpha Dirichlet prior alpha. Default: 0.1 12 | -beta Dirichlet prior beta. Default: 0.01 13 | -num_blocks Number of blocks in disk. Default: 1 14 | -max_num_document Max number of document in a data block 15 | -input_dir Directory of input data, containing 16 | files generated by dump_block 17 | -num_servers Number of servers. Default: 1 18 | -num_local_workers Number of local training threads. Default: 4 19 | -num_aggregator Number of local aggregation threads. Default: 1 20 | -server_file Server endpoint file. Used by MPI-free version 21 | -warm_start Warm start 22 | -out_of_core Use out of core computing 23 | -data_capacity Memory pool size(MB) for data storage, 24 | should larger than the any data block 25 | -model_capacity Memory pool size(MB) for local model cache 26 | -alias_capacity Memory pool size(MB) for alias table 27 | -delta_capacity Memory pool size(MB) for local delta cache 28 | ``` 29 | #Note on the input data 30 | 31 | The input data is placed in a folder, which is specified by the command line argument ```input_dir```. 32 | 33 | This folder should contains files named as ```block.id```, ```vocab.id```. The ```id``` is range from 0 to N-1 where ```N``` is the number of data block. 34 | 35 | The input data should be generated by the tool ```dump_binary```(released along with LightLDA), which convert the libsvm format in a binary format. This is for training efficiency consideration. 36 | 37 | #Note on the arguments about capacity 38 | 39 | In LightLDA, almost all the memory chunk is pre-allocated. LightLDA uses these fixed-capacity memory as memory pool. 40 | 41 | For data capacity, you should assign a value at least larger than the largest size of your binary training block file(generated by ```dump_binary```, see Note on input data above). 42 | 43 | For ```model/alias/delta capacity```, you can assign any value. LightLDA handles big model challenge under limited memory condition by model scheduling, which loads only a slice of needed parameters that can fit into the pre-allocated memory and schedules only related tokens to train. To reduce the wait time, the next slice is prefetched in the background. Empirically, ```model capacity``` and ```alias capacity``` are in same order. ```delta capacity``` can be much smaller than model/alias capacity. Logs will gives the actually memory size used at the beggning of program. You can use this information to adjust these arguments to achieve better computation/memory efficiency. 44 | 45 | #Note on distirubted running 46 | 47 | Data should be distributed into different nodes. 48 | 49 | Running with MPI, you just need to run ```mpiexec --machinefile machine_file lightlda -lightlda_arguments... ``` 50 | 51 | Running without MPI, you need to prepare a server_endpoint file which contains ip:port information for server process. 52 | -------------------------------------------------------------------------------- /windows/lightlda.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 2013 4 | VisualStudioVersion = 12.0.40629.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "lightlda", "lightlda\lightlda.vcxproj", "{9A6EEB61-6B68-49A3-B83A-E044F6678D6A}" 7 | EndProject 8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "dump_binary", "dump_binary\dump_binary.vcxproj", "{FFD24CAD-825A-4732-8186-4361D3E1438F}" 9 | EndProject 10 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "infer", "infer\infer.vcxproj", "{3CF22D68-9F4B-46B9-B0D4-7129467DD759}" 11 | EndProject 12 | Global 13 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 14 | Debug|Mixed Platforms = Debug|Mixed Platforms 15 | Debug|Win32 = Debug|Win32 16 | Debug|x64 = Debug|x64 17 | Release|Mixed Platforms = Release|Mixed Platforms 18 | Release|Win32 = Release|Win32 19 | Release|x64 = Release|x64 20 | EndGlobalSection 21 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 22 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Debug|Mixed Platforms.ActiveCfg = Debug|Win32 23 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Debug|Mixed Platforms.Build.0 = Debug|Win32 24 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Debug|Win32.ActiveCfg = Debug|Win32 25 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Debug|Win32.Build.0 = Debug|Win32 26 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Debug|x64.ActiveCfg = Debug|x64 27 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Debug|x64.Build.0 = Debug|x64 28 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Release|Mixed Platforms.ActiveCfg = Release|Win32 29 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Release|Mixed Platforms.Build.0 = Release|Win32 30 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Release|Win32.ActiveCfg = Release|Win32 31 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Release|Win32.Build.0 = Release|Win32 32 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Release|x64.ActiveCfg = Release|x64 33 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A}.Release|x64.Build.0 = Release|x64 34 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Debug|Mixed Platforms.ActiveCfg = Debug|Win32 35 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Debug|Mixed Platforms.Build.0 = Debug|Win32 36 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Debug|Win32.ActiveCfg = Debug|Win32 37 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Debug|Win32.Build.0 = Debug|Win32 38 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Debug|x64.ActiveCfg = Debug|Win32 39 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Release|Mixed Platforms.ActiveCfg = Release|Win32 40 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Release|Mixed Platforms.Build.0 = Release|Win32 41 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Release|Win32.ActiveCfg = Release|Win32 42 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Release|Win32.Build.0 = Release|Win32 43 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Release|x64.ActiveCfg = Release|x64 44 | {FFD24CAD-825A-4732-8186-4361D3E1438F}.Release|x64.Build.0 = Release|x64 45 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Debug|Mixed Platforms.ActiveCfg = Debug|Win32 46 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Debug|Mixed Platforms.Build.0 = Debug|Win32 47 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Debug|Win32.ActiveCfg = Debug|Win32 48 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Debug|Win32.Build.0 = Debug|Win32 49 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Debug|x64.ActiveCfg = Debug|Win32 50 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Release|Mixed Platforms.ActiveCfg = Release|Win32 51 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Release|Mixed Platforms.Build.0 = Release|Win32 52 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Release|Win32.ActiveCfg = Release|Win32 53 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Release|Win32.Build.0 = Release|Win32 54 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Release|x64.ActiveCfg = Release|x64 55 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759}.Release|x64.Build.0 = Release|x64 56 | EndGlobalSection 57 | GlobalSection(SolutionProperties) = preSolution 58 | HideSolutionNode = FALSE 59 | EndGlobalSection 60 | EndGlobal 61 | -------------------------------------------------------------------------------- /src/data_block.cpp: -------------------------------------------------------------------------------- 1 | #include "data_block.h" 2 | #include "document.h" 3 | #include "common.h" 4 | 5 | #include 6 | 7 | #include 8 | 9 | #if defined(_WIN32) || defined(_WIN64) 10 | #include 11 | #else 12 | #include 13 | #endif 14 | 15 | namespace 16 | { 17 | void AtomicMoveFileExA(std::string existing_file, std::string new_file) 18 | { 19 | #if defined(_WIN32) || defined(_WIN64) 20 | MoveFileExA(existing_file.c_str(), new_file.c_str(), 21 | MOVEFILE_REPLACE_EXISTING); 22 | #else 23 | if (rename(existing_file.c_str(), new_file.c_str()) == -1) 24 | { 25 | multiverso::Log::Error("Failed to move tmp file to final location\n"); 26 | } 27 | #endif 28 | } 29 | } 30 | 31 | namespace multiverso { namespace lightlda 32 | { 33 | DataBlock::DataBlock() 34 | : has_read_(false), num_document_(0), corpus_size_(0), vocab_(nullptr) 35 | { 36 | max_num_document_ = Config::max_num_document; 37 | memory_block_size_ = Config::data_capacity / sizeof(int32_t); 38 | 39 | documents_.resize(max_num_document_); 40 | 41 | try{ 42 | offset_buffer_ = new int64_t[max_num_document_]; 43 | } 44 | catch (std::bad_alloc& ba) { 45 | Log::Fatal("Bad Alloc caught: failed memory allocation for offset_buffer in DataBlock\n"); 46 | } 47 | 48 | try{ 49 | documents_buffer_ = new int32_t[memory_block_size_]; 50 | } 51 | catch (std::bad_alloc& ba) { 52 | Log::Fatal("Bad Alloc caught: failed memory allocation for documents_buffer in DataBlock\n"); 53 | } 54 | } 55 | 56 | DataBlock::~DataBlock() 57 | { 58 | delete[] offset_buffer_; 59 | delete[] documents_buffer_; 60 | } 61 | 62 | void DataBlock::Read(std::string file_name) 63 | { 64 | file_name_ = file_name; 65 | 66 | std::ifstream block_file(file_name_, std::ios::in | std::ios::binary); 67 | if (!block_file.good()) 68 | { 69 | Log::Fatal("Failed to read data %s\n", file_name_.c_str()); 70 | } 71 | block_file.read(reinterpret_cast(&num_document_), sizeof(DocNumber)); 72 | 73 | if (num_document_ > max_num_document_) 74 | { 75 | Log::Fatal("Rank %d: Num of documents > max number of documents when reading file %s\n", 76 | Multiverso::ProcessRank(), file_name_.c_str()); 77 | } 78 | 79 | block_file.read(reinterpret_cast(offset_buffer_), 80 | sizeof(int64_t)* (num_document_ + 1)); 81 | 82 | corpus_size_ = offset_buffer_[num_document_]; 83 | 84 | if (corpus_size_ > memory_block_size_) 85 | { 86 | Log::Fatal("Rank %d: corpus_size_ > memory_block_size when reading file %s\n", 87 | Multiverso::ProcessRank(), file_name_.c_str()); 88 | } 89 | 90 | block_file.read(reinterpret_cast(documents_buffer_), 91 | sizeof(int32_t)* corpus_size_); 92 | block_file.close(); 93 | 94 | GenerateDocuments(); 95 | has_read_ = true; 96 | } 97 | 98 | void DataBlock::Write() 99 | { 100 | std::string temp_file = file_name_ + ".temp"; 101 | 102 | std::ofstream block_file(temp_file, std::ios::out | std::ios::binary); 103 | 104 | if (!block_file.good()) 105 | { 106 | Log::Fatal("Failed to open file %s\n", temp_file.c_str()); 107 | } 108 | 109 | block_file.write(reinterpret_cast(&num_document_), 110 | sizeof(DocNumber)); 111 | block_file.write(reinterpret_cast(offset_buffer_), 112 | sizeof(int64_t)* (num_document_ + 1)); 113 | block_file.write(reinterpret_cast(documents_buffer_), 114 | sizeof(int32_t)* corpus_size_); 115 | block_file.flush(); 116 | block_file.close(); 117 | 118 | AtomicMoveFileExA(temp_file, file_name_); 119 | has_read_ = false; 120 | } 121 | 122 | void DataBlock::GenerateDocuments() 123 | { 124 | for (int32_t index = 0; index < num_document_; ++index) 125 | { 126 | documents_[index].reset(new Document( 127 | documents_buffer_ + offset_buffer_[index], 128 | documents_buffer_ + offset_buffer_[index + 1])); 129 | } 130 | } 131 | } // namespace lightlda 132 | } // namespace multiverso 133 | -------------------------------------------------------------------------------- /src/meta.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file meta.h 3 | * \brief This file defines meta information for training dataset 4 | */ 5 | 6 | #ifndef LIGHTLDA_META_H_ 7 | #define LIGHTLDA_META_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace multiverso { namespace lightlda 14 | { 15 | /*! 16 | * \brief LocalVocab defines the meta information of a data block. 17 | * It containes 1) which words occurs in this block, 2) slice information 18 | */ 19 | class LocalVocab 20 | { 21 | public: 22 | friend class Meta; 23 | LocalVocab(); 24 | ~LocalVocab(); 25 | /*! \brief Get the last word of current slice */ 26 | int32_t LastWord(int32_t slice) const; 27 | /*! \brief Get the number of slice */ 28 | int32_t num_slice() const; 29 | /*! \brief Get the pointer to first word in this slice */ 30 | const int* begin(int32_t slice) const; 31 | /*! \brief Get the pointer to last word + 1 in this slice */ 32 | const int32_t* end(int32_t slice) const; 33 | private: 34 | int32_t num_slices_; 35 | int32_t* vocabs_; 36 | int32_t size_; 37 | bool own_memory_; 38 | std::vector slice_index_; 39 | }; 40 | 41 | 42 | struct WordEntry 43 | { 44 | bool is_dense; 45 | int64_t begin_offset; 46 | int32_t capacity; 47 | }; 48 | 49 | class AliasTableIndex 50 | { 51 | public: 52 | AliasTableIndex(); 53 | WordEntry& word_entry(int32_t word); 54 | void PushWord(int32_t word, bool is_dense, 55 | int64_t begin_offset, int32_t capacity); 56 | private: 57 | std::vector index_; 58 | std::vector index_map_; 59 | }; 60 | 61 | /*! 62 | * \brief Meta containes all the meta information of training data in 63 | * current process. It containes 1) all the local vacabs for all data 64 | * blocks, 2) the global tf for the training dataset 65 | */ 66 | class Meta 67 | { 68 | public: 69 | Meta(); 70 | ~Meta(); 71 | /*! \brief Initialize the Meta information */ 72 | void Init(); 73 | /*! \brief Get the tf of word in the whole dataset */ 74 | int32_t tf(int32_t word) const; 75 | /*! \brief Get the tf of word in local dataset */ 76 | int32_t local_tf(int32_t word) const; 77 | /*! \brief Get the local vocab based on block id */ 78 | const LocalVocab& local_vocab(int32_t id) const; 79 | 80 | AliasTableIndex* alias_index(int32_t block, int32_t slice); 81 | private: 82 | /*! \brief Schedule the model and split as slices based on memory */ 83 | void ModelSchedule(); 84 | /*! \brief Schedule the model without vocabulary sliptting */ 85 | void ModelSchedule4Inference(); 86 | /*! \brief Build index for alias table */ 87 | void BuildAliasIndex(); 88 | private: 89 | /*! \brief meta information for all data block */ 90 | std::vector local_vocabs_; 91 | /*! \breif tf information for all word in the dataset */ 92 | std::vector tf_; 93 | /*! \brief local tf information for all word in this machine */ 94 | std::vector local_tf_; 95 | 96 | std::vector > alias_index_; 97 | // No copying allowed 98 | Meta(const Meta&); 99 | void operator=(const Meta&); 100 | }; 101 | 102 | // -- inline functions definition area --------------------------------- // 103 | inline int32_t LocalVocab::LastWord(int32_t slice) const 104 | { 105 | return vocabs_[slice_index_[slice + 1] - 1]; 106 | } 107 | inline int32_t LocalVocab::num_slice() const { return num_slices_; } 108 | inline const int32_t* LocalVocab::begin(int32_t slice) const 109 | { 110 | return vocabs_ + slice_index_[slice]; 111 | } 112 | inline const int32_t* LocalVocab::end(int32_t slice) const 113 | { 114 | return vocabs_ + slice_index_[slice + 1]; 115 | } 116 | inline int32_t Meta::tf(int32_t word) const { return tf_[word]; } 117 | inline int32_t Meta::local_tf(int32_t word) const { return local_tf_[word]; } 118 | inline const LocalVocab& Meta::local_vocab(int32_t id) const 119 | { 120 | return local_vocabs_[id]; 121 | } 122 | inline AliasTableIndex* Meta::alias_index(int32_t block, int32_t slice) 123 | { 124 | return alias_index_[block][slice]; 125 | } 126 | // -- inline functions definition area --------------------------------- // 127 | 128 | } // namespace lightlda 129 | } // namespace multiverso 130 | 131 | #endif // LIGHTLDA_META_H_ 132 | -------------------------------------------------------------------------------- /src/data_block.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file data_block.h 3 | * \brief Defines the training data block 4 | */ 5 | 6 | #ifndef LIGHTLDA_DATA_BLOCK_H_ 7 | #define LIGHTLDA_DATA_BLOCK_H_ 8 | 9 | #include "common.h" 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | namespace multiverso { namespace lightlda 18 | { 19 | class Document; 20 | class LocalVocab; 21 | /*! 22 | * \brief DataBlock is the an unit of the training dataset, 23 | * it correspond to a data block file in disk. 24 | */ 25 | class DataBlock 26 | { 27 | public: 28 | DataBlock(); 29 | ~DataBlock(); 30 | /*! \brief Reads a block of data into data block from disk */ 31 | void Read(std::string file_name); 32 | /*! \brief Writes a block of data to disk */ 33 | void Write(); 34 | 35 | bool HasLoad() const; 36 | 37 | /*! \brief Gets the size (number of documents) of data block */ 38 | DocNumber Size() const; 39 | /*! 40 | * \brief Gets one document 41 | * \param index index of document 42 | * \return pointer to document 43 | */ 44 | Document* GetOneDoc(int32_t index); 45 | 46 | // mutator and accessor methods 47 | const LocalVocab& meta() const; 48 | void set_meta(const LocalVocab* local_vocab); 49 | private: 50 | void GenerateDocuments(); 51 | bool has_read_; 52 | /*! \brief size of memory pool for document offset */ 53 | int64_t max_num_document_; 54 | /*! \brief size of memory pool for documents */ 55 | int64_t memory_block_size_; 56 | /*! \brief index to each document */ 57 | std::vector> documents_; 58 | /*! \brief number of document in this block */ 59 | DocNumber num_document_; 60 | /*! \brief memory pool to store the document offset */ 61 | int64_t* offset_buffer_; 62 | /*! \brief actual memory size used */ 63 | int64_t corpus_size_; 64 | /*! \brief memory pool to store the documents */ 65 | int32_t* documents_buffer_; 66 | /*! \brief meta(vocabs) information of current data block */ 67 | const LocalVocab* vocab_; 68 | /*! \brief file name in disk */ 69 | std::string file_name_; 70 | // No copying allowed 71 | DataBlock(const DataBlock&); 72 | void operator=(const DataBlock&); 73 | }; 74 | 75 | /*! 76 | * \brief LDADataBlock is a logic data block that multiverso used to 77 | * train lightlda 78 | */ 79 | class LDADataBlock : public DataBlockBase 80 | { 81 | public: 82 | // mutator and accessor methods 83 | int32_t block() const; 84 | void set_block(int32_t block); 85 | int32_t slice() const; 86 | void set_slice(int32_t slice); 87 | int32_t iteration() const; 88 | void set_iteration(int32_t iteration); 89 | DataBlock& data(); 90 | void set_data(DataBlock* data); 91 | private: 92 | /*! \brief the actual data block */ 93 | DataBlock* data_; 94 | /*! \brief the data block id */ 95 | int32_t block_; 96 | /*! \brief the slice id */ 97 | int32_t slice_; 98 | /*! \brief the i-th iteration */ 99 | int32_t iteration_; 100 | }; 101 | 102 | // -- inline functions definition area --------------------------------- // 103 | 104 | inline bool DataBlock::HasLoad() const { return has_read_; } 105 | inline Document* DataBlock::GetOneDoc(int32_t index) 106 | { 107 | return documents_[index].get(); 108 | } 109 | inline const LocalVocab& DataBlock::meta() const { return *vocab_; } 110 | inline void DataBlock::set_meta(const LocalVocab* local_vocab) 111 | { 112 | vocab_ = local_vocab; 113 | } 114 | inline int32_t LDADataBlock::block() const { return block_; } 115 | inline void LDADataBlock::set_block(int32_t block) { block_ = block; } 116 | inline int32_t LDADataBlock::slice() const { return slice_; } 117 | inline void LDADataBlock::set_slice(int32_t slice) { slice_ = slice; } 118 | inline int32_t LDADataBlock::iteration() const { return iteration_; } 119 | inline void LDADataBlock::set_iteration(int32_t iteration) 120 | { 121 | iteration_ = iteration; 122 | } 123 | 124 | inline DataBlock& LDADataBlock::data() { return *data_; } 125 | inline void LDADataBlock::set_data(DataBlock* data) { data_ = data; } 126 | inline DocNumber DataBlock::Size() const { return num_document_; } 127 | 128 | // -- inline functions definition area --------------------------------- // 129 | 130 | } // namespace lightlda 131 | } // namespace multiverso 132 | 133 | #endif // LIGHTLDA_DATA_BLOCK_H_ 134 | -------------------------------------------------------------------------------- /src/data_stream.cpp: -------------------------------------------------------------------------------- 1 | #include "data_stream.h" 2 | #include "common.h" 3 | #include "data_block.h" 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace multiverso { namespace lightlda 11 | { 12 | class MemoryDataStream :public IDataStream 13 | { 14 | public: 15 | MemoryDataStream(int32_t num_blocks, std::string data_path); 16 | virtual ~MemoryDataStream(); 17 | virtual void BeforeDataAccess() override; 18 | virtual void EndDataAccess() override; 19 | virtual DataBlock& CurrDataBlock() override; 20 | private: 21 | std::vector data_buffer_; 22 | std::string data_path_; 23 | int32_t index_; 24 | 25 | // No copying allowed 26 | MemoryDataStream(const MemoryDataStream&); 27 | void operator=(const MemoryDataStream&); 28 | }; 29 | 30 | class DiskDataStream : public IDataStream 31 | { 32 | typedef DoubleBuffer DataBuffer; 33 | public: 34 | DiskDataStream(int32_t num_blocks, std::string data_path, 35 | int32_t num_iterations); 36 | virtual ~DiskDataStream(); 37 | virtual void BeforeDataAccess() override; 38 | virtual void EndDataAccess() override; 39 | virtual DataBlock& CurrDataBlock() override; 40 | private: 41 | /*! \brief Background data thread entrance function */ 42 | void DataPreloadMain(); 43 | /*! \brief buffer for data */ 44 | DataBlock* buffer_0; 45 | DataBlock* buffer_1; 46 | DataBuffer* data_buffer_; 47 | /*! \brief current block id to be accessed */ 48 | int32_t block_id_; 49 | /*! \brief number of data blocks in disk */ 50 | int32_t num_blocks_; 51 | /*! \brief number of training iterations */ 52 | int32_t num_iterations_; 53 | /*! \brief data path */ 54 | std::string data_path_; 55 | /*! \brief backend thread for data preload */ 56 | std::thread preload_thread_; 57 | bool working_; 58 | 59 | // No copying allowed 60 | DiskDataStream(const DiskDataStream&); 61 | void operator=(const DiskDataStream&); 62 | }; 63 | 64 | MemoryDataStream::MemoryDataStream(int32_t num_blocks, std::string data_path) 65 | : data_path_(data_path), index_(0) 66 | { 67 | data_buffer_.resize(num_blocks, nullptr); 68 | for (int32_t i = 0; i < num_blocks; ++i) 69 | { 70 | data_buffer_[i] = new DataBlock(); 71 | data_buffer_[i]->Read(data_path_ + "/block." 72 | + std::to_string(i)); 73 | } 74 | } 75 | MemoryDataStream::~MemoryDataStream() 76 | { 77 | for (auto& data : data_buffer_) 78 | { 79 | data->Write(); 80 | delete data; 81 | data = nullptr; 82 | } 83 | } 84 | void MemoryDataStream::BeforeDataAccess() 85 | { 86 | index_ %= data_buffer_.size(); 87 | } 88 | void MemoryDataStream::EndDataAccess() 89 | { 90 | ++index_; 91 | } 92 | 93 | DataBlock& MemoryDataStream::CurrDataBlock() 94 | { 95 | return *data_buffer_[index_]; 96 | } 97 | 98 | DiskDataStream::DiskDataStream(int32_t num_blocks, 99 | std::string data_path, int32_t num_iterations) : 100 | num_blocks_(num_blocks), data_path_(data_path), 101 | num_iterations_(num_iterations), working_(false) 102 | { 103 | block_id_ = 0; 104 | buffer_0 = new DataBlock(); 105 | buffer_1 = new DataBlock(); 106 | data_buffer_ = new DataBuffer(1, buffer_0, buffer_1); 107 | preload_thread_ = std::thread(&DiskDataStream::DataPreloadMain, this); 108 | while (!working_) 109 | { 110 | std::this_thread::sleep_for(std::chrono::microseconds(10)); 111 | } 112 | } 113 | 114 | DiskDataStream::~DiskDataStream() 115 | { 116 | preload_thread_.join(); 117 | if (data_buffer_ != nullptr) 118 | { 119 | delete data_buffer_; 120 | data_buffer_ = nullptr; 121 | delete buffer_1; 122 | buffer_1 = nullptr; 123 | delete buffer_0; 124 | buffer_0 = nullptr; 125 | } 126 | } 127 | 128 | DataBlock& DiskDataStream::CurrDataBlock() 129 | { 130 | return data_buffer_->WorkerBuffer(); 131 | } 132 | 133 | void DiskDataStream::BeforeDataAccess() 134 | { 135 | data_buffer_->Start(1); 136 | } 137 | 138 | void DiskDataStream::EndDataAccess() 139 | { 140 | data_buffer_->End(1); 141 | } 142 | 143 | void DiskDataStream::DataPreloadMain() 144 | { 145 | int32_t block_id = 0; 146 | std::string block_file = data_path_ + "/block." 147 | + std::to_string(block_id); 148 | data_buffer_->Start(0); 149 | data_buffer_->IOBuffer().Read(block_file); 150 | data_buffer_->End(0); 151 | working_ = true; 152 | for (int32_t iter = 0; iter <= num_iterations_; ++iter) 153 | { 154 | for (int32_t block_id = 0; block_id < num_blocks_; ++block_id) 155 | { 156 | data_buffer_->Start(0); 157 | 158 | DataBlock& data_block = data_buffer_->IOBuffer(); 159 | if (data_block.HasLoad()) 160 | { 161 | data_block.Write(); 162 | } 163 | if (iter == num_iterations_ && block_id == num_blocks_ - 1) 164 | { 165 | break; 166 | } 167 | // Load New data; 168 | int32_t next_block_id = (block_id + 1) % num_blocks_; 169 | block_file = data_path_ + "/block." + 170 | std::to_string(next_block_id); 171 | data_block.Read(block_file); 172 | data_buffer_->End(0); 173 | } 174 | } 175 | } 176 | 177 | IDataStream* CreateDataStream() 178 | { 179 | if (Config::out_of_core && Config::num_blocks != 1) 180 | { 181 | return new DiskDataStream(Config::num_blocks, Config::input_dir, 182 | Config::num_iterations); 183 | } 184 | else 185 | { 186 | return new MemoryDataStream(Config::num_blocks, Config::input_dir); 187 | } 188 | } 189 | } // namespace lightlda 190 | } // namespace multiverso 191 | -------------------------------------------------------------------------------- /inference/infer.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "alias_table.h" 3 | #include "data_stream.h" 4 | #include "data_block.h" 5 | #include "document.h" 6 | #include "meta.h" 7 | #include "util.h" 8 | #include "model.h" 9 | #include "inferer.h" 10 | #include 11 | #include 12 | #include 13 | // #include 14 | #include 15 | 16 | namespace multiverso { namespace lightlda 17 | { 18 | class Infer 19 | { 20 | public: 21 | static void Run(int argc, char** argv) 22 | { 23 | Log::ResetLogFile("LightLDA_infer." + std::to_string(clock()) + ".log"); 24 | Config::Init(argc, argv); 25 | //init meta 26 | meta.Init(); 27 | //init model 28 | LocalModel* model = new LocalModel(&meta); model->Init(); 29 | //init document stream 30 | data_stream = CreateDataStream(); 31 | //init documents 32 | InitDocument(); 33 | //init alias table 34 | AliasTable* alias_table = new AliasTable(); 35 | //init inferers 36 | std::vector inferers; 37 | Barrier barrier(Config::num_local_workers); 38 | // pthread_barrier_t barrier; 39 | // pthread_barrier_init(&barrier, nullptr, Config::num_local_workers); 40 | for (int32_t i = 0; i < Config::num_local_workers; ++i) 41 | { 42 | inferers.push_back(new Inferer(alias_table, data_stream, 43 | &meta, model, 44 | &barrier, i, Config::num_local_workers)); 45 | } 46 | 47 | //do inference in muti-threads 48 | Inference(inferers); 49 | 50 | //dump doc topic 51 | DumpDocTopic(); 52 | 53 | //recycle space 54 | for (auto& inferer : inferers) 55 | { 56 | delete inferer; 57 | inferer = nullptr; 58 | } 59 | // pthread_barrier_destroy(&barrier); 60 | delete data_stream; 61 | delete alias_table; 62 | delete model; 63 | } 64 | private: 65 | static void Inference(std::vector& inferers) 66 | { 67 | //pthread_t * threads = new pthread_t[Config::num_local_workers]; 68 | //if(nullptr == threads) 69 | //{ 70 | // Log::Fatal("failed to allocate space for worker threads"); 71 | //} 72 | std::vector threads; 73 | for(int32_t i = 0; i < Config::num_local_workers; ++i) 74 | { 75 | threads.push_back(std::thread(&InferenceThread, inferers[i])); 76 | //if(pthread_create(threads + i, nullptr, InferenceThread, inferers[i])) 77 | //{ 78 | // Log::Fatal("failed to create worker threads"); 79 | //} 80 | } 81 | for(int32_t i = 0; i < Config::num_local_workers; ++i) 82 | { 83 | // pthread_join(threads[i], nullptr); 84 | threads[i].join(); 85 | } 86 | // delete [] threads; 87 | } 88 | 89 | static void* InferenceThread(void* arg) 90 | { 91 | Inferer* inferer = (Inferer*)arg; 92 | // inference corpus block by block 93 | for (int32_t block = 0; block < Config::num_blocks; ++block) 94 | { 95 | inferer->BeforeIteration(block); 96 | for (int32_t i = 0; i < Config::num_iterations; ++i) 97 | { 98 | inferer->DoIteration(i); 99 | } 100 | inferer->EndIteration(); 101 | } 102 | return nullptr; 103 | } 104 | 105 | static void InitDocument() 106 | { 107 | xorshift_rng rng; 108 | for (int32_t block = 0; block < Config::num_blocks; ++block) 109 | { 110 | data_stream->BeforeDataAccess(); 111 | DataBlock& data_block = data_stream->CurrDataBlock(); 112 | int32_t num_slice = meta.local_vocab(block).num_slice(); 113 | for (int32_t slice = 0; slice < num_slice; ++slice) 114 | { 115 | for (int32_t i = 0; i < data_block.Size(); ++i) 116 | { 117 | Document* doc = data_block.GetOneDoc(i); 118 | int32_t& cursor = doc->Cursor(); 119 | if (slice == 0) cursor = 0; 120 | int32_t last_word = meta.local_vocab(block).LastWord(slice); 121 | for (; cursor < doc->Size(); ++cursor) 122 | { 123 | if (doc->Word(cursor) > last_word) break; 124 | // Init the latent variable 125 | if (!Config::warm_start) 126 | doc->SetTopic(cursor, rng.rand_k(Config::num_topics)); 127 | } 128 | } 129 | } 130 | data_stream->EndDataAccess(); 131 | } 132 | } 133 | 134 | 135 | static void DumpDocTopic() 136 | { 137 | Row doc_topic_counter(0, Format::Sparse, kMaxDocLength); 138 | for (int32_t block = 0; block < Config::num_blocks; ++block) 139 | { 140 | std::ofstream fout("doc_topic." + std::to_string(block)); 141 | data_stream->BeforeDataAccess(); 142 | DataBlock& data_block = data_stream->CurrDataBlock(); 143 | for (int i = 0; i < data_block.Size(); ++i) 144 | { 145 | Document* doc = data_block.GetOneDoc(i); 146 | doc_topic_counter.Clear(); 147 | doc->GetDocTopicVector(doc_topic_counter); 148 | fout << i << " "; // doc id 149 | Row::iterator iter = doc_topic_counter.Iterator(); 150 | while (iter.HasNext()) 151 | { 152 | fout << " " << iter.Key() << ":" << iter.Value(); 153 | iter.Next(); 154 | } 155 | fout << std::endl; 156 | } 157 | data_stream->EndDataAccess(); 158 | } 159 | } 160 | private: 161 | /*! \brief training data access */ 162 | static IDataStream* data_stream; 163 | /*! \brief training data meta information */ 164 | static Meta meta; 165 | }; 166 | IDataStream* Infer::data_stream = nullptr; 167 | Meta Infer::meta; 168 | 169 | } // namespace lightlda 170 | } // namespace multiverso 171 | 172 | 173 | int main(int argc, char** argv) 174 | { 175 | multiverso::lightlda::Config::inference = true; 176 | multiverso::lightlda::Infer::Run(argc, argv); 177 | return 0; 178 | } 179 | -------------------------------------------------------------------------------- /src/common.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | 3 | #include 4 | 5 | namespace multiverso { namespace lightlda 6 | { 7 | const int64_t kMB = 1024 * 1024; 8 | 9 | // -- Begin: Config definitioin and defalut values --------------------- // 10 | int32_t Config::num_vocabs = -1; 11 | int32_t Config::num_topics = 100; 12 | int32_t Config::num_iterations = 100; 13 | int32_t Config::mh_steps = 2; 14 | int32_t Config::num_servers = 1; 15 | int32_t Config::num_local_workers = 1; 16 | int32_t Config::num_aggregator = 1; 17 | int32_t Config::num_blocks = 1; 18 | int64_t Config::max_num_document = -1; 19 | float Config::alpha = 0.01f; 20 | float Config::beta = 0.01f; 21 | std::string Config::server_file = ""; 22 | std::string Config::input_dir = ""; 23 | bool Config::warm_start = false; 24 | bool Config::inference = false; 25 | bool Config::out_of_core = false; 26 | int64_t Config::data_capacity = 1024 * kMB; 27 | int64_t Config::model_capacity = 512 * kMB; 28 | int64_t Config::delta_capacity = 256 * kMB; 29 | int64_t Config::alias_capacity = 512 * kMB; 30 | // -- End: Config definitioin and defalut values ----------------------- // 31 | 32 | void Config::Init(int argc, char* argv[]) 33 | { 34 | if (argc < 2) 35 | { 36 | PrintUsage(); 37 | } 38 | for (int i = 1; i < argc; ++i) 39 | { 40 | if (strcmp(argv[i], "-help") == 0 || strcmp(argv[i], "--help") == 0) 41 | { 42 | PrintUsage(); 43 | } 44 | if (strcmp(argv[i], "-num_vocabs") == 0) num_vocabs = atoi(argv[i + 1]); 45 | if (strcmp(argv[i], "-num_topics") == 0) num_topics = atoi(argv[i + 1]); 46 | if (strcmp(argv[i], "-num_iterations") == 0) num_iterations = atoi(argv[i + 1]); 47 | if (strcmp(argv[i], "-mh_steps") == 0) mh_steps = atoi(argv[i + 1]); 48 | if (strcmp(argv[i], "-num_servers") == 0) num_servers = atoi(argv[i + 1]); 49 | if (strcmp(argv[i], "-num_local_workers") == 0) num_local_workers = atoi(argv[i + 1]); 50 | if (strcmp(argv[i], "-num_aggregator") == 0) num_aggregator = atoi(argv[i + 1]); 51 | if (strcmp(argv[i], "-num_blocks") == 0) num_blocks = atoi(argv[i + 1]); 52 | if (strcmp(argv[i], "-max_num_document") == 0) max_num_document = atoll(argv[i + 1]); 53 | if (strcmp(argv[i], "-alpha") == 0) alpha = static_cast(atof(argv[i + 1])); 54 | if (strcmp(argv[i], "-beta") == 0) beta = static_cast(atof(argv[i + 1])); 55 | if (strcmp(argv[i], "-input_dir") == 0) input_dir = std::string(argv[i + 1]); 56 | if (strcmp(argv[i], "-server_file") == 0) server_file = std::string(argv[i + 1]); 57 | if (strcmp(argv[i], "-warm_start") == 0) warm_start = true; 58 | if (strcmp(argv[i], "-out_of_core") == 0) out_of_core = true; 59 | if (strcmp(argv[i], "-data_capacity") == 0) data_capacity = atoi(argv[i + 1]) * kMB; 60 | if (strcmp(argv[i], "-model_capacity") == 0) model_capacity = atoi(argv[i + 1]) * kMB; 61 | if (strcmp(argv[i], "-alias_capacity") == 0) alias_capacity = atoi(argv[i + 1]) * kMB; 62 | if (strcmp(argv[i], "-delta_capacity") == 0) delta_capacity = atoi(argv[i + 1]) * kMB; 63 | } 64 | Check(); 65 | } 66 | 67 | void Config::PrintTrainingUsage() 68 | { 69 | printf("LightLDA usage: \n"); 70 | printf("-num_vocabs Size of dataset vocabulary \n"); 71 | printf("-num_topics Number of topics. Default: 100\n"); 72 | printf("-num_iterations Number of iteratioins. Default: 100\n"); 73 | printf("-mh_steps Metropolis-hasting steps. Default: 2\n"); 74 | printf("-alpha Dirichlet prior alpha. Default: 0.1\n"); 75 | printf("-beta Dirichlet prior beta. Default: 0.01\n\n"); 76 | printf("-num_blocks Number of blocks in disk. Default: 1\n"); 77 | printf("-max_num_document Max number of document in a data block \n"); 78 | printf("-input_dir Directory of input data, containing\n"); 79 | printf(" files generated by dump_block \n\n"); 80 | printf("-num_servers Number of servers. Default: 1\n"); 81 | printf("-num_local_workers Number of local training threads. Default: 4\n"); 82 | printf("-num_aggregator Number of local aggregation threads. Default: 1\n"); 83 | printf("-server_file Server endpoint file. Used by MPI-free version\n"); 84 | printf("-warm_start Warm start \n"); 85 | printf("-out_of_core Use out of core computing \n\n"); 86 | printf("-data_capacity Memory pool size(MB) for data storage, \n"); 87 | printf(" should larger than the any data block\n"); 88 | printf("-model_capacity Memory pool size(MB) for local model cache\n"); 89 | printf("-alias_capacity Memory pool size(MB) for alias table \n"); 90 | printf("-delta_capacity Memory pool size(MB) for local delta cache\n"); 91 | exit(0); 92 | } 93 | 94 | void Config::PrintInferenceUsage() 95 | { 96 | printf("LightLDA Inference usage: \n"); 97 | printf("-num_vocabs Size of dataset vocabulary \n"); 98 | printf("-num_topics Number of topics. Default: 100\n"); 99 | printf("-num_iterations Number of iteratioins. Default: 100\n"); 100 | printf("-mh_steps Metropolis-hasting steps. Default: 2\n"); 101 | printf("-alpha Dirichlet prior alpha. Default: 0.1\n"); 102 | printf("-beta Dirichlet prior beta. Default: 0.01\n\n"); 103 | printf("-num_blocks Number of blocks in disk. Default: 1\n"); 104 | printf("-max_num_document Max number of document in a data block \n"); 105 | printf("-input_dir Directory of input data, containing\n"); 106 | printf(" files generated by dump_block \n\n"); 107 | printf("-num_local_workers Number of local training threads. Default: 4\n"); 108 | printf("-warm_start Warm start \n"); 109 | printf("-out_of_core Use out of core computing \n\n"); 110 | printf("-data_capacity Memory pool size(MB) for data storage, \n"); 111 | printf(" should larger than the any data block\n"); 112 | exit(0); 113 | } 114 | 115 | void Config::PrintUsage() 116 | { 117 | if(!inference) 118 | { 119 | PrintTrainingUsage(); 120 | } 121 | else 122 | { 123 | PrintInferenceUsage(); 124 | } 125 | } 126 | 127 | void Config::Check() 128 | { 129 | if (input_dir == "" || num_vocabs <= 0 || max_num_document == -1) 130 | { 131 | PrintUsage(); 132 | } 133 | } 134 | } // namespace lightlda 135 | } // namespace multiverso 136 | -------------------------------------------------------------------------------- /windows/dump_binary/dump_binary.vcxproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {FFD24CAD-825A-4732-8186-4361D3E1438F} 23 | Win32Proj 24 | dump_binary 25 | 26 | 27 | 28 | Application 29 | true 30 | v120 31 | Unicode 32 | 33 | 34 | Application 35 | true 36 | v120 37 | Unicode 38 | 39 | 40 | Application 41 | false 42 | v120 43 | true 44 | Unicode 45 | 46 | 47 | Application 48 | false 49 | v120 50 | true 51 | Unicode 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | true 71 | 72 | 73 | true 74 | 75 | 76 | false 77 | 78 | 79 | false 80 | 81 | 82 | 83 | 84 | 85 | Level3 86 | Disabled 87 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 88 | 89 | 90 | Console 91 | true 92 | 93 | 94 | 95 | 96 | 97 | 98 | Level3 99 | Disabled 100 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 101 | 102 | 103 | Console 104 | true 105 | 106 | 107 | 108 | 109 | Level3 110 | 111 | 112 | MaxSpeed 113 | true 114 | true 115 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 116 | 117 | 118 | Console 119 | true 120 | true 121 | true 122 | 123 | 124 | 125 | 126 | Level3 127 | 128 | 129 | MaxSpeed 130 | true 131 | true 132 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 133 | 134 | 135 | Console 136 | true 137 | true 138 | true 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /src/trainer.cpp: -------------------------------------------------------------------------------- 1 | #include "trainer.h" 2 | 3 | #include "alias_table.h" 4 | #include "common.h" 5 | #include "data_block.h" 6 | #include "eval.h" 7 | #include "meta.h" 8 | #include "sampler.h" 9 | #include "model.h" 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace multiverso { namespace lightlda 16 | { 17 | std::mutex Trainer::mutex_; 18 | double Trainer::doc_llh_ = 0.0; 19 | double Trainer::word_llh_ = 0.0; 20 | 21 | Trainer::Trainer(AliasTable* alias_table, 22 | Barrier* barrier, Meta* meta) : 23 | alias_(alias_table), barrier_(barrier), meta_(meta), 24 | model_(nullptr) 25 | { 26 | sampler_ = new LightDocSampler(); 27 | model_ = new PSModel(this); 28 | } 29 | 30 | Trainer::~Trainer() 31 | { 32 | delete sampler_; 33 | delete model_; 34 | } 35 | 36 | void Trainer::TrainIteration(DataBlockBase* data_block) 37 | { 38 | StopWatch watch; watch.Start(); 39 | LDADataBlock* lda_data_block = 40 | reinterpret_cast(data_block); 41 | 42 | DataBlock& data = lda_data_block->data(); 43 | int32_t block = lda_data_block->block(); 44 | int32_t slice = lda_data_block->slice(); 45 | int32_t iter = lda_data_block->iteration(); 46 | const LocalVocab& local_vocab = data.meta(); 47 | 48 | int32_t id = TrainerId(); 49 | int32_t trainer_num = TrainerCount(); 50 | int32_t lastword = local_vocab.LastWord(slice); 51 | if (id == 0) 52 | { 53 | Log::Info("Rank = %d, Iter = %d, Block = %d, Slice = %d\n", 54 | Multiverso::ProcessRank(), lda_data_block->iteration(), 55 | lda_data_block->block(), lda_data_block->slice()); 56 | } 57 | // Build Alias table 58 | if (id == 0) alias_->Init(meta_->alias_index(block, slice)); 59 | barrier_->Wait(); 60 | for (const int32_t* pword = local_vocab.begin(slice) + id; 61 | pword < local_vocab.end(slice); 62 | pword += trainer_num) 63 | { 64 | alias_->Build(*pword, model_); 65 | } 66 | if (id == 0) alias_->Build(-1, model_); 67 | barrier_->Wait(); 68 | 69 | if (TrainerId() == 0) 70 | { 71 | Log::Info("Rank = %d, Alias Time used: %.2f s \n", 72 | Multiverso::ProcessRank(), watch.ElapsedSeconds()); 73 | } 74 | int32_t num_token = 0; 75 | watch.Restart(); 76 | // Train with lightlda sampler 77 | for (int32_t doc_id = id; doc_id < data.Size(); doc_id += trainer_num) 78 | { 79 | Document* doc = data.GetOneDoc(doc_id); 80 | num_token += sampler_->SampleOneDoc(doc, slice, lastword, model_, alias_); 81 | } 82 | if (TrainerId() == 0) 83 | { 84 | Log::Info("Rank = %d, Training Time used: %.2f s \n", 85 | Multiverso::ProcessRank(), watch.ElapsedSeconds()); 86 | Log::Info("Rank = %d, sampling throughput: %.6f (tokens/thread/sec) \n", 87 | Multiverso::ProcessRank(), double(num_token) / watch.ElapsedSeconds()); 88 | } 89 | watch.Restart(); 90 | // Evaluate loss function 91 | // Evaluate(lda_data_block); 92 | 93 | if (iter % 5 == 0) 94 | { 95 | Evaluate(lda_data_block); 96 | if (TrainerId() == 0) 97 | Log::Info("Rank = %d, Evaluation Time used: %.2f s \n", 98 | Multiverso::ProcessRank(), watch.ElapsedSeconds()); 99 | } 100 | // if (iter != 0 && iter % 50 == 0) Dump(iter, lda_data_block); 101 | 102 | // Clear the thread information in alias table 103 | if (iter == Config::num_iterations - 1) alias_->Clear(); 104 | } 105 | 106 | void Trainer::Evaluate(LDADataBlock* lda_data_block) 107 | { 108 | double thread_doc = 0, thread_word = 0; 109 | 110 | DataBlock& data = lda_data_block->data(); 111 | int32_t block = lda_data_block->block(); 112 | int32_t slice = lda_data_block->slice(); 113 | const LocalVocab& local_vocab = data.meta(); 114 | 115 | // 1. Evaluate doc likelihood 116 | for (int32_t doc_id = TrainerId(); doc_id < data.Size() && slice == 0; 117 | doc_id += TrainerCount()) 118 | { 119 | thread_doc += Eval::ComputeOneDocLLH(data.GetOneDoc(doc_id), 120 | sampler_->doc_topic_counter()); 121 | } 122 | { 123 | std::lock_guard lock(mutex_); 124 | doc_llh_ += thread_doc; 125 | } 126 | if (slice == 0 && barrier_->Wait()) 127 | { 128 | Log::Info("doc likelihood : %e\n", doc_llh_); 129 | doc_llh_ = 0; 130 | } 131 | 132 | // 2. Evaluate word likelihood 133 | for (const int32_t* word = local_vocab.begin(slice) + TrainerId(); 134 | word < local_vocab.end(slice) && block == 0; word += TrainerCount()) 135 | { 136 | thread_word += Eval::ComputeOneWordLLH(*word, this); 137 | } 138 | { 139 | std::lock_guard lock(mutex_); 140 | word_llh_ += thread_word; 141 | } 142 | if (block == 0 && barrier_->Wait()) 143 | { 144 | Log::Info("word likelihood : %e\n", word_llh_); 145 | word_llh_ = 0; 146 | } 147 | 148 | // 3. Evaluate normalize item for word likelihood 149 | if (TrainerId() == 0 && block == 0) 150 | { 151 | Log::Info("Normalized likelihood : %e\n", 152 | Eval::NormalizeWordLLH(this)); 153 | } 154 | barrier_->Wait(); 155 | } 156 | 157 | void Trainer::Dump(int32_t iter, LDADataBlock* lda_data_block) 158 | { 159 | DataBlock& data = lda_data_block->data(); 160 | int32_t slice = lda_data_block->slice(); 161 | const LocalVocab& local_vocab = data.meta(); 162 | 163 | std::string out = "model." + std::to_string(iter) + "." 164 | + std::to_string(slice) + "." + std::to_string(TrainerId()); 165 | std::ofstream fout(out); 166 | 167 | for (const int32_t* p = local_vocab.begin(slice) + TrainerId(); 168 | p < local_vocab.end(slice); p += TrainerCount()) 169 | { 170 | int word = *p; 171 | Row& row = GetRow(kWordTopicTable, word); 172 | Row::iterator iter = row.Iterator(); 173 | 174 | fout << word; 175 | while (iter.HasNext()) 176 | { 177 | int32_t topic = iter.Key(); 178 | int32_t count = iter.Value(); 179 | fout << " " << topic << ":" << count; 180 | iter.Next(); 181 | } 182 | fout << std::endl; 183 | } 184 | 185 | fout.close(); 186 | } 187 | 188 | void ParamLoader::ParseAndRequest(DataBlockBase* data_block) 189 | { 190 | LDADataBlock* lda_data_block = 191 | reinterpret_cast(data_block); 192 | // Request word-topic-table 193 | int32_t slice = lda_data_block->slice(); 194 | DataBlock& data = lda_data_block->data(); 195 | const LocalVocab& local_vocab = data.meta(); 196 | 197 | for (const int32_t* p = local_vocab.begin(slice); 198 | p != local_vocab.end(slice); ++p) 199 | { 200 | RequestRow(kWordTopicTable, *p); 201 | } 202 | Log::Debug("Request params. start = %d, end = %d\n", 203 | *local_vocab.begin(slice), *(local_vocab.end(slice) - 1)); 204 | // Request summary-row 205 | RequestTable(kSummaryRow); 206 | } 207 | } // namespace lightlda 208 | } // namespace multiverso 209 | -------------------------------------------------------------------------------- /windows/infer/infer.vcxproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {3CF22D68-9F4B-46B9-B0D4-7129467DD759} 23 | Win32Proj 24 | infer 25 | 26 | 27 | 28 | Application 29 | true 30 | v120 31 | Unicode 32 | 33 | 34 | Application 35 | true 36 | v120 37 | Unicode 38 | 39 | 40 | Application 41 | false 42 | v120 43 | true 44 | Unicode 45 | 46 | 47 | Application 48 | false 49 | v120 50 | true 51 | Unicode 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | true 71 | 72 | 73 | true 74 | 75 | 76 | false 77 | 78 | 79 | false 80 | $(SolutionDir)/../../multiverso/include;$(SolutionDir)/../src/;$(VC_IncludePath);$(WindowsSDK_IncludePath); 81 | 82 | 83 | 84 | 85 | 86 | Level3 87 | Disabled 88 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 89 | 90 | 91 | Console 92 | true 93 | 94 | 95 | 96 | 97 | 98 | 99 | Level3 100 | Disabled 101 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 102 | 103 | 104 | Console 105 | true 106 | 107 | 108 | 109 | 110 | Level3 111 | 112 | 113 | MaxSpeed 114 | true 115 | true 116 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 117 | 118 | 119 | Console 120 | true 121 | true 122 | true 123 | 124 | 125 | 126 | 127 | Level3 128 | 129 | 130 | MaxSpeed 131 | true 132 | true 133 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 134 | 135 | 136 | Console 137 | true 138 | true 139 | true 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /src/meta.cpp: -------------------------------------------------------------------------------- 1 | #include "meta.h" 2 | #include "common.h" 3 | 4 | #include 5 | #include 6 | 7 | namespace multiverso { namespace lightlda 8 | { 9 | LocalVocab::LocalVocab() 10 | : num_slices_(0), own_memory_(false), vocabs_(nullptr), size_(0) 11 | {} 12 | 13 | LocalVocab::~LocalVocab() 14 | { 15 | if (own_memory_) 16 | { 17 | delete[] vocabs_; 18 | } 19 | } 20 | 21 | AliasTableIndex::AliasTableIndex() 22 | { 23 | index_map_.resize(Config::num_vocabs, -1); 24 | } 25 | 26 | WordEntry& AliasTableIndex::word_entry(int32_t word) 27 | { 28 | if (index_map_[word] == -1) 29 | { 30 | Log::Fatal("Fatal in alias index: word %d not exist\n", word); 31 | } 32 | return index_[index_map_[word]]; 33 | } 34 | 35 | void AliasTableIndex::PushWord(int32_t word, 36 | bool is_dense, int64_t begin_offset, int32_t capacity) 37 | { 38 | index_map_[word] = static_cast(index_.size()); 39 | index_.push_back({ is_dense, begin_offset, capacity }); 40 | } 41 | 42 | Meta::Meta() 43 | { 44 | } 45 | 46 | Meta::~Meta() 47 | { 48 | for (int32_t i = 0; i < alias_index_.size(); ++i) 49 | { 50 | for (int32_t j = 0; j < alias_index_[i].size(); ++j) 51 | { 52 | delete alias_index_[i][j]; 53 | } 54 | } 55 | } 56 | 57 | void Meta::Init() 58 | { 59 | tf_.resize(Config::num_vocabs, 0); 60 | local_tf_.resize(Config::num_vocabs, 0); 61 | int32_t* tf = new int32_t[Config::num_vocabs]; 62 | int32_t* local_tf = new int32_t[Config::num_vocabs]; 63 | local_vocabs_.resize(Config::num_blocks); 64 | for (int32_t i = 0; i < Config::num_blocks; ++i) 65 | { 66 | LocalVocab& local_vocab = local_vocabs_[i]; 67 | 68 | std::string file_name = Config::input_dir 69 | + "/vocab." + std::to_string(i); 70 | std::ifstream vocab_file(file_name, std::ios::in|std::ios::binary); 71 | 72 | if (!vocab_file.good()) 73 | { 74 | Log::Fatal("Failed to open file : %s\n", file_name.c_str()); 75 | } 76 | 77 | vocab_file.read(reinterpret_cast(&local_vocab.size_), 78 | sizeof(int)); 79 | local_vocab.vocabs_ = new int[local_vocab.size_]; 80 | local_vocab.own_memory_ = true; 81 | vocab_file.read(reinterpret_cast(local_vocab.vocabs_), 82 | sizeof(int)* local_vocab.size_); 83 | vocab_file.read(reinterpret_cast(tf), 84 | sizeof(int)* local_vocab.size_); 85 | vocab_file.read(reinterpret_cast(local_tf), 86 | sizeof(int)* local_vocab.size_); 87 | 88 | vocab_file.close(); 89 | 90 | for (int32_t i = 0; i < local_vocab.size_; ++i) 91 | { 92 | if (tf[i] > tf_[local_vocab.vocabs_[i]]) 93 | { 94 | tf_[local_vocab.vocabs_[i]] = tf[i]; 95 | } 96 | if (local_tf[i] > local_tf_[local_vocab.vocabs_[i]]) 97 | { 98 | local_tf_[local_vocab.vocabs_[i]] = local_tf[i]; 99 | } 100 | } 101 | } 102 | 103 | delete[] local_tf; 104 | delete[] tf; 105 | 106 | if(!Config::inference) 107 | { 108 | ModelSchedule(); 109 | } 110 | else 111 | { 112 | ModelSchedule4Inference(); 113 | } 114 | BuildAliasIndex(); 115 | } 116 | 117 | void Meta::ModelSchedule() 118 | { 119 | int64_t model_capacity = Config::model_capacity; 120 | int64_t alias_capacity = Config::alias_capacity; 121 | int64_t delta_capacity = Config::delta_capacity; 122 | 123 | int32_t model_thresh = Config::num_topics / (2 * kLoadFactor); 124 | int32_t alias_thresh = (Config::num_topics * 2) / 3; 125 | int32_t delta_thresh = Config::num_topics / (4 * kLoadFactor); 126 | 127 | 128 | // Schedule for each data block 129 | for (int32_t i = 0; i < Config::num_blocks; ++i) 130 | { 131 | LocalVocab& local_vocab = local_vocabs_[i]; 132 | int32_t* vocabs = local_vocab.vocabs_; 133 | local_vocab.slice_index_.push_back(0); 134 | 135 | int64_t model_offset = 0; 136 | int64_t alias_offset = 0; 137 | int64_t delta_offset = 0; 138 | for (int32_t j = 0; j < local_vocab.size_; ++j) 139 | { 140 | int32_t word = vocabs[j]; 141 | int32_t tf = tf_[word]; 142 | int32_t local_tf = local_tf_[word]; 143 | int32_t model_size = (tf > model_thresh) ? 144 | Config::num_topics* sizeof(int32_t) : 145 | tf * kLoadFactor * sizeof(int32_t); 146 | model_offset += model_size; 147 | 148 | int32_t alias_size = (tf > alias_thresh) ? 149 | Config::num_topics * 2 * sizeof(int32_t) : 150 | tf * 3 * sizeof(int32_t); 151 | alias_offset += alias_size; 152 | 153 | int32_t delta_size = (local_tf > delta_thresh) ? 154 | Config::num_topics * sizeof(int32_t) : 155 | local_tf * kLoadFactor * 2 * sizeof(int32_t); 156 | delta_offset += delta_size; 157 | 158 | if (model_offset > model_capacity || 159 | alias_offset > alias_capacity || 160 | delta_offset > delta_capacity) 161 | { 162 | Log::Info("Actual Model capacity: %d MB, Alias capacity: %d MB, Delta capacity: %dMB\n", 163 | model_offset/1024/1024, alias_offset/1024/1024, delta_offset/1024/1024); 164 | local_vocab.slice_index_.push_back(j); 165 | ++local_vocab.num_slices_; 166 | model_offset = model_size; 167 | alias_offset = alias_size; 168 | delta_offset = delta_size; 169 | } 170 | } 171 | local_vocab.slice_index_.push_back(local_vocab.size_); 172 | ++local_vocab.num_slices_; 173 | Log::Info("INFO: block = %d, the number of slice = %d\n", 174 | i, local_vocab.num_slices_); 175 | } 176 | } 177 | 178 | void Meta::ModelSchedule4Inference() 179 | { 180 | Config::alias_capacity = 0; 181 | int32_t alias_thresh = (Config::num_topics * 2) / 3; 182 | // Schedule for each data block 183 | for (int32_t i = 0; i < Config::num_blocks; ++i) 184 | { 185 | LocalVocab& local_vocab = local_vocabs_[i]; 186 | int32_t* vocabs = local_vocab.vocabs_; 187 | local_vocab.slice_index_.push_back(0); 188 | local_vocab.slice_index_.push_back(local_vocab.size_); 189 | local_vocab.num_slices_ = 1; 190 | int64_t alias_offset = 0; 191 | for (int32_t j = 0; j < local_vocab.size_; ++j) 192 | { 193 | int32_t word = vocabs[j]; 194 | int32_t tf = tf_[word]; 195 | int32_t alias_size = (tf > alias_thresh) ? 196 | Config::num_topics * 2 * sizeof(int32_t) : 197 | tf * 3 * sizeof(int32_t); 198 | alias_offset += alias_size; 199 | } 200 | if(alias_offset > Config::alias_capacity) 201 | { 202 | Config::alias_capacity = alias_offset; 203 | } 204 | } 205 | Log::Info("Actual Alias capacity: %d MB\n", Config::alias_capacity/1024/1024); 206 | } 207 | 208 | void Meta::BuildAliasIndex() 209 | { 210 | int32_t alias_thresh = (Config::num_topics * 2) / 3; 211 | alias_index_.resize(Config::num_blocks); 212 | // for each block 213 | for (int32_t i = 0; i < Config::num_blocks; ++i) 214 | { 215 | const LocalVocab& vocab = local_vocab(i); 216 | alias_index_[i].resize(vocab.num_slice()); 217 | // for each slice 218 | for (int32_t j = 0; j < vocab.num_slice(); ++j) 219 | { 220 | alias_index_[i][j] = new AliasTableIndex(); 221 | int64_t offset = 0; 222 | for (const int32_t* p = vocab.begin(j); 223 | p != vocab.end(j); ++p) 224 | { 225 | int32_t word = *p; 226 | bool is_dense = true; 227 | int32_t capacity = Config::num_topics; 228 | int64_t size = Config::num_topics * 2; 229 | if (tf(word) <= alias_thresh) 230 | { 231 | is_dense = false; 232 | capacity = tf(word); 233 | size = tf(word) * 3; 234 | } 235 | alias_index_[i][j]->PushWord(word, is_dense, offset, capacity); 236 | offset += size; 237 | } 238 | } 239 | } 240 | } 241 | 242 | } // namespace lightlda 243 | } // namespace multiverso 244 | -------------------------------------------------------------------------------- /windows/lightlda/lightlda.vcxproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {9A6EEB61-6B68-49A3-B83A-E044F6678D6A} 23 | Win32Proj 24 | lightlda 25 | 26 | 27 | 28 | Application 29 | true 30 | v120 31 | Unicode 32 | 33 | 34 | Application 35 | true 36 | v120 37 | Unicode 38 | 39 | 40 | Application 41 | false 42 | v120 43 | true 44 | Unicode 45 | 46 | 47 | Application 48 | false 49 | v120 50 | true 51 | Unicode 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | true 71 | 72 | 73 | true 74 | 75 | 76 | false 77 | 78 | 79 | false 80 | $(SolutionDir)/../../multiverso/include;$(VC_IncludePath);$(WindowsSDK_IncludePath); 81 | $(SolutionDir)/../../multiverso/windows/x64/Release/;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64); 82 | 83 | 84 | 85 | 86 | 87 | Level3 88 | Disabled 89 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 90 | 91 | 92 | Console 93 | true 94 | 95 | 96 | 97 | 98 | 99 | 100 | Level3 101 | Disabled 102 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 103 | 104 | 105 | Console 106 | true 107 | 108 | 109 | 110 | 111 | Level3 112 | 113 | 114 | MaxSpeed 115 | true 116 | true 117 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 118 | 119 | 120 | Console 121 | true 122 | true 123 | true 124 | 125 | 126 | 127 | 128 | Level3 129 | 130 | 131 | MaxSpeed 132 | true 133 | true 134 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 135 | 136 | 137 | Console 138 | true 139 | true 140 | true 141 | multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /src/sampler.cpp: -------------------------------------------------------------------------------- 1 | #include "sampler.h" 2 | 3 | #include "alias_table.h" 4 | #include "common.h" 5 | #include "document.h" 6 | #include "model.h" 7 | 8 | #include 9 | #include 10 | 11 | namespace multiverso { namespace lightlda 12 | { 13 | LightDocSampler::LightDocSampler() 14 | { 15 | alpha_ = Config::alpha; 16 | beta_ = Config::beta; 17 | num_vocab_ = Config::num_vocabs; 18 | num_topic_ = Config::num_topics; 19 | mh_steps_ = Config::mh_steps; 20 | 21 | alpha_sum_ = num_topic_ * alpha_; 22 | beta_sum_ = num_vocab_ * beta_; 23 | 24 | subtractor_ = Config::inference ? 0 : 1; 25 | 26 | doc_topic_counter_.reset(new Row(0, 27 | multiverso::Format::Sparse, kMaxDocLength)); 28 | } 29 | 30 | int32_t LightDocSampler::SampleOneDoc(Document* doc, int32_t slice, 31 | int32_t lastword, ModelBase* model, AliasTable* alias) 32 | { 33 | DocInit(doc); 34 | int32_t num_tokens = 0; 35 | int32_t& cursor = doc->Cursor(); 36 | if (slice == 0) cursor = 0; 37 | for (; cursor != doc->Size(); ++cursor) 38 | { 39 | int32_t word = doc->Word(cursor); 40 | if (word > lastword) break; 41 | int32_t old_topic = doc->Topic(cursor); 42 | int32_t new_topic = Sample(doc, word, old_topic, old_topic, 43 | model, alias); 44 | if (old_topic != new_topic) 45 | { 46 | doc->SetTopic(cursor, new_topic); 47 | doc_topic_counter_->Add(old_topic, -1); 48 | doc_topic_counter_->Add(new_topic, 1); 49 | if(!Config::inference) 50 | { 51 | model->AddWordTopicRow(word, old_topic, -1); 52 | model->AddSummaryRow(old_topic, -1); 53 | model->AddWordTopicRow(word, new_topic, 1); 54 | model->AddSummaryRow(new_topic, 1); 55 | } 56 | } 57 | ++num_tokens; 58 | } 59 | return num_tokens; 60 | } 61 | 62 | void LightDocSampler::DocInit(Document* doc) 63 | { 64 | doc_topic_counter_->Clear(); 65 | doc->GetDocTopicVector(*doc_topic_counter_); 66 | } 67 | 68 | int32_t LightDocSampler::Sample(Document* doc, 69 | int32_t word, int32_t old_topic, int32_t s, 70 | ModelBase* model, AliasTable* alias) 71 | { 72 | int32_t t, w_t_cnt, w_s_cnt; 73 | int64_t n_t, n_s; 74 | float n_td_alpha, n_sd_alpha; 75 | float n_tw_beta, n_sw_beta, n_t_beta_sum, n_s_beta_sum; 76 | float proposal_t, proposal_s; 77 | float nominator, denominator; 78 | double rejection, pi; 79 | int32_t m; 80 | 81 | Row& word_topic_row = model->GetWordTopicRow(word); 82 | Row& summary_row = model->GetSummaryRow(); 83 | 84 | for (int32_t i = 0; i < mh_steps_; ++i) 85 | { 86 | // Word proposal 87 | t = alias->Propose(word, rng_); 88 | if (t < 0 || t >= num_topic_) 89 | { 90 | Log::Fatal("Invalid topic assignment %d from word proposal\n", t); 91 | } 92 | if (t != s) 93 | { 94 | rejection = rng_.rand_double(); 95 | 96 | w_t_cnt = word_topic_row.At(t); 97 | w_s_cnt = word_topic_row.At(s); 98 | n_t = summary_row.At(t); 99 | n_s = summary_row.At(s); 100 | 101 | n_td_alpha = doc_topic_counter_->At(t) + alpha_; 102 | n_sd_alpha = doc_topic_counter_->At(s) + alpha_; 103 | n_tw_beta = w_t_cnt + beta_; 104 | n_t_beta_sum = n_t + beta_sum_; 105 | n_sw_beta = w_s_cnt + beta_; 106 | n_s_beta_sum = n_s + beta_sum_; 107 | if (s == old_topic) 108 | { 109 | --n_sd_alpha; 110 | n_sw_beta -= subtractor_; 111 | n_s_beta_sum -= subtractor_; 112 | } 113 | if (t == old_topic) 114 | { 115 | --n_td_alpha; 116 | n_tw_beta -= subtractor_; 117 | n_t_beta_sum -= subtractor_; 118 | } 119 | 120 | proposal_s = (w_s_cnt + beta_) / (n_s + beta_sum_); 121 | proposal_t = (w_t_cnt + beta_) / (n_t + beta_sum_); 122 | 123 | nominator = n_td_alpha * n_tw_beta * n_s_beta_sum * proposal_s; 124 | denominator = n_sd_alpha * n_sw_beta * n_t_beta_sum * proposal_t; 125 | 126 | pi = nominator / denominator; 127 | 128 | m = -(rejection < pi); 129 | s = (t & m) | (s & ~m); 130 | } 131 | // Doc proposal 132 | double n_td_or_alpha = rng_.rand_double() * 133 | (doc->Size() + alpha_sum_); 134 | if (n_td_or_alpha < doc->Size()) 135 | { 136 | int32_t t_idx = static_cast(n_td_or_alpha); 137 | t = doc->Topic(t_idx); 138 | } 139 | else 140 | { 141 | t = rng_.rand_k(num_topic_); 142 | } 143 | if (t != s) 144 | { 145 | rejection = rng_.rand_double(); 146 | 147 | w_t_cnt = word_topic_row.At(t); 148 | w_s_cnt = word_topic_row.At(s); 149 | n_t = summary_row.At(t); 150 | n_s = summary_row.At(s); 151 | 152 | n_td_alpha = doc_topic_counter_->At(t) + alpha_; 153 | n_sd_alpha = doc_topic_counter_->At(s) + alpha_; 154 | n_tw_beta = w_t_cnt + beta_; 155 | n_t_beta_sum = n_t + beta_sum_; 156 | n_sw_beta = w_s_cnt + beta_; 157 | n_s_beta_sum = n_s + beta_sum_; 158 | if (s == old_topic) 159 | { 160 | --n_sd_alpha; 161 | n_sw_beta -= subtractor_; 162 | n_s_beta_sum -= subtractor_; 163 | } 164 | if (t == old_topic) 165 | { 166 | --n_td_alpha; 167 | n_tw_beta -= subtractor_; 168 | n_t_beta_sum -= subtractor_; 169 | 170 | } 171 | 172 | proposal_s = (doc_topic_counter_->At(s) + alpha_); 173 | proposal_t = (doc_topic_counter_->At(t) + alpha_); 174 | 175 | nominator = n_td_alpha * n_tw_beta * n_s_beta_sum * proposal_s; 176 | denominator = n_sd_alpha * n_sw_beta * n_t_beta_sum * proposal_t; 177 | 178 | pi = nominator / denominator; 179 | 180 | m = -(rejection < pi); 181 | s = (t & m) | (s & ~m); 182 | } 183 | } 184 | return s; 185 | } 186 | 187 | int32_t LightDocSampler::ApproxSample(Document* doc, 188 | int32_t word, int32_t old_topic, int32_t s, 189 | ModelBase* model, AliasTable* alias) 190 | { 191 | float n_tw_beta, n_sw_beta, n_t_beta_sum, n_s_beta_sum; 192 | float nominator, denominator; 193 | double rejection, pi; 194 | int32_t m, t; 195 | 196 | Row& word_topic_row = model->GetWordTopicRow(word); 197 | Row& summary_row = model->GetSummaryRow(); 198 | 199 | for (int32_t i = 0; i < mh_steps_; ++i) 200 | { 201 | // word proposal 202 | t = alias->Propose(word, rng_); 203 | if (t != s) 204 | { 205 | nominator = doc_topic_counter_->At(t) + alpha_; 206 | denominator = doc_topic_counter_->At(s) + alpha_; 207 | if (t == old_topic) 208 | { 209 | nominator -= 1; 210 | } 211 | if (s == old_topic) 212 | { 213 | denominator -= 1; 214 | } 215 | pi = nominator / denominator; 216 | rejection = rng_.rand_double(); 217 | m = -(rejection < pi); 218 | s = (t & m) | (s & ~m); 219 | } 220 | // doc proposal 221 | double n_td_or_alpha = rng_.rand_double() * 222 | (doc->Size() + alpha_sum_); 223 | if (n_td_or_alpha < doc->Size()) 224 | { 225 | int32_t t_idx = static_cast(n_td_or_alpha); 226 | t = doc->Topic(t_idx); 227 | } 228 | else 229 | { 230 | t = rng_.rand_k(num_topic_); 231 | } 232 | if (t != s) 233 | { 234 | n_tw_beta = word_topic_row.At(t) + beta_; 235 | n_sw_beta = word_topic_row.At(s) + beta_; 236 | n_t_beta_sum = summary_row.At(t) + beta_sum_; 237 | n_s_beta_sum = summary_row.At(s) + beta_sum_; 238 | 239 | if (t == old_topic) 240 | { 241 | n_tw_beta -= subtractor_; 242 | n_t_beta_sum -= subtractor_; 243 | } 244 | if (s == old_topic) 245 | { 246 | n_sw_beta -= subtractor_; 247 | n_s_beta_sum -= subtractor_; 248 | } 249 | 250 | nominator = n_tw_beta * n_s_beta_sum; 251 | denominator = n_sw_beta * n_t_beta_sum; 252 | pi = nominator / denominator; 253 | rejection = rng_.rand_double(); 254 | m = -(rejection < pi); 255 | s = (t & m) | (s & ~m); 256 | } 257 | } 258 | return s; 259 | } 260 | } // namespace lightlda 261 | } // namespace multiverso 262 | -------------------------------------------------------------------------------- /src/model.cpp: -------------------------------------------------------------------------------- 1 | #include "model.h" 2 | 3 | #ifdef _MSC_VER 4 | #include 5 | #include 6 | #else 7 | #include 8 | #include 9 | #endif 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "meta.h" 16 | #include "trainer.h" 17 | 18 | #include 19 | #include 20 | 21 | namespace multiverso { namespace lightlda 22 | { 23 | LocalModel::LocalModel(Meta * meta) : word_topic_table_(nullptr), 24 | summary_table_(nullptr), meta_(meta) 25 | { 26 | CreateTable(); 27 | } 28 | 29 | void LocalModel::Init() 30 | { 31 | LoadTable(); 32 | } 33 | 34 | void LocalModel::CreateTable() 35 | { 36 | int32_t num_vocabs = Config::num_vocabs; 37 | int32_t num_topics = Config::num_topics; 38 | multiverso::Format dense_format = multiverso::Format::Dense; 39 | multiverso::Format sparse_format = multiverso::Format::Sparse; 40 | Type int_type = Type::Int; 41 | Type longlong_type = Type::LongLong; 42 | 43 | word_topic_table_.reset(new Table(kWordTopicTable, num_vocabs, num_topics, 44 | int_type, dense_format)); 45 | summary_table_.reset(new Table(kSummaryRow, 1, num_topics, 46 | longlong_type, dense_format)); 47 | } 48 | 49 | void LocalModel::LoadTable() 50 | { 51 | #ifdef _MSC_VER 52 | Log::Info("loading model\n"); 53 | //set regex for model files 54 | std::string prefix = "server_[[:digit:]]+_table_"; 55 | std::string suffix = ".model"; 56 | std::ostringstream wordtopic_regstr; 57 | wordtopic_regstr << prefix << kWordTopicTable << suffix; 58 | std::ostringstream summary_regstr; 59 | summary_regstr << prefix << kSummaryRow << suffix; 60 | std::regex model_wordtopic_regex(wordtopic_regstr.str()); 61 | std::regex model_summary_regex(summary_regstr.str()); 62 | 63 | //look for model files & load them 64 | intptr_t handle; 65 | _finddata_t fileinto; 66 | std::string input_dir = Config::input_dir; 67 | handle = _findfirst(input_dir.append("\\*").c_str(), &fileinto); 68 | if (handle != -1) 69 | { 70 | do 71 | { 72 | if (std::regex_match(fileinto.name, fileinto.name + std::strlen(fileinto.name), model_wordtopic_regex)) 73 | { 74 | Log::Info("loading word topic table[%s]\n", fileinto.name); 75 | LoadWordTopicTable(Config::input_dir + "/" + fileinto.name); 76 | } 77 | else if (std::regex_match(fileinto.name, fileinto.name + std::strlen(fileinto.name), model_summary_regex)) 78 | { 79 | Log::Info("loading summary table[%s]\n", fileinto.name); 80 | LoadSummaryTable(Config::input_dir + "/" + fileinto.name); 81 | } 82 | } while (!_findnext(handle, &fileinto)); 83 | } 84 | else 85 | { 86 | Log::Fatal("model dir does not exist : %s\n", Config::input_dir.c_str()); 87 | } 88 | _findclose(handle); 89 | #else 90 | Log::Info("loading model\n"); 91 | //set regex for model files 92 | regex_t model_wordtopic_regex; 93 | regex_t model_summary_regex; 94 | std::string prefix = "server_[[:digit:]]+_table_"; 95 | std::string suffix = ".model"; 96 | std::ostringstream wordtopic_regstr; 97 | wordtopic_regstr << prefix << kWordTopicTable << suffix; 98 | std::ostringstream summary_regstr; 99 | summary_regstr << prefix << kSummaryRow << suffix; 100 | regcomp(&model_wordtopic_regex, wordtopic_regstr.str().c_str(), REG_EXTENDED); 101 | regcomp(&model_summary_regex, summary_regstr.str().c_str(), REG_EXTENDED); 102 | 103 | //look for model files & load them 104 | DIR *dir; 105 | struct dirent *ent; 106 | if ((dir = opendir(Config::input_dir.c_str())) != NULL) 107 | { 108 | while ((ent = readdir(dir)) != NULL) 109 | { 110 | if (!regexec(&model_wordtopic_regex, ent->d_name, 0, NULL, 0)) 111 | { 112 | Log::Info("loading word topic table[%s]\n", ent->d_name); 113 | LoadWordTopicTable(Config::input_dir + "/" + ent->d_name); 114 | } 115 | else if (!regexec(&model_summary_regex, ent->d_name, 0, NULL, 0)) 116 | { 117 | Log::Info("loading summary table[%s]\n", ent->d_name); 118 | LoadSummaryTable(Config::input_dir + "/" + ent->d_name); 119 | } 120 | } 121 | closedir(dir); 122 | } 123 | else 124 | { 125 | Log::Fatal("model dir does not exist : %s\n", Config::input_dir.c_str()); 126 | } 127 | regfree(&model_wordtopic_regex); 128 | regfree(&model_summary_regex); 129 | #endif 130 | } 131 | 132 | void LocalModel::LoadWordTopicTable(const std::string& model_fname) 133 | { 134 | multiverso::Format dense_format = multiverso::Format::Dense; 135 | multiverso::Format sparse_format = multiverso::Format::Sparse; 136 | std::ifstream model_file(model_fname, std::ios::in); 137 | std::string line; 138 | while (getline(model_file, line)) 139 | { 140 | std::stringstream ss(line); 141 | std::string word; 142 | std::string fea; 143 | std::vector feas; 144 | int32_t word_id, topic_id, freq; 145 | //assign word id 146 | ss >> word; 147 | word_id = std::stoi(word); 148 | if (meta_->tf(word_id) > 0) 149 | { 150 | //set row 151 | if (meta_->tf(word_id) * kLoadFactor > Config::num_topics) 152 | { 153 | word_topic_table_->SetRow(word_id, dense_format, 154 | Config::num_topics); 155 | } 156 | else 157 | { 158 | word_topic_table_->SetRow(word_id, sparse_format, 159 | meta_->tf(word_id) * kLoadFactor); 160 | } 161 | //get row 162 | Row * row = static_cast*> 163 | (word_topic_table_->GetRow(word_id)); 164 | 165 | //add features to row 166 | while (ss >> fea) 167 | { 168 | size_t pos = fea.find_last_of(':'); 169 | if (pos != std::string::npos) 170 | { 171 | topic_id = std::stoi(fea.substr(0, pos)); 172 | freq = std::stoi(fea.substr(pos + 1)); 173 | row->Add(topic_id, freq); 174 | } 175 | else 176 | { 177 | Log::Fatal("bad format of model: %s\n", line.c_str()); 178 | } 179 | } 180 | } 181 | } 182 | model_file.close(); 183 | } 184 | 185 | void LocalModel::LoadSummaryTable(const std::string& model_fname) 186 | { 187 | Row * row = static_cast*> 188 | (summary_table_->GetRow(0)); 189 | std::ifstream model_file(model_fname, std::ios::in); 190 | std::string line; 191 | if (getline(model_file, line)) 192 | { 193 | std::stringstream ss(line); 194 | std::string fea; 195 | std::vector feas; 196 | int32_t topic_id, freq; 197 | //skip word id 198 | ss >> fea; 199 | //add features to row 200 | while (ss >> fea) 201 | { 202 | size_t pos = fea.find_last_of(':'); 203 | if (pos != std::string::npos) 204 | { 205 | topic_id = std::stoi(fea.substr(0, pos)); 206 | freq = std::stoi(fea.substr(pos + 1)); 207 | row->Add(topic_id, freq); 208 | } 209 | else 210 | { 211 | Log::Fatal("bad format of model: %s\n", line.c_str()); 212 | } 213 | } 214 | } 215 | model_file.close(); 216 | } 217 | 218 | void LocalModel::AddWordTopicRow( 219 | integer_t word_id, integer_t topic_id, int32_t delta) 220 | { 221 | Log::Fatal("Not implemented yet\n"); 222 | } 223 | 224 | void LocalModel::AddSummaryRow(integer_t topic_id, int64_t delta) 225 | { 226 | Log::Fatal("Not implemented yet\n"); 227 | } 228 | 229 | Row& LocalModel::GetWordTopicRow(integer_t word) 230 | { 231 | return *(static_cast*>(word_topic_table_->GetRow(word))); 232 | } 233 | 234 | Row& LocalModel::GetSummaryRow() 235 | { 236 | return *(static_cast*>(summary_table_->GetRow(0))); 237 | } 238 | 239 | Row& PSModel::GetWordTopicRow(integer_t word_id) 240 | { 241 | return trainer_->GetRow(kWordTopicTable, word_id); 242 | } 243 | 244 | Row& PSModel::GetSummaryRow() 245 | { 246 | return trainer_->GetRow(kSummaryRow, 0); 247 | } 248 | 249 | void PSModel::AddWordTopicRow( 250 | integer_t word_id, integer_t topic_id, int32_t delta) 251 | { 252 | trainer_->Add(kWordTopicTable, word_id, topic_id, delta); 253 | } 254 | 255 | void PSModel::AddSummaryRow(integer_t topic_id, int64_t delta) 256 | { 257 | trainer_->Add(kSummaryRow, 0, topic_id, delta); 258 | } 259 | 260 | } // namespace lightlda 261 | } // namespace multiverso 262 | -------------------------------------------------------------------------------- /src/lightlda.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "trainer.h" 3 | #include "alias_table.h" 4 | #include "data_stream.h" 5 | #include "data_block.h" 6 | #include "document.h" 7 | #include "meta.h" 8 | #include "util.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace multiverso { namespace lightlda 16 | { 17 | class LightLDA 18 | { 19 | public: 20 | static void Run(int argc, char** argv) 21 | { 22 | Config::Init(argc, argv); 23 | 24 | AliasTable* alias_table = new AliasTable(); 25 | Barrier* barrier = new Barrier(Config::num_local_workers); 26 | meta.Init(); 27 | std::vector trainers; 28 | for (int32_t i = 0; i < Config::num_local_workers; ++i) 29 | { 30 | Trainer* trainer = new Trainer(alias_table, barrier, &meta); 31 | trainers.push_back(trainer); 32 | } 33 | 34 | ParamLoader* param_loader = new ParamLoader(); 35 | multiverso::Config config; 36 | config.num_servers = Config::num_servers; 37 | config.num_aggregator = Config::num_aggregator; 38 | config.server_endpoint_file = Config::server_file; 39 | 40 | Multiverso::Init(trainers, param_loader, config, &argc, &argv); 41 | 42 | Log::ResetLogFile("LightLDA." 43 | + std::to_string(clock()) + ".log"); 44 | 45 | data_stream = CreateDataStream(); 46 | InitMultiverso(); 47 | Train(); 48 | 49 | Multiverso::Close(); 50 | 51 | for (auto& trainer : trainers) 52 | { 53 | delete trainer; 54 | } 55 | delete param_loader; 56 | 57 | DumpDocTopic(); 58 | 59 | delete data_stream; 60 | delete barrier; 61 | delete alias_table; 62 | } 63 | private: 64 | static void Train() 65 | { 66 | Multiverso::BeginTrain(); 67 | for (int32_t i = 0; i < Config::num_iterations; ++i) 68 | { 69 | Multiverso::BeginClock(); 70 | // Train corpus block by block 71 | for (int32_t block = 0; block < Config::num_blocks; ++block) 72 | { 73 | data_stream->BeforeDataAccess(); 74 | DataBlock& data_block = data_stream->CurrDataBlock(); 75 | data_block.set_meta(&meta.local_vocab(block)); 76 | int32_t num_slice = meta.local_vocab(block).num_slice(); 77 | std::vector data(num_slice); 78 | // Train datablock slice by slice 79 | for (int32_t slice = 0; slice < num_slice; ++slice) 80 | { 81 | LDADataBlock* lda_block = &data[slice]; 82 | lda_block->set_data(&data_block); 83 | lda_block->set_iteration(i); 84 | lda_block->set_block(block); 85 | lda_block->set_slice(slice); 86 | Multiverso::PushDataBlock(lda_block); 87 | } 88 | Multiverso::Wait(); 89 | data_stream->EndDataAccess(); 90 | } 91 | Multiverso::EndClock(); 92 | } 93 | Multiverso::EndTrain(); 94 | } 95 | 96 | static void InitMultiverso() 97 | { 98 | Multiverso::BeginConfig(); 99 | CreateTable(); 100 | ConfigTable(); 101 | Initialize(); 102 | Multiverso::EndConfig(); 103 | } 104 | 105 | static void Initialize() 106 | { 107 | xorshift_rng rng; 108 | for (int32_t block = 0; block < Config::num_blocks; ++block) 109 | { 110 | data_stream->BeforeDataAccess(); 111 | DataBlock& data_block = data_stream->CurrDataBlock(); 112 | int32_t num_slice = meta.local_vocab(block).num_slice(); 113 | for (int32_t slice = 0; slice < num_slice; ++slice) 114 | { 115 | for (int32_t i = 0; i < data_block.Size(); ++i) 116 | { 117 | Document* doc = data_block.GetOneDoc(i); 118 | int32_t& cursor = doc->Cursor(); 119 | if (slice == 0) cursor = 0; 120 | int32_t last_word = meta.local_vocab(block).LastWord(slice); 121 | for (; cursor < doc->Size(); ++cursor) 122 | { 123 | if (doc->Word(cursor) > last_word) break; 124 | // Init the latent variable 125 | if (!Config::warm_start) 126 | doc->SetTopic(cursor, rng.rand_k(Config::num_topics)); 127 | // Init the server table 128 | Multiverso::AddToServer(kWordTopicTable, 129 | doc->Word(cursor), doc->Topic(cursor), 1); 130 | Multiverso::AddToServer(kSummaryRow, 131 | 0, doc->Topic(cursor), 1); 132 | } 133 | } 134 | Multiverso::Flush(); 135 | } 136 | data_stream->EndDataAccess(); 137 | } 138 | } 139 | 140 | static void DumpDocTopic() 141 | { 142 | Row doc_topic_counter(0, Format::Sparse, kMaxDocLength); 143 | for (int32_t block = 0; block < Config::num_blocks; ++block) 144 | { 145 | std::ofstream fout("doc_topic." + std::to_string(block)); 146 | data_stream->BeforeDataAccess(); 147 | DataBlock& data_block = data_stream->CurrDataBlock(); 148 | for (int i = 0; i < data_block.Size(); ++i) 149 | { 150 | Document* doc = data_block.GetOneDoc(i); 151 | doc_topic_counter.Clear(); 152 | doc->GetDocTopicVector(doc_topic_counter); 153 | fout << i << " "; // doc id 154 | Row::iterator iter = doc_topic_counter.Iterator(); 155 | while (iter.HasNext()) 156 | { 157 | fout << " " << iter.Key() << ":" << iter.Value(); 158 | iter.Next(); 159 | } 160 | fout << std::endl; 161 | } 162 | data_stream->EndDataAccess(); 163 | } 164 | } 165 | 166 | static void CreateTable() 167 | { 168 | int32_t num_vocabs = Config::num_vocabs; 169 | int32_t num_topics = Config::num_topics; 170 | Type int_type = Type::Int; 171 | Type longlong_type = Type::LongLong; 172 | multiverso::Format dense_format = multiverso::Format::Dense; 173 | multiverso::Format sparse_format = multiverso::Format::Sparse; 174 | 175 | Multiverso::AddServerTable(kWordTopicTable, num_vocabs, 176 | num_topics, int_type, dense_format); 177 | Multiverso::AddCacheTable(kWordTopicTable, num_vocabs, 178 | num_topics, int_type, dense_format, Config::model_capacity); 179 | Multiverso::AddAggregatorTable(kWordTopicTable, num_vocabs, 180 | num_topics, int_type, dense_format, Config::delta_capacity); 181 | 182 | Multiverso::AddTable(kSummaryRow, 1, Config::num_topics, 183 | longlong_type, dense_format); 184 | } 185 | 186 | static void ConfigTable() 187 | { 188 | multiverso::Format dense_format = multiverso::Format::Dense; 189 | multiverso::Format sparse_format = multiverso::Format::Sparse; 190 | for (int32_t word = 0; word < Config::num_vocabs; ++word) 191 | { 192 | if (meta.tf(word) > 0) 193 | { 194 | if (meta.tf(word) * kLoadFactor > Config::num_topics) 195 | { 196 | Multiverso::SetServerRow(kWordTopicTable, 197 | word, dense_format, Config::num_topics); 198 | Multiverso::SetCacheRow(kWordTopicTable, 199 | word, dense_format, Config::num_topics); 200 | } 201 | else 202 | { 203 | Multiverso::SetServerRow(kWordTopicTable, 204 | word, sparse_format, meta.tf(word) * kLoadFactor); 205 | Multiverso::SetCacheRow(kWordTopicTable, 206 | word, sparse_format, meta.tf(word) * kLoadFactor); 207 | } 208 | } 209 | if (meta.local_tf(word) > 0) 210 | { 211 | if (meta.local_tf(word) * 2 * kLoadFactor > Config::num_topics) 212 | Multiverso::SetAggregatorRow(kWordTopicTable, 213 | word, dense_format, Config::num_topics); 214 | else 215 | Multiverso::SetAggregatorRow(kWordTopicTable, word, 216 | sparse_format, meta.local_tf(word) * 2 * kLoadFactor); 217 | } 218 | } 219 | } 220 | private: 221 | /*! \brief training data access */ 222 | static IDataStream* data_stream; 223 | /*! \brief training data meta information */ 224 | static Meta meta; 225 | }; 226 | IDataStream* LightLDA::data_stream = nullptr; 227 | Meta LightLDA::meta; 228 | 229 | } // namespace lightlda 230 | } // namespace multiverso 231 | 232 | 233 | int main(int argc, char** argv) 234 | { 235 | multiverso::lightlda::LightLDA::Run(argc, argv); 236 | return 0; 237 | } 238 | -------------------------------------------------------------------------------- /src/alias_table.cpp: -------------------------------------------------------------------------------- 1 | #include "alias_table.h" 2 | 3 | #include "common.h" 4 | #include "model.h" 5 | #include "util.h" 6 | #include "meta.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace multiverso { namespace lightlda 14 | { 15 | _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_; 16 | _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_int_; 17 | _THREAD_LOCAL std::vector>* AliasTable::L_; 18 | _THREAD_LOCAL std::vector>* AliasTable::H_; 19 | 20 | AliasTable::AliasTable() 21 | { 22 | memory_size_ = Config::alias_capacity / sizeof(int32_t); 23 | num_vocabs_ = Config::num_vocabs; 24 | num_topics_ = Config::num_topics; 25 | beta_ = Config::beta; 26 | beta_sum_ = beta_ * num_vocabs_; 27 | memory_block_ = new int32_t[memory_size_]; 28 | 29 | beta_kv_vector_ = new int32_t[2 * num_topics_]; 30 | 31 | height_.resize(num_vocabs_); 32 | mass_.resize(num_vocabs_); 33 | } 34 | 35 | AliasTable::~AliasTable() 36 | { 37 | delete[] memory_block_; 38 | delete[] beta_kv_vector_; 39 | } 40 | 41 | void AliasTable::Init(AliasTableIndex* table_index) 42 | { 43 | table_index_ = table_index; 44 | } 45 | 46 | int32_t AliasTable::Build(int32_t word, ModelBase* model) 47 | { 48 | if (q_w_proportion_ == nullptr) 49 | q_w_proportion_ = new std::vector(num_topics_); 50 | if (q_w_proportion_int_ == nullptr) 51 | q_w_proportion_int_ = new std::vector(num_topics_); 52 | if (L_ == nullptr) 53 | L_ = new std::vector>(num_topics_); 54 | if (H_ == nullptr) 55 | H_ = new std::vector>(num_topics_); 56 | // Compute the proportion 57 | Row& summary_row = model->GetSummaryRow(); 58 | if (word == -1) // build alias row for beta 59 | { 60 | beta_mass_ = 0; 61 | for (int32_t k = 0; k < num_topics_; ++k) 62 | { 63 | (*q_w_proportion_)[k] = beta_ / (summary_row.At(k) + beta_sum_); 64 | beta_mass_ += (*q_w_proportion_)[k]; 65 | } 66 | AliasMultinomialRNG(num_topics_, beta_mass_, beta_height_, 67 | beta_kv_vector_); 68 | } 69 | else // build alias row for word 70 | { 71 | WordEntry& word_entry = table_index_->word_entry(word); 72 | Row& word_topic_row = model->GetWordTopicRow(word); 73 | int32_t size = 0; 74 | mass_[word] = 0; 75 | if (word_entry.is_dense) 76 | { 77 | size = num_topics_; 78 | for (int32_t k = 0; k < num_topics_; ++k) 79 | { 80 | (*q_w_proportion_)[k] = (word_topic_row.At(k) + beta_) 81 | / (summary_row.At(k) + beta_sum_); 82 | mass_[word] += (*q_w_proportion_)[k]; 83 | } 84 | } 85 | else // word_entry.is_dense = false 86 | { 87 | word_entry.capacity = word_topic_row.NonzeroSize(); 88 | int32_t* idx_vector = memory_block_ + word_entry.begin_offset 89 | + 2 * word_entry.capacity; 90 | Row::iterator iter = word_topic_row.Iterator(); 91 | while (iter.HasNext()) 92 | { 93 | int32_t t = iter.Key(); 94 | int32_t n_tw = iter.Value(); 95 | int64_t n_t = summary_row.At(t); 96 | idx_vector[size] = t; 97 | (*q_w_proportion_)[size] = (n_tw) / (n_t + beta_sum_); 98 | mass_[word] += (*q_w_proportion_)[size]; 99 | ++size; 100 | iter.Next(); 101 | } 102 | if (size == 0) 103 | { 104 | Log::Error("Fail to build alias row, capacity of row = %d\n", 105 | word_topic_row.NonzeroSize()); 106 | } 107 | } 108 | AliasMultinomialRNG(size, mass_[word], height_[word], 109 | memory_block_ + word_entry.begin_offset); 110 | } 111 | return 0; 112 | } 113 | 114 | int32_t AliasTable::Propose(int32_t word, xorshift_rng& rng) 115 | { 116 | WordEntry& word_entry = table_index_->word_entry(word); 117 | int32_t* kv_vector = memory_block_ + word_entry.begin_offset; 118 | int32_t capacity = word_entry.capacity; 119 | if (word_entry.is_dense) 120 | { 121 | auto sample = rng.rand(); 122 | int32_t idx = sample / height_[word]; 123 | if (capacity <= idx) idx = capacity - 1; 124 | 125 | int32_t* p = kv_vector + 2 * idx; 126 | int32_t k = *p++; 127 | int32_t v = *p; 128 | int32_t m = -(sample < v); 129 | return (idx & m) | (k & ~m); 130 | } 131 | else 132 | { 133 | auto sample = rng.rand_double() * (mass_[word] + beta_mass_); 134 | if (sample < mass_[word]) 135 | { 136 | int32_t* idx_vector = kv_vector + 2 * word_entry.capacity; 137 | auto n_kw_sample = rng.rand(); 138 | int32_t idx = n_kw_sample / height_[word]; 139 | if (capacity <= idx) idx = capacity - 1; 140 | int32_t* p = kv_vector + 2 * idx; 141 | int32_t k = *p++; 142 | int32_t v = *p; 143 | int32_t id = idx_vector[idx]; 144 | int32_t m = -(n_kw_sample < v); 145 | return (id & m) | (idx_vector[k] & ~m); 146 | } 147 | else 148 | { 149 | auto beta_sample = rng.rand(); 150 | int32_t idx = beta_sample / beta_height_; 151 | if (num_topics_ <= idx) idx = num_topics_ - 1; 152 | int32_t* p = beta_kv_vector_ + 2 * idx; 153 | int32_t k = *p++; 154 | int32_t v = *p; 155 | int32_t m = -(beta_sample < v); 156 | return (idx & m) | (k & ~m); 157 | } 158 | } 159 | } 160 | 161 | void AliasTable::Clear() 162 | { 163 | delete q_w_proportion_; 164 | q_w_proportion_ = nullptr; 165 | delete q_w_proportion_int_; 166 | q_w_proportion_int_ = nullptr; 167 | delete L_; 168 | L_ = nullptr; 169 | delete H_; 170 | H_ = nullptr; 171 | } 172 | 173 | 174 | void AliasTable::AliasMultinomialRNG(int32_t size, float mass, int32_t& height, 175 | int32_t* kv_vector) 176 | { 177 | int32_t mass_int = 0x7fffffff; 178 | int32_t a_int = mass_int / size; 179 | mass_int = a_int * size; 180 | height = a_int; 181 | int64_t mass_sum = 0; 182 | for (int32_t i = 0; i < size; ++i) 183 | { 184 | (*q_w_proportion_)[i] /= mass; 185 | (*q_w_proportion_int_)[i] = 186 | static_cast((*q_w_proportion_)[i] * mass_int); 187 | mass_sum += (*q_w_proportion_int_)[i]; 188 | } 189 | if (mass_sum > mass_int) 190 | { 191 | int32_t more = static_cast(mass_sum - mass_int); 192 | int32_t id = 0; 193 | for (int32_t i = 0; i < more;) 194 | { 195 | if ((*q_w_proportion_int_)[id] >= 1) 196 | { 197 | --(*q_w_proportion_int_)[id]; 198 | ++i; 199 | } 200 | id = (id + 1) % size; 201 | } 202 | } 203 | 204 | if (mass_sum < mass_int) 205 | { 206 | int32_t more = static_cast(mass_int - mass_sum); 207 | int32_t id = 0; 208 | for (int32_t i = 0; i < more; ++i) 209 | { 210 | ++(*q_w_proportion_int_)[id]; 211 | id = (id + 1) % size; 212 | } 213 | } 214 | 215 | for (int32_t k = 0; k < size; ++k) 216 | { 217 | int32_t* p = kv_vector + 2 * k; 218 | *p = k; ++p; 219 | *p = (k + 1) * height; 220 | } 221 | int32_t L_head = 0, L_tail = 0, H_head = 0, H_tail = 0; 222 | for (int32_t k = 0; k < size; ++k) 223 | { 224 | int32_t val = (*q_w_proportion_int_)[k]; 225 | if (val < height) 226 | { 227 | (*L_)[L_tail].first = k; 228 | (*L_)[L_tail].second = val; 229 | ++L_tail; 230 | } 231 | else 232 | { 233 | (*H_)[H_tail].first = k; 234 | (*H_)[H_tail].second = val; 235 | ++H_tail; 236 | } 237 | } 238 | while (L_head != L_tail && H_head != H_tail) 239 | { 240 | auto& l_pl = (*L_)[L_head++]; 241 | auto& h_ph = (*H_)[H_head++]; 242 | int32_t* p = kv_vector + 2 * l_pl.first; 243 | *p = h_ph.first; ++p; 244 | *p = l_pl.first * height + l_pl.second; 245 | 246 | auto sum = h_ph.second + l_pl.second; 247 | if (sum > 2 * height) 248 | { 249 | (*H_)[H_tail].first = h_ph.first; 250 | (*H_)[H_tail].second = sum - height; 251 | ++H_tail; 252 | } 253 | else 254 | { 255 | (*L_)[L_tail].first = h_ph.first; 256 | (*L_)[L_tail].second = sum - height; 257 | ++L_tail; 258 | } 259 | } 260 | while (L_head != L_tail) 261 | { 262 | auto first = (*L_)[L_head].first; 263 | auto second = (*L_)[L_head].second; 264 | int32_t* p = kv_vector + 2 * first; 265 | *p = first; ++p; 266 | *p = first * height + second; 267 | ++L_head; 268 | } 269 | while (H_head != H_tail) 270 | { 271 | auto first = (*H_)[H_head].first; 272 | auto second = (*H_)[H_head].second; 273 | int32_t* p = kv_vector + 2 * first; 274 | *p = first; ++p; 275 | *p = first * height + second; 276 | ++H_head; 277 | } 278 | } 279 | } // namespace lightlda 280 | } // namespace multiverso 281 | -------------------------------------------------------------------------------- /preprocess/dump_binary.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | * \file dump_binary.cpp 3 | * \brief Preprocessing tool for converting LibSVM data to LightLDA input binary format 4 | * Usage: 5 | * dump_binary 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace lightlda 19 | { 20 | /* 21 | * Output file format: 22 | * 1, the first 4 byte indicates the number of docs in this block 23 | * 2, the 4 * (doc_num + 1) bytes indicate the offset of reach doc 24 | * an example 25 | * 3 // there are 3 docs in this block 26 | * 0 // the offset of the 1-st doc 27 | * 10 // the offset of the 2-nd doc, with this we know the length of the 1-st doc is 5 = 10/2 28 | * 16 // the offset of the 3-rd doc, with this we know the length of the 2-nd doc is 3 = (16-10)/2 29 | * 24 // with this, we know the length of the 3-rd doc is 4 = (24 - 16)/2 30 | * w11 t11 w12 t12 w13 t13 w14 t14 w15 t15 // the token-topic list of the 1-st doc 31 | * w21 t21 w22 t22 w23 t23 // the token-topic list of the 2-nd doc 32 | * w31 t31 w32 t32 w33 t33 w34 t34 // the token-topic list of the 3-rd doc 33 | 34 | * the class block_stream helps generate such binary format file, usage: 35 | * int doc_num = 3; 36 | * int64_t* offset_buf = new int64_t[doc_num + 1]; 37 | * 38 | * block_stream bs; 39 | * bs.open("block"); 40 | * bs.write_empty_header(offset_buf, doc_num); 41 | * ... 42 | * // update offset_buf and doc_num... 43 | 44 | * bs.write_doc(doc_buf, doc_idx); 45 | * ... 46 | * bs.write_real_header(offset_buf, doc_num); 47 | * bs.close(); 48 | */ 49 | class block_stream 50 | { 51 | public: 52 | block_stream(); 53 | ~block_stream(); 54 | bool open(const std::string file_name); 55 | bool write_doc(int32_t* int32_buf, int32_t count); 56 | bool write_empty_header(int64_t* int64_buf, int64_t count); 57 | bool write_real_header(int64_t* int64_buf, int64_t count); 58 | bool seekp(int64_t pos); 59 | bool close(); 60 | private: 61 | // assuming each doc has 500 tokens in average, 62 | // the block_buf_ will hold 1 million document, 63 | // needs 0.8GB RAM. 64 | const int32_t block_buf_size_ = 1024 * 1024 * 2 * 100; 65 | 66 | std::ofstream stream_; 67 | std::string file_name_; 68 | 69 | int32_t *block_buf_; 70 | int32_t buf_idx_; 71 | 72 | block_stream(const block_stream& other) = delete; 73 | block_stream& operator=(const block_stream& other) = delete; 74 | }; 75 | 76 | /* 77 | (1) open an utf-8 encoded file in binary mode, 78 | get its content line by line. Working around the CTRL-Z issue in Windows text file reading. 79 | (2) assuming each line ends with '\n' 80 | */ 81 | class utf8_stream 82 | { 83 | public: 84 | utf8_stream(); 85 | ~utf8_stream(); 86 | 87 | bool open(const std::string& file_name); 88 | 89 | /* 90 | return true if successfully get a line (may be empty), false if not. 91 | It is user's task to verify whether a line is empty or not. 92 | */ 93 | bool getline(std::string &line); 94 | int64_t count_line(); 95 | bool close(); 96 | private: 97 | bool block_is_empty(); 98 | bool fill_block(); 99 | std::ifstream stream_; 100 | std::string file_name_; 101 | const int32_t block_buf_size_ = 1024 * 1024 * 800; 102 | // const int32_t block_buf_size_ = 2; 103 | std::string block_buf_; 104 | std::string::size_type buf_idx_; 105 | std::string::size_type buf_end_; 106 | 107 | utf8_stream(const utf8_stream& other) = delete; 108 | utf8_stream& operator=(const utf8_stream& other) = delete; 109 | }; 110 | 111 | block_stream::block_stream() 112 | : buf_idx_(0) 113 | { 114 | block_buf_ = new int32_t[block_buf_size_]; 115 | } 116 | block_stream::~block_stream() 117 | { 118 | if (block_buf_) 119 | { 120 | delete[]block_buf_; 121 | } 122 | } 123 | 124 | bool block_stream::open(const std::string file_name) 125 | { 126 | file_name_ = file_name; 127 | stream_.open(file_name_, std::ios::out | std::ios::binary); 128 | return stream_.good(); 129 | } 130 | 131 | bool block_stream::seekp(int64_t pos) 132 | { 133 | stream_.seekp(pos); 134 | return true; 135 | } 136 | 137 | bool block_stream::write_empty_header(int64_t* int64_buf, int64_t count) 138 | { 139 | stream_.write(reinterpret_cast(&count), sizeof(int64_t)); 140 | stream_.write(reinterpret_cast(int64_buf), 141 | sizeof(int64_t)* (count + 1)); 142 | return true; 143 | } 144 | 145 | bool block_stream::write_real_header(int64_t* int64_buf, int64_t count) 146 | { 147 | // clear off the block_buf_, if any content not dumped to disk 148 | if (buf_idx_ != 0) 149 | { 150 | stream_.write(reinterpret_cast (block_buf_), 151 | sizeof(int32_t)* buf_idx_); 152 | buf_idx_ = 0; 153 | } 154 | 155 | seekp(0); 156 | write_empty_header(int64_buf, count); 157 | return true; 158 | } 159 | 160 | bool block_stream::write_doc(int32_t* int32_buf, int32_t count) 161 | { 162 | if (buf_idx_ + count > block_buf_size_) 163 | { 164 | stream_.write(reinterpret_cast(block_buf_), 165 | sizeof(int32_t)* buf_idx_); 166 | buf_idx_ = 0; 167 | } 168 | memcpy(block_buf_ + buf_idx_, int32_buf, count * sizeof(int32_t)); 169 | buf_idx_ += count; 170 | return true; 171 | } 172 | 173 | bool block_stream::close() 174 | { 175 | stream_.close(); 176 | return true; 177 | } 178 | 179 | utf8_stream::utf8_stream() 180 | { 181 | block_buf_.resize(block_buf_size_); 182 | } 183 | utf8_stream::~utf8_stream() 184 | { 185 | } 186 | 187 | bool utf8_stream::open(const std::string& file_name) 188 | { 189 | stream_.open(file_name, std::ios::in | std::ios::binary); 190 | buf_idx_ = 0; 191 | buf_end_ = 0; 192 | return stream_.good(); 193 | } 194 | 195 | bool utf8_stream::getline(std::string& line) 196 | { 197 | line = ""; 198 | while (true) 199 | { 200 | if (block_is_empty()) 201 | { 202 | // if the block_buf_ is empty, fill the block_buf_ 203 | if (!fill_block()) 204 | { 205 | // if fail to fill the block_buf_, that means we reach the end of file 206 | if (!line.empty()) 207 | std::cout << "Invalid format, according to our assumption: " 208 | "each line has an \\n. However, we reach here with an non-empty line but not find an \\n"; 209 | return false; 210 | } 211 | } 212 | // the block is not empty now 213 | 214 | std::string::size_type end_pos = block_buf_.find("\n", buf_idx_); 215 | if (end_pos != std::string::npos) 216 | { 217 | // successfully find a new line 218 | line += block_buf_.substr(buf_idx_, end_pos - buf_idx_); 219 | buf_idx_ = end_pos + 1; 220 | return true; 221 | } 222 | else 223 | { 224 | // do not find an \n untile the end of block_buf_ 225 | line += block_buf_.substr(buf_idx_, buf_end_ - buf_idx_); 226 | buf_idx_ = buf_end_; 227 | } 228 | } 229 | return false; 230 | } 231 | 232 | int64_t utf8_stream::count_line() 233 | { 234 | char* buffer = &block_buf_[0]; 235 | 236 | int64_t line_num = 0; 237 | while (true) 238 | { 239 | stream_.read(buffer, block_buf_size_); 240 | int32_t end_pos = static_cast(stream_.gcount()); 241 | if (end_pos == 0) 242 | { 243 | break; 244 | } 245 | line_num += std::count(buffer, buffer + end_pos, '\n'); 246 | } 247 | return line_num; 248 | } 249 | 250 | bool utf8_stream::block_is_empty() 251 | { 252 | return buf_idx_ == buf_end_; 253 | } 254 | 255 | bool utf8_stream::fill_block() 256 | { 257 | char* buffer = &block_buf_[0]; 258 | stream_.read(buffer, block_buf_size_); 259 | buf_idx_ = 0; 260 | buf_end_ = static_cast(stream_.gcount()); 261 | return buf_end_ != 0; 262 | } 263 | 264 | bool utf8_stream::close() 265 | { 266 | stream_.close(); 267 | return true; 268 | } 269 | } 270 | 271 | struct Token { 272 | int32_t word_id; 273 | int32_t topic_id; 274 | }; 275 | 276 | int Compare(const Token& token1, const Token& token2) { 277 | return token1.word_id < token2.word_id; 278 | } 279 | 280 | double get_time() 281 | { 282 | auto start = std::chrono::high_resolution_clock::now(); 283 | auto since_epoch = start.time_since_epoch(); 284 | return std::chrono::duration_cast>>(since_epoch).count(); 285 | } 286 | 287 | void split_string(std::string& line, char separator, std::vector& output, bool trimEmpty = false) 288 | { 289 | output.clear(); 290 | 291 | if (line.empty()) 292 | { 293 | return; 294 | } 295 | 296 | // trip whitespace, \r 297 | while (!line.empty()) 298 | { 299 | int32_t last = line.length() - 1; 300 | if (line[last] == ' ' || line[last] == '\r') 301 | { 302 | line.erase(last, 1); 303 | } 304 | else 305 | { 306 | break; 307 | } 308 | } 309 | 310 | std::string::size_type pos; 311 | std::string::size_type lastPos = 0; 312 | 313 | using value_type = std::vector::value_type; 314 | using size_type = std::vector::size_type; 315 | 316 | while (true) 317 | { 318 | pos = line.find_first_of(separator, lastPos); 319 | if (pos == std::string::npos) 320 | { 321 | pos = line.length(); 322 | 323 | if (pos != lastPos || !trimEmpty) 324 | output.push_back(value_type(line.data() + lastPos, 325 | (size_type)pos - lastPos)); 326 | 327 | break; 328 | } 329 | else 330 | { 331 | if (pos != lastPos || !trimEmpty) 332 | output.push_back(value_type(line.data() + lastPos, 333 | (size_type)pos - lastPos)); 334 | } 335 | 336 | lastPos = pos + 1; 337 | } 338 | return; 339 | } 340 | 341 | void count_doc_num(std::string input_doc, int64_t &doc_num) 342 | { 343 | lightlda::utf8_stream stream; 344 | if (!stream.open(input_doc)) 345 | { 346 | std::cout << "Fails to open file: " << input_doc << std::endl; 347 | exit(1); 348 | } 349 | doc_num = stream.count_line(); 350 | stream.close(); 351 | } 352 | 353 | void load_global_tf(std::unordered_map& global_tf_map, 354 | std::string word_tf_file, 355 | int64_t& global_tf_count) 356 | { 357 | lightlda::utf8_stream stream; 358 | if (!stream.open(word_tf_file)) 359 | { 360 | std::cout << "Fails to open file: " << word_tf_file << std::endl; 361 | exit(1); 362 | } 363 | std::string line; 364 | while (stream.getline(line)) 365 | { 366 | std::vector output; 367 | split_string(line, '\t', output); 368 | if (output.size() != 3) 369 | { 370 | std::cout << "Invalid line: " << line << std::endl; 371 | exit(1); 372 | } 373 | int32_t word_id = std::stoi(output[0]); 374 | int32_t tf = std::stoi(output[2]); 375 | auto it = global_tf_map.find(word_id); 376 | if (it != global_tf_map.end()) 377 | { 378 | std::cout << "Duplicate words detected: " << line << std::endl; 379 | exit(1); 380 | } 381 | global_tf_map.insert(std::make_pair(word_id, tf)); 382 | global_tf_count += tf; 383 | } 384 | stream.close(); 385 | } 386 | 387 | int main(int argc, char* argv[]) 388 | { 389 | if (argc != 5) 390 | { 391 | printf("Usage: dump_binary \n"); 392 | exit(1); 393 | } 394 | 395 | std::string libsvm_file_name(argv[1]); 396 | std::string word_dict_file_name(argv[2]); 397 | std::string output_dir(argv[3]); 398 | int32_t output_offset = atoi(argv[4]); 399 | const int32_t kMaxDocLength = 8192; 400 | 401 | // 1. count how many documents in the data set 402 | int64_t doc_num; 403 | count_doc_num(libsvm_file_name, doc_num); 404 | 405 | // 2. load the word_dict file, get the global {word_id, tf} mapping 406 | std::unordered_map global_tf_map; 407 | std::unordered_map local_tf_map; 408 | int64_t global_tf_count = 0; 409 | load_global_tf(global_tf_map, word_dict_file_name, global_tf_count); 410 | int32_t word_num = global_tf_map.size(); 411 | std::cout << "There are totally " << word_num 412 | << " words in the vocabulary" << std::endl; 413 | std::cout << "There are maximally totally " << global_tf_count 414 | << " tokens in the data set" << std::endl; 415 | 416 | // 3. transform the libsvm -> binary block 417 | int64_t* offset_buf = new int64_t[doc_num + 1]; 418 | int32_t *doc_buf = new int32_t[kMaxDocLength * 2 + 1]; 419 | 420 | std::string block_name = output_dir + "/block." + std::to_string(output_offset); 421 | std::string vocab_name = output_dir + "/vocab." + std::to_string(output_offset); 422 | std::string txt_vocab_name = output_dir + "/vocab." + std::to_string(output_offset) + ".txt"; 423 | 424 | // open file 425 | lightlda::utf8_stream libsvm_file; 426 | lightlda::block_stream block_file; 427 | 428 | if (!libsvm_file.open(libsvm_file_name)) 429 | { 430 | std::cout << "Fails to open file: " << libsvm_file_name << std::endl; 431 | exit(1); 432 | } 433 | if (!block_file.open(block_name)) 434 | { 435 | std::cout << "Fails to create file: " << block_name << std::endl; 436 | exit(1); 437 | } 438 | std::ofstream vocab_file(vocab_name, std::ios::out | std::ios::binary); 439 | std::ofstream txt_vocab_file(txt_vocab_name, std::ios::out); 440 | 441 | if (!vocab_file.good()) 442 | { 443 | std::cout << "Fails to create file: " << vocab_name << std::endl; 444 | exit(1); 445 | } 446 | if (!txt_vocab_file.good()) 447 | { 448 | std::cout << "Fails to create file: " << txt_vocab_name << std::endl; 449 | exit(1); 450 | } 451 | 452 | block_file.write_empty_header(offset_buf, doc_num); 453 | 454 | int64_t block_token_num = 0; 455 | std::string str_line; 456 | std::string line; 457 | char *endptr = nullptr; 458 | const int kBASE = 10; 459 | int doc_buf_idx; 460 | 461 | double dump_start = get_time(); 462 | 463 | offset_buf[0] = 0; 464 | for (int64_t j = 0; j < doc_num; ++j) 465 | { 466 | if (!libsvm_file.getline(str_line) || str_line.empty()) 467 | { 468 | std::cout << "Fails to get line" << std::endl; 469 | exit(1); 470 | } 471 | str_line += '\n'; 472 | 473 | std::vector output; 474 | split_string(str_line, '\t', output); 475 | 476 | 477 | if (output.size() != 2) 478 | { 479 | std::cout << "Invalid format, not key TAB val: " << str_line << std::endl; 480 | exit(1); 481 | } 482 | 483 | int doc_token_count = 0; 484 | std::vector doc_tokens; 485 | 486 | char *ptr = &(output[1][0]); 487 | 488 | while (*ptr != '\n') 489 | { 490 | if (doc_token_count >= kMaxDocLength) break; 491 | // read a word_id:count pair 492 | int32_t word_id = strtol(ptr, &endptr, kBASE); 493 | ptr = endptr; 494 | if (':' != *ptr) 495 | { 496 | std::cout << "Invalid input" << str_line << std::endl; 497 | exit(1); 498 | } 499 | int32_t count = strtol(++ptr, &endptr, kBASE); 500 | 501 | ptr = endptr; 502 | for (int k = 0; k < count; ++k) 503 | { 504 | doc_tokens.push_back({ word_id, 0 }); 505 | if (local_tf_map.find(word_id) == local_tf_map.end()) 506 | { 507 | local_tf_map.insert(std::make_pair(word_id, 1)); 508 | } 509 | else 510 | { 511 | local_tf_map[word_id]++; 512 | } 513 | ++block_token_num; 514 | ++doc_token_count; 515 | if (doc_token_count >= kMaxDocLength) break; 516 | } 517 | while (*ptr == ' ' || *ptr == '\r') ++ptr; 518 | } 519 | // The input data may be already sorted 520 | std::sort(doc_tokens.begin(), doc_tokens.end(), Compare); 521 | 522 | doc_buf_idx = 0; 523 | doc_buf[doc_buf_idx++] = 0; // cursor 524 | 525 | for (auto& token : doc_tokens) 526 | { 527 | doc_buf[doc_buf_idx++] = token.word_id; 528 | doc_buf[doc_buf_idx++] = token.topic_id; 529 | } 530 | 531 | block_file.write_doc(doc_buf, doc_buf_idx); 532 | offset_buf[j + 1] = offset_buf[j] + doc_buf_idx; 533 | } 534 | block_file.write_real_header(offset_buf, doc_num); 535 | 536 | int32_t vocab_size = 0; 537 | 538 | vocab_file.write(reinterpret_cast(&vocab_size), sizeof(int32_t)); 539 | 540 | int32_t non_zero_count = 0; 541 | // write vocab 542 | for (int i = 0; i < word_num; ++i) 543 | { 544 | if (local_tf_map[i] > 0) 545 | { 546 | non_zero_count++; 547 | vocab_file.write(reinterpret_cast (&i), sizeof(int32_t)); 548 | } 549 | } 550 | std::cout << "The number of tokens in the output block is: " << block_token_num << std::endl; 551 | std::cout << "Local vocab_size for the output block is: " << non_zero_count << std::endl; 552 | 553 | // write global tf 554 | for (int i = 0; i < word_num; ++i) 555 | { 556 | if (local_tf_map[i] > 0) 557 | { 558 | vocab_file.write(reinterpret_cast (&global_tf_map[i]), sizeof(int32_t)); 559 | } 560 | } 561 | // write local tf 562 | for (int i = 0; i < word_num; ++i) 563 | { 564 | if (local_tf_map[i] > 0) 565 | { 566 | vocab_file.write(reinterpret_cast (&local_tf_map[i]), sizeof(int32_t)); 567 | } 568 | } 569 | vocab_file.seekp(0); 570 | vocab_file.write(reinterpret_cast(&non_zero_count), sizeof(int32_t)); 571 | vocab_file.close(); 572 | 573 | txt_vocab_file << non_zero_count << std::endl; 574 | for (int i = 0; i < word_num; ++i) 575 | { 576 | if (local_tf_map[i] > 0) 577 | { 578 | txt_vocab_file << i << "\t" << global_tf_map[i] << "\t" << local_tf_map[i] << std::endl; 579 | } 580 | } 581 | txt_vocab_file.close(); 582 | 583 | double dump_end = get_time(); 584 | std::cout << "Elapsed seconds for dump blocks: " << (dump_end - dump_start) << std::endl; 585 | 586 | // close file and release resource 587 | libsvm_file.close(); 588 | block_file.close(); 589 | 590 | delete[]offset_buf; 591 | delete[]doc_buf; 592 | return 0; 593 | } 594 | --------------------------------------------------------------------------------