├── .gitignore ├── test ├── basic.cpp └── CMakeLists.txt ├── get_gflags.sh ├── src ├── lda.cpp ├── CMakeLists.txt ├── Partition.hpp ├── Types.hpp ├── clock.hpp ├── Vocab.hpp ├── NumaArray.cpp ├── lda.hpp ├── AdjList.cpp ├── AdjList.hpp ├── Bigraph.hpp ├── Xorshift.hpp ├── clock.cpp ├── Vocab.cpp ├── warplda.hpp ├── Bigraph.cpp ├── NumaArray.hpp ├── main.cpp ├── HashTable.hpp ├── Shuffle.hpp ├── alias_urn.h ├── format.cpp ├── Utils.hpp └── warplda.cpp ├── CMakeLists.txt ├── LICENSE ├── python ├── model.py └── sampling.py ├── data └── uci-to-yahoo └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | debug/ 2 | release/ 3 | gflags/ 4 | *.txt 5 | -------------------------------------------------------------------------------- /test/basic.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int main() 4 | { 5 | printf("OK\n"); 6 | return 0; 7 | } 8 | -------------------------------------------------------------------------------- /get_gflags.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | git clone https://github.com/gflags/gflags.git 4 | cd gflags 5 | git checkout d701ceac73be2c43b6e7b97474184e626fded88b 6 | -------------------------------------------------------------------------------- /src/lda.cpp: -------------------------------------------------------------------------------- 1 | #include "lda.hpp" 2 | 3 | void LDA::loadBinary(std::string fname) 4 | { 5 | if (!g.Load(fname)) 6 | throw std::runtime_error(std::string("Load Binary failed : ") + fname); 7 | printf("Bigraph loaded from %s, %u documents, %u unique tokens, %lu total words\n", fname.c_str(), g.NU(), g.NV(), g.NE()); 8 | } 9 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library (common Bigraph.cpp AdjList.cpp Vocab.cpp NumaArray.cpp clock.cpp) 2 | set (LINK_LIBS ${LINK_LIBS} common gflags numa ) 3 | 4 | add_executable (warplda main.cpp lda.cpp warplda.cpp) 5 | add_executable (format format.cpp) 6 | 7 | target_link_libraries (warplda ${LINK_LIBS}) 8 | target_link_libraries (format ${LINK_LIBS}) 9 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set (LINK_LIBS ${LINK_LIBS} pthread) 2 | 3 | macro (test_program target) 4 | add_executable (${target} "${target}.cpp") 5 | target_link_libraries (${target} ${LINK_LIBS}) 6 | endmacro (test_program) 7 | 8 | macro (do_test target arg result) 9 | add_test (${target}-${arg} ${target} ${arg}) 10 | set_tests_properties (${target}-${arg} PROPERTIES PASS_REGULAR_EXPRESSION ${result}) 11 | endmacro (do_test) 12 | 13 | test_program (basic) 14 | do_test (basic "" "OK") 15 | -------------------------------------------------------------------------------- /src/Partition.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | class Partition 4 | { 5 | public: 6 | //Partition(){} 7 | Partition(int size, int n) 8 | : size_(size) 9 | , n_(n) 10 | , d_(n / size) 11 | , r_(n % size) 12 | {} 13 | int Startid(int par) 14 | { 15 | return d_ * par + (par < r_ ? par : r_ ); 16 | } 17 | int Parid(int idx) 18 | { 19 | return idx / (d_ + 1) < r_ ? idx / (d_ + 1) : (idx - r_) / d_; 20 | } 21 | private: 22 | int size_; 23 | int n_; 24 | int d_; 25 | int r_; 26 | }; 27 | -------------------------------------------------------------------------------- /src/Types.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | //using Int = int; 8 | //using VI = std::vector; 9 | //using PII = std::pair; 10 | //using VII = std::vector; 11 | 12 | #include 13 | 14 | using TUID = uint32_t;// document id 15 | using TVID = uint32_t;// word id 16 | using TEID = uint64_t;// number of tokens 17 | using TDegree = uint32_t; //num of degree of vertex 18 | using TTopic = uint32_t;//k 19 | using TCount = uint32_t;//ck[k] 20 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.8) 2 | 3 | project (gridlda) 4 | 5 | add_subdirectory(gflags) 6 | 7 | # set compiling flags 8 | set (CMAKE_CXX_FLAGS "-march=native -std=c++1y -Wall -g ${CMAKE_CXX_FLAGS}") 9 | 10 | find_package(OpenMP) 11 | if(OPENMP_FOUND) 12 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 13 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 14 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 15 | endif() 16 | 17 | 18 | # set include dircetories 19 | include_directories (./src) 20 | 21 | add_subdirectory (src) 22 | 23 | enable_testing () 24 | add_subdirectory (test) 25 | -------------------------------------------------------------------------------- /src/clock.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __CLOCK 2 | #define __CLOCK 3 | #include 4 | #include 5 | 6 | class Clock { 7 | public: 8 | void start(); 9 | double timeElapsed(); 10 | double restart(); 11 | void pause(); 12 | void resume(); 13 | 14 | template 15 | static double CalcTime(Function f) 16 | { 17 | auto start = std::chrono::high_resolution_clock::now(); 18 | f(); 19 | auto end = std::chrono::high_resolution_clock::now(); 20 | std::chrono::duration diff = end-start; 21 | return diff.count(); 22 | } 23 | Clock():_elapsed(0), _last(0), _started(false){} 24 | 25 | private: 26 | double _elapsed; 27 | double _last; 28 | bool _started; 29 | }; 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /src/Vocab.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | class Vocab 8 | { 9 | private: 10 | std::unordered_map dict; 11 | std::vector words; 12 | public: 13 | Vocab(); 14 | ~Vocab(); 15 | bool load(std::string fname); 16 | bool store(std::string fname); 17 | int addWord(std::string w); 18 | void clear(); 19 | int getIdByWord(std::string w) const; 20 | std::string getWordById(int id) const; 21 | void RearrangeId(const unsigned int* new_id); 22 | int nWords() const; 23 | int operator[](std::string w) const { return getIdByWord(w); } 24 | std::string operator[](int id) const { return getWordById(id); } 25 | }; 26 | -------------------------------------------------------------------------------- /src/NumaArray.cpp: -------------------------------------------------------------------------------- 1 | #include "NumaArray.hpp" 2 | 3 | 4 | NumaInfo::info_t::info_t() 5 | { 6 | #pragma omp parallel 7 | { 8 | int tid = omp_get_thread_num(); 9 | int nid = numa_node_of_cpu(tid); 10 | #pragma omp critical 11 | { 12 | numa_id[tid] = nid; 13 | ord[tid] = info[nid].size(); 14 | info[nid][tid] = info[nid].size(); 15 | //printf("NumaInfo::Init thread id %d at numa %d ord = %d\n", tid, nid, info[nid][tid]); 16 | } 17 | } 18 | } 19 | 20 | NumaInfo::NumaInfo(int thread_id, size_t n) 21 | { 22 | Partition p(info.info.size(), n); 23 | beg = p.Startid(info.numa_id[thread_id])+info.ord[thread_id]; 24 | end = p.Startid(info.numa_id[thread_id]+1); 25 | step = info.info[info.numa_id[thread_id]].size(); 26 | } 27 | 28 | NumaInfo::info_t NumaInfo::info; 29 | -------------------------------------------------------------------------------- /src/lda.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "Bigraph.hpp" 5 | 6 | class LDA 7 | { 8 | protected: 9 | Bigraph g; 10 | public: 11 | LDA() {} 12 | virtual void loadBinary(std::string prefix); 13 | virtual void estimate(int K, float alpha, float beta, int niter, int perplexity_interval, int eval, std::string fmodel, std::string vocab_fname, std::string info, uint32_t ntop) = 0; 14 | virtual void inference(int niter, int perplexity_interval) = 0; 15 | virtual void loadModel(std::string prefix) = 0; 16 | virtual void storeModel(std::string prefix) = 0; 17 | virtual void loadZ(std::string prefix) = 0; 18 | virtual void storeZ(std::string prefix) = 0; 19 | virtual void writeInfo(std::string vocab, std::string info, uint32_t ntop) = 0; 20 | }; 21 | 22 | template 23 | class WarpLDA; 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Jianfei Chen, Kaiwei Li, Jun Zhu and Wenguang Chen 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /src/AdjList.cpp: -------------------------------------------------------------------------------- 1 | #include "AdjList.hpp" 2 | #include "Utils.hpp" 3 | 4 | #include 5 | #include 6 | 7 | 8 | template 9 | void readvec(NumaArray &a, std::ifstream &f, uint64_t idx_beg, uint64_t count) 10 | { 11 | a.Assign(count); 12 | f.seekg(idx_beg * sizeof(T), std::ios::beg); 13 | f.read((char*)&a[0], count * sizeof(T)); 14 | } 15 | 16 | template 17 | bool AdjList::Load(std::string name) 18 | { 19 | std::ifstream fidx(name + ".idx", std::ios::binary); 20 | std::ifstream flnk(name + ".lnk", std::ios::binary); 21 | 22 | if (!fidx || !flnk) 23 | return false; 24 | n_ = filesize(fidx) / sizeof(TEdge) - 1; 25 | 26 | // p_ = Partition(1, n_); 27 | 28 | beg_ = 0; //p_.Startid(0); 29 | end_ = n_; //p_.Startid(1); 30 | n_local_ = end_ - beg_; 31 | 32 | readvec(idx_vec_, fidx, beg_ , end_ - beg_ + 1); 33 | 34 | idx_ = &idx_vec_[0] - beg_; 35 | 36 | readvec(lnk_vec_, flnk, idx_[beg_], idx_[end_] - idx_[beg_]); 37 | 38 | lnk_ = &lnk_vec_[0] - idx_[beg_]; 39 | 40 | ne_ = filesize(flnk) / sizeof(TDst); 41 | 42 | return true; 43 | } 44 | 45 | template bool AdjList::Load(std::string name); 46 | -------------------------------------------------------------------------------- /src/AdjList.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Types.hpp" 4 | //#include "Partition.h" 5 | #include "NumaArray.hpp" 6 | 7 | template 8 | class AdjList 9 | { 10 | //template using NumaArray=NumaArray1; 11 | public: 12 | using TSrc = TSrc_; 13 | using TDst = TDst_; 14 | using TEdge = TEdge_; 15 | using TDegree = TDegree_; 16 | public: 17 | AdjList(){} 18 | bool Load(std::string); 19 | TSrc NumVertices() { return n_; } 20 | TSrc NumVerticesLocal() { return n_local_; } 21 | TEdge NumEdges() { return ne_; } 22 | TDegree Degree(TSrc id) { return idx_[id + 1] - idx_[id]; } 23 | TEdge Idx(TSrc id) { return idx_[id]; } 24 | const TDst* Edges(TSrc id) { return &lnk_[idx_[id]]; } 25 | TSrc Begin() { return beg_; } 26 | TSrc End() { return end_; } 27 | 28 | template 29 | void Visit(Function f) 30 | { 31 | for (TSrc id = beg_; id < end_; id++) 32 | { 33 | f(id, Degree(id), lnk_ + idx_[id]); 34 | } 35 | } 36 | 37 | TEdge * idx_; 38 | TDst * lnk_; 39 | private: 40 | //Partition p_; 41 | TSrc n_; 42 | TSrc n_local_; 43 | TEdge ne_; 44 | TSrc beg_; 45 | TSrc end_; 46 | NumaArray idx_vec_; 47 | NumaArray lnk_vec_; 48 | }; 49 | -------------------------------------------------------------------------------- /src/Bigraph.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "AdjList.hpp" 4 | 5 | class Bigraph 6 | { 7 | public: 8 | using TAdjU = AdjList; 9 | using TAdjV = AdjList; 10 | 11 | private: 12 | TAdjU u; 13 | TAdjV v; 14 | #if 0 15 | std::vector word_id; 16 | #endif 17 | public: 18 | Bigraph(); 19 | bool Load(std::string); 20 | TUID NU() { return u.NumVertices(); } 21 | TVID NV() { return v.NumVertices(); } 22 | TEID NE() { return u.NumEdges(); } 23 | 24 | static void Generate(std::string, std::vector> &, TVID nv = 0); 25 | 26 | template 27 | void VisitU(Function f) 28 | { 29 | u.Visit(f); 30 | } 31 | 32 | template 33 | void VisitV(Function f) 34 | { 35 | v.Visit(f); 36 | } 37 | 38 | TUID DegreeU(TUID uid) 39 | { 40 | return u.Degree(uid); 41 | } 42 | TVID DegreeV(TVID vid) 43 | { 44 | return v.Degree(vid); 45 | } 46 | const TVID* EdgeOfU(TUID uid) 47 | { 48 | return u.Edges(uid); 49 | } 50 | const TUID* EdgeOfV(TVID vid) 51 | { 52 | return v.Edges(vid); 53 | } 54 | TAdjU & AdjU() { return u; } 55 | TAdjV & AdjV() { return v; } 56 | 57 | TEID UIdx(TUID uid) { return u.Idx(uid); } 58 | TEID VIdx(TVID vid) { return v.Idx(vid); } 59 | 60 | TUID Ubegin() { return u.Begin(); } 61 | TUID Uend() { return u.End(); } 62 | 63 | TVID Vbegin() { return v.Begin(); } 64 | TVID Vend() { return v.End(); } 65 | // TVID WordId(TVID vid) { return word_id[vid]; } 66 | }; 67 | -------------------------------------------------------------------------------- /src/Xorshift.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | class XorShift 6 | { 7 | uint64_t s[16]; 8 | int p; 9 | uint64_t x; /* The state must be seeded with a nonzero value. */ 10 | 11 | uint64_t xorshift1024star(void) { 12 | uint64_t s0 = s[ p ]; 13 | uint64_t s1 = s[ p = (p+1) & 15 ]; 14 | s1 ^= s1 << 31; // a 15 | s1 ^= s1 >> 11; // b 16 | s0 ^= s0 >> 30; // c 17 | return ( s[p] = s0 ^ s1 ) * UINT64_C(1181783497276652981); 18 | } 19 | uint64_t xorshift128plus(void) { 20 | uint64_t x = s[0]; 21 | uint64_t const y = s[1]; 22 | s[0] = y; 23 | x ^= x << 23; // a 24 | x ^= x >> 17; // b 25 | x ^= y ^ (y >> 26); // c 26 | s[1] = x; 27 | return x + y; 28 | } 29 | uint64_t xorshift64star(void) { 30 | x ^= x >> 12; // a 31 | x ^= x << 25; // b 32 | x ^= x >> 27; // c 33 | return x * UINT64_C(2685821657736338717); 34 | } 35 | public: 36 | 37 | using result_type=uint64_t; 38 | 39 | XorShift() : p(0), x((uint64_t)std::rand() * RAND_MAX + std::rand()){ 40 | for (int i = 0; i < 16; i++) 41 | { 42 | s[i] = xorshift64star(); 43 | } 44 | } 45 | uint64_t operator()(){ 46 | return xorshift128plus(); 47 | //return xorshift64star(); 48 | } 49 | uint32_t Rand32(){ 50 | return (uint32_t)xorshift128plus(); 51 | } 52 | void MakeBuffer(void *p, size_t len) 53 | { 54 | int N = len / sizeof(uint32_t); 55 | uint32_t *arr = (uint32_t *)p; 56 | for (int i = 0; i < N; i++) 57 | arr[i] = (*this)(); 58 | int M = len % sizeof(uint32_t); 59 | if (M > 0) 60 | { 61 | uint32_t k = (*this)(); 62 | memcpy(arr + N, &k, M); 63 | } 64 | } 65 | uint64_t max() {return std::numeric_limits::max();} 66 | uint64_t min() {return std::numeric_limits::min();} 67 | }; 68 | 69 | -------------------------------------------------------------------------------- /python/model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2018 hschen0712 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | 读取warplda模型 11 | """ 12 | from collections import defaultdict 13 | import numpy as np 14 | 15 | class Warplda(object): 16 | def __init__(self, model_path, vocab_path): 17 | with open(model_path, 'rb') as fr: 18 | lines = [line.strip() for line in fr.readlines()] 19 | 20 | params = lines[0].split() 21 | self.vocab_size = int(params[0]) #词表大小 22 | self.num_topics = int(params[1]) #主题数 23 | self.alpha = float(params[2]) 24 | self.beta = float(params[3]) 25 | self.alpha_bar = self.alpha * self.num_topics 26 | self.beta_bar = self.beta * self.vocab_size 27 | self.Cvk = np.zeros((self.vocab_size, self.num_topics)) 28 | 29 | 30 | for word_id, line in enumerate(lines[1:]): 31 | num_elements, word_topic_cnt = line.split('\t') 32 | word_id = int(word_id) 33 | pairs = word_topic_cnt.split() 34 | for pair in pairs: 35 | topic_id, cnt = pair.split(':') 36 | topic_id = int(topic_id) 37 | cnt = int(cnt) 38 | self.Cvk[word_id, topic_id] = cnt 39 | self.Ck = self.Cvk.sum(axis=0) 40 | # 词表加载 41 | self.vocab_map = {} 42 | with open(vocab_path) as fr: 43 | lines = [line.strip().decode('utf8') for line in fr.readlines()] 44 | for word_id, word in enumerate(lines): 45 | self.vocab_map[word] = word_id 46 | 47 | 48 | 49 | if __name__ == '__main__': 50 | warplda = Warplda('./train.model.iter9900', '.train.vocab') 51 | 52 | -------------------------------------------------------------------------------- /src/clock.cpp: -------------------------------------------------------------------------------- 1 | #include "clock.hpp" 2 | 3 | // Windows 4 | #ifdef _WIN32 5 | #include 6 | double get_wall_time(){ 7 | LARGE_INTEGER time,freq; 8 | if (!QueryPerformanceFrequency(&freq)){ 9 | // Handle error 10 | return 0; 11 | } 12 | if (!QueryPerformanceCounter(&time)){ 13 | // Handle error 14 | return 0; 15 | } 16 | return (double)time.QuadPart / freq.QuadPart; 17 | } 18 | double get_cpu_time(){ 19 | FILETIME a,b,c,d; 20 | if (GetProcessTimes(GetCurrentProcess(),&a,&b,&c,&d) != 0){ 21 | // Returns total user time. 22 | // Can be tweaked to include kernel times as well. 23 | return 24 | (double)(d.dwLowDateTime | 25 | ((unsigned long long)d.dwHighDateTime << 32)) * 0.0000001; 26 | }else{ 27 | // Handle error 28 | return 0; 29 | } 30 | } 31 | 32 | // Posix/Linux 33 | #else 34 | #include 35 | #include 36 | double get_wall_time(){ 37 | struct timeval time; 38 | if (gettimeofday(&time,NULL)){ 39 | // Handle error 40 | return 0; 41 | } 42 | return (double)time.tv_sec + (double)time.tv_usec * .000001; 43 | } 44 | double get_cpu_time(){ 45 | return (double)clock() / CLOCKS_PER_SEC; 46 | } 47 | #endif 48 | 49 | void Clock::start() 50 | { 51 | _last = get_wall_time(); 52 | _elapsed = 0; 53 | _started = true; 54 | } 55 | 56 | double Clock::restart() 57 | { 58 | double ret = timeElapsed(); 59 | start(); 60 | return ret; 61 | } 62 | 63 | double Clock::timeElapsed() { 64 | if (_started) 65 | return _elapsed + get_wall_time() - _last; 66 | else 67 | return _elapsed; 68 | 69 | } 70 | 71 | void Clock::pause() { 72 | _elapsed += get_wall_time() - _last; 73 | _started = false; 74 | } 75 | 76 | void Clock::resume() { 77 | _last = get_wall_time(); 78 | _started = true; 79 | } 80 | -------------------------------------------------------------------------------- /data/uci-to-yahoo: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse, random 3 | 4 | parser = argparse.ArgumentParser(description="Convert data from UCI ML repository format to NIPS format.") 5 | parser.add_argument('uci', metavar='uci', type=str, help='Corpus in UCI format') 6 | parser.add_argument('dict', metavar='dict', type=str, help='Dictionary in UCI format') 7 | parser.add_argument('-o', metavar='o', type=str, help='Output corpus') 8 | parser.add_argument('--max_doc_length', metavar='max_doc_length', type=int, default=8192, help='Max document length, will subsample words if this is exceeded') 9 | 10 | args = parser.parse_args() 11 | print 'Max doc length is %d' % args.max_doc_length 12 | 13 | wlist = map(lambda x:x.rstrip(), open(args.dict, 'r').readlines()) 14 | 15 | f = open(args.uci, 'r') 16 | try: 17 | fo = open(args.o, 'w') 18 | except AttributeError: 19 | args.output = 'yahoo_' + args.uci 20 | fo = open(args.output, 'w') 21 | 22 | d = int(f.readline().rstrip()) 23 | w = int(f.readline().rstrip()) 24 | nnz = int(f.readline().rstrip()) 25 | 26 | subsampled = 0 27 | def write_doc(doc_id1, doc_id2, doc): 28 | global subsampled 29 | if len(doc) == 0: 30 | return 31 | random.shuffle(doc) 32 | 33 | if len(doc) > args.max_doc_length: 34 | subsampled += len(doc) - args.max_doc_length 35 | doc = doc[:args.max_doc_length] 36 | 37 | fo.write(str(doc_id1) + ' ' + str(doc_id2) + ' ') 38 | fo.write(' '.join(doc)) 39 | fo.write('\n') 40 | 41 | last_docid = 1 42 | doc = [] 43 | while True: 44 | line = f.readline() 45 | if line == '': 46 | break 47 | 48 | line = line.rstrip().split(' ') 49 | docid = int(line[0]) 50 | 51 | if docid != last_docid: 52 | last_docid = docid 53 | write_doc(docid-1, docid-1, doc) 54 | doc = [] 55 | 56 | wid = int(line[1]) 57 | times = int(line[2]) 58 | for i in xrange(times): 59 | doc.append(wlist[wid-1]) 60 | 61 | write_doc(last_docid, last_docid, doc) 62 | f.close() 63 | fo.close() 64 | 65 | print '%d tokens truncated' % subsampled 66 | -------------------------------------------------------------------------------- /src/Vocab.cpp: -------------------------------------------------------------------------------- 1 | #include "Vocab.hpp" 2 | #include "Utils.hpp" 3 | 4 | #include 5 | #include 6 | 7 | Vocab::Vocab() 8 | { 9 | } 10 | 11 | Vocab::~Vocab() 12 | { 13 | } 14 | 15 | void Vocab::clear() 16 | { 17 | dict.clear(); 18 | words.clear(); 19 | } 20 | 21 | bool Vocab::load(std::string fname) 22 | { 23 | clear(); 24 | int id = 0; 25 | bool success = ForEachLinesInFile(fname, [&](std::string line) 26 | { 27 | std::istringstream sin(line); 28 | std::string word; 29 | sin >> word; 30 | dict[word] = id++; 31 | words.push_back(word); 32 | }); 33 | return success; 34 | } 35 | 36 | bool Vocab::store(std::string fname) 37 | { 38 | std::ofstream fou(fname); 39 | if (!fou) return false; 40 | std::string line; 41 | for (unsigned i = 0; i < words.size(); i++) 42 | { 43 | fou << words[i] << std::endl; 44 | } 45 | fou.close(); 46 | return true; 47 | } 48 | 49 | int Vocab::addWord(std::string w) 50 | { 51 | auto it = dict.find(w); 52 | if (it == dict.end()) 53 | { 54 | dict[w] = words.size(); 55 | words.push_back(w); 56 | return dict[w]; 57 | }else 58 | return it->second; 59 | } 60 | 61 | int Vocab::getIdByWord(std::string w) const 62 | { 63 | auto it = dict.find(w); 64 | if (it == dict.end()) 65 | return -1; 66 | else 67 | return it->second; 68 | } 69 | 70 | std::string Vocab::getWordById(int id) const 71 | { 72 | if (id < 0 || id >= (int)words.size()) 73 | return ""; 74 | else 75 | return words[id]; 76 | } 77 | 78 | int Vocab::nWords() const 79 | { 80 | return words.size(); 81 | } 82 | 83 | void Vocab::RearrangeId(const unsigned int* new_id) 84 | { 85 | for (auto &e : this->dict) 86 | e.second = new_id[e.second]; 87 | decltype(words) new_words(words.size()); 88 | for (unsigned i = 0; i < words.size(); i++) 89 | new_words[new_id[i]] = words[i]; 90 | words = std::move(new_words); 91 | } 92 | -------------------------------------------------------------------------------- /python/sampling.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2018 hschen0712 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | Collapsed Gibbs Sampling 11 | """ 12 | import numpy as np 13 | import json 14 | from model import Warplda 15 | from collections import Counter 16 | 17 | def gibbs_sampling(doc, model, num_iter=10): 18 | ''' 19 | gibbs采样 20 | :param doc: 预处理后的文档 21 | :param model: warplda模型 22 | :param num_iter: 迭代次数,默认为10 23 | :return: 24 | ''' 25 | vocab_map = model.vocab_map 26 | doc = [vocab_map[word] for word in doc.split() if word in vocab_map] 27 | 28 | K = model.num_topics 29 | num_tokens = len(doc) 30 | alpha = model.alpha 31 | alpha_bar = model.alpha_bar 32 | beta = model.beta 33 | beta_bar = model.beta_bar 34 | Cvk = model.Cvk 35 | Ck = model.Ck 36 | z = np.zeros((num_tokens, K)) 37 | # 随机初始化每个token的指派 38 | for n in range(num_tokens): 39 | rand_topic = np.random.randint(0, K) 40 | z[n, rand_topic] = 1 41 | 42 | for i in range(num_iter): 43 | for n, word_id in enumerate(doc): 44 | pz = np.divide(np.multiply(z.sum(axis=0) + alpha, Cvk[word_id, :] + beta), Ck + beta_bar) 45 | k = np.random.multinomial(1, pz / pz.sum()).argmax() 46 | z[n, :] *= 0 47 | z[n, k] = 1 48 | topic_cnt = Counter(z.argmax(axis=1)) 49 | topic_dist = [(topic_id, (cnt + alpha)/(num_tokens + alpha_bar)) for topic_id, cnt in topic_cnt.iteritems()] 50 | topic_dist = json.dumps(dict([(topic_id, prob) for topic_id, prob in topic_dist if prob >=0.05 ])) 51 | 52 | return topic_dist 53 | 54 | 55 | if __name__ == '__main__': 56 | from time import time 57 | 58 | docs = [line.strip().decode('utf8') for line in open('./test_corpus_for_lda.dat').readlines()] 59 | 60 | warplda = Warplda('./train.model.iter9900', './train.vocab') 61 | start = time() 62 | for d, doc in enumerate(docs): 63 | print 'doc {}'.format(d) 64 | topic_dist = gibbs_sampling(doc, warplda) 65 | print topic_dist 66 | stop = time() 67 | 68 | print stop - start 69 | -------------------------------------------------------------------------------- /src/warplda.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "lda.hpp" 5 | #include "HashTable.hpp" 6 | #include "Xorshift.hpp" 7 | #include "Utils.hpp" 8 | #include "Shuffle.hpp" 9 | #include "alias_urn.h" 10 | 11 | template 12 | class WarpLDA; 13 | 14 | template 15 | class WarpLDA : public LDA 16 | { 17 | public: 18 | WarpLDA(); 19 | virtual void estimate(int K, float alpha, float beta, int niter, int perplexity_interval, int neval, std::string fmodel, std::string vocab_fname, std::string info, uint32_t ntop) override; 20 | virtual void inference(int niter, int perplexity_interval) override; 21 | virtual void loadModel(std::string prefix) override; 22 | virtual void storeModel(std::string prefix) override; 23 | virtual void loadZ(std::string prefix) override; 24 | virtual void storeZ(std::string prefix) override; 25 | virtual void writeInfo(std::string vocab, std::string info, uint32_t ntop) override; 26 | 27 | private: 28 | struct TData 29 | { 30 | TTopic newk[MH]; 31 | TTopic oldk; 32 | }; 33 | TTopic K; 34 | float alpha, beta, alpha_bar, beta_bar; 35 | int niter, perplexity_interval; 36 | NumaArray nnz_d; 37 | NumaArray nnz_w; 38 | NumaArray ck; 39 | std::unique_ptr> shuffle; 40 | XorShift generator; 41 | std::vector> cwk_model; 42 | std::vector cwk_urns; 43 | std::vector cwk_sums; 44 | AliasUrn global_urn; 45 | double global_sum; 46 | double total_log_likelihood; // p(w | \alpha, \beta) 47 | template 48 | double perplexity(); 49 | template 50 | void initialize(); 51 | template 52 | void accept_d_propose_w(); 53 | template 54 | void accept_w_propose_d(); 55 | void reduce_ck(); 56 | struct LocalBuffer{ 57 | std::vector ck_new; 58 | HashTable cxk_sparse; 59 | std::vector local_data; 60 | float log_likelihood; 61 | XorShift generator; 62 | uint32_t Rand32() { return generator.Rand32(); } 63 | LocalBuffer(TTopic K, TDegree maxdegree) 64 | : ck_new(K), cxk_sparse(logceil(K)), local_data(maxdegree) 65 | { 66 | } 67 | void Init(); 68 | }; 69 | std::vector> local_buffers; 70 | }; 71 | 72 | extern template class WarpLDA<1>; 73 | -------------------------------------------------------------------------------- /src/Bigraph.cpp: -------------------------------------------------------------------------------- 1 | #include "Bigraph.hpp" 2 | 3 | #include 4 | #include 5 | 6 | 7 | Bigraph::Bigraph() 8 | { 9 | } 10 | 11 | bool Bigraph::Load(std::string name) 12 | { 13 | if (u.Load(name + ".u") && v.Load(name + ".v")) 14 | { 15 | return true; 16 | #if 0 17 | std::ifstream fwordid(name + ".wordid", std::ios::binary); 18 | if (fwordid){ 19 | word_id.resize(NV()); 20 | if (fwordid.read((char*)&word_id[0], NV() * sizeof(TVID))) 21 | return true; 22 | } 23 | #endif 24 | } 25 | return false; 26 | } 27 | 28 | void Bigraph::Generate(std::string name, std::vector>& edge_list, TVID nv) 29 | { 30 | TUID nu = 0; 31 | for (auto &e : edge_list) 32 | nu = std::max(nu, e.first); 33 | nu = nu + 1; 34 | 35 | if (nv == 0) { 36 | for (auto &e : edge_list) 37 | nv = std::max(nv, e.second); 38 | nv = nv + 1; 39 | } 40 | 41 | #if 0 42 | std::vector pu(nu); 43 | std::vector pv(nv); 44 | for (TUID i = 0; i < nu; i++) 45 | pu[i] = i; 46 | for (TVID i = 0; i < nv; i++) 47 | pv[i] = i; 48 | std::random_shuffle(pu.begin(), pu.end()); 49 | std::random_shuffle(pv.begin(), pv.end()); 50 | for (auto &e : edge_list) 51 | { 52 | e.first = pu[e.first]; 53 | e.second = pv[e.second]; 54 | } 55 | 56 | std::vector word_id(nv); 57 | for (TVID i = 0; i < nv; i++) 58 | word_id[pv[i]] = i; 59 | std::ofstream fwordid(name + ".wordid", std::ios::binary); 60 | fwordid.write((char*)&word_id[0], nv * sizeof(TVID)); 61 | fwordid.close(); 62 | #endif 63 | 64 | std::ofstream fuidx(name + ".u.idx", std::ios::binary); 65 | std::ofstream fulnk(name + ".u.lnk", std::ios::binary); 66 | std::ofstream fvidx(name + ".v.idx", std::ios::binary); 67 | std::ofstream fvlnk(name + ".v.lnk", std::ios::binary); 68 | 69 | std::sort(edge_list.begin(), edge_list.end(), [](const std::pair & a, const std::pair & b){ return a.first < b.first || (a.first == b.first && a.second < b.second); }); 70 | 71 | TEID off = 0; 72 | fuidx.write((char*)&off, sizeof(off)); 73 | for (TUID i = 1; i <= nu; i++) 74 | { 75 | while (off < edge_list.size() && edge_list[off].first < i) 76 | { 77 | auto tar = edge_list[off].second; 78 | fulnk.write((char*)&tar, sizeof(tar)); 79 | off++; 80 | } 81 | fuidx.write((char*)&off, sizeof(off)); 82 | } 83 | 84 | std::sort(edge_list.begin(), edge_list.end(), [](const std::pair & a, const std::pair & b){ return a.second < b.second || (a.second == b.second && a.first < b.first); }); 85 | 86 | off = 0; 87 | fvidx.write((char*)&off, sizeof(off)); 88 | for (TVID i = 1; i <= nv; i++) 89 | { 90 | while (off < edge_list.size() && edge_list[off].second < i) 91 | { 92 | auto src = edge_list[off].first; 93 | fvlnk.write((char*)&src, sizeof(src)); 94 | off++; 95 | } 96 | fvidx.write((char*)&off, sizeof(off)); 97 | } 98 | 99 | fuidx.close(); 100 | fulnk.close(); 101 | fvidx.close(); 102 | fvlnk.close(); 103 | } 104 | -------------------------------------------------------------------------------- /src/NumaArray.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Partition.hpp" 3 | #include "omp.h" 4 | #include 5 | #include 6 | #include "Types.hpp" 7 | 8 | template 9 | class RemoteArray 10 | { 11 | using TIdx = int; 12 | public: 13 | RemoteArray() : arr_(nullptr), idx_(nullptr){} 14 | RemoteArray(T* arr, const TIdx* idx) : arr_(arr), idx_(idx){} 15 | void Assign(T* arr, const TIdx* idx) { arr_ = arr; idx_ = idx; } 16 | T& operator[](size_t pos) { return arr_[idx_[pos]]; } 17 | const T& operator[](size_t pos) const { return arr_[idx_[pos]]; } 18 | private: 19 | T* arr_; 20 | const TIdx* idx_; 21 | }; 22 | 23 | 24 | template 25 | class RemoteArray64 26 | { 27 | using TIdx = uint64_t; 28 | public: 29 | RemoteArray64() : arr_(nullptr), idx_(nullptr){} 30 | RemoteArray64(T* arr, TIdx* idx) : arr_(arr), idx_(idx){} 31 | void Assign(T* arr, TIdx* idx) { arr_ = arr; idx_ = idx; } 32 | T& operator[](size_t pos) { return arr_[idx_[pos]]; } 33 | const T& operator[](size_t pos) const { return arr_[idx_[pos]]; } 34 | private: 35 | T* arr_; 36 | TIdx* idx_; 37 | }; 38 | 39 | 40 | class NumaInfo 41 | { 42 | private: 43 | struct info_t 44 | { 45 | std::map> info; 46 | std::map ord; 47 | std::map numa_id; 48 | info_t(); 49 | }; 50 | static info_t info; 51 | public: 52 | NumaInfo(int thread_id, size_t n); 53 | size_t beg, end, step; 54 | }; 55 | 56 | template 57 | class NumaArray 58 | { 59 | public: 60 | NumaArray() : arr_(nullptr) {} 61 | ~NumaArray() { Free(); } 62 | NumaArray(size_t n, T v = T()) 63 | { 64 | Assign(n, v); 65 | } 66 | T* data() { return arr_; } 67 | const T* data() const { return arr_; } 68 | void Assign(size_t n, T v = T()) 69 | { 70 | Free(); 71 | arr_ = new T[n]; 72 | #pragma omp parallel for schedule(static) 73 | for (size_t i = 0; i < n; i++) 74 | arr_[i] = v; 75 | size_ = n; 76 | } 77 | void Free() 78 | { 79 | if (arr_) 80 | delete[] arr_; 81 | } 82 | size_t size() { return size_; } 83 | T& operator[](size_t pos) { return arr_[pos]; } 84 | const T& operator[](size_t pos) const { return arr_[pos]; } 85 | private: 86 | T* arr_; 87 | size_t size_; 88 | }; 89 | 90 | 91 | #if 0 92 | template 93 | class NumaArray1 94 | { 95 | public: 96 | NumaArray1() : arr_(nullptr) {} 97 | ~NumaArray1() { Free(); } 98 | NumaArray1(size_t n, T v = T()) 99 | { 100 | Assign(n, v); 101 | } 102 | void Assign(size_t n, T v = T()) 103 | { 104 | Free(); 105 | arr_ = new T[n]; 106 | //#pragma omp parallel for schedule(static) 107 | for (size_t i = 0; i < n; i++) 108 | arr_[i] = v; 109 | size_ = n; 110 | } 111 | void Free() 112 | { 113 | if (arr_) 114 | delete[] arr_; 115 | } 116 | size_t size() { return size_; } 117 | T& operator[](size_t pos) { return arr_[pos]; } 118 | const T& operator[](size_t pos) const { return arr_[pos]; } 119 | private: 120 | T* arr_; 121 | size_t size_; 122 | }; 123 | 124 | #endif 125 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #define STRIP_FLAG_HELP 0 2 | #include 3 | #include "Bigraph.hpp" 4 | #include "Utils.hpp" 5 | #include "warplda.hpp" 6 | 7 | DEFINE_string(prefix, "./prefix", "prefix of result files"); 8 | DEFINE_int32(niter, 10, "number of iterations"); 9 | DEFINE_int32(neval, 100, "save model every neval iterations"); 10 | DEFINE_int32(k, 1000, "number of topics"); 11 | DEFINE_double(alpha, 50, "sum of alpha"); 12 | DEFINE_double(beta, 0.01, "beta"); 13 | DEFINE_int32(mh, 1, "number of Metropolis-Hastings steps"); 14 | DEFINE_int32(ntop, 10, "num top words per each topic"); 15 | DEFINE_string(bin, "", "binary file"); 16 | DEFINE_string(model, "", "model file"); 17 | DEFINE_string(info, "", "info"); 18 | DEFINE_string(vocab, "", "vocabulary file"); 19 | DEFINE_string(topics, "", "topic assignment file"); 20 | DEFINE_string(z, "", "Z file name"); 21 | DEFINE_bool(estimate, false, "estimate model parameters"); 22 | DEFINE_bool(inference, false, "inference latent topic assignments"); 23 | DEFINE_bool(writeinfo, true, "write info"); 24 | DEFINE_bool(dumpmodel, true, "dump model"); 25 | DEFINE_bool(dumpz, true, "dump Z"); 26 | DEFINE_int32(perplexity, -1, "Interval to evaluate perplexity. -1 for don't evaluate."); 27 | 28 | int main(int argc, char** argv) 29 | { 30 | gflags::SetUsageMessage("Usage : ./warplda [ flags... ]"); 31 | gflags::ParseCommandLineFlags(&argc, &argv, true); 32 | 33 | if ((FLAGS_inference || FLAGS_estimate) == false) 34 | FLAGS_estimate = true; 35 | if (!FLAGS_z.empty()) 36 | FLAGS_dumpz = true; 37 | 38 | SetIfEmpty(FLAGS_bin, FLAGS_prefix + ".bin"); 39 | SetIfEmpty(FLAGS_model, FLAGS_prefix + ".model"); 40 | SetIfEmpty(FLAGS_info, FLAGS_prefix + ".info"); 41 | SetIfEmpty(FLAGS_vocab, FLAGS_prefix + ".vocab"); 42 | SetIfEmpty(FLAGS_topics, FLAGS_prefix + ".topics"); 43 | 44 | LDA *lda = new WarpLDA<1>(); 45 | lda->loadBinary(FLAGS_bin); 46 | if (FLAGS_estimate) 47 | { 48 | lda->estimate(FLAGS_k, FLAGS_alpha / FLAGS_k, FLAGS_beta, FLAGS_niter, FLAGS_perplexity, FLAGS_neval, FLAGS_model, FLAGS_vocab, FLAGS_info, FLAGS_ntop); 49 | if (FLAGS_dumpmodel) 50 | { 51 | std::cout << "Dump model " << FLAGS_model << std::endl; 52 | lda->storeModel(FLAGS_model); 53 | } 54 | if (FLAGS_writeinfo) 55 | { 56 | std::cout << "Write Info " << FLAGS_info << " ntop " << FLAGS_ntop << std::endl; 57 | lda->writeInfo(FLAGS_vocab, FLAGS_info, FLAGS_ntop); 58 | } 59 | if (FLAGS_dumpz) 60 | { 61 | SetIfEmpty(FLAGS_z, FLAGS_prefix + ".z.estimate"); 62 | std::cout << "Dump Z " << FLAGS_z << std::endl; 63 | lda->storeZ(FLAGS_z); 64 | } 65 | } 66 | else if(FLAGS_inference) 67 | { 68 | lda->loadModel(FLAGS_model); 69 | lda->inference(FLAGS_niter, FLAGS_perplexity); 70 | if (FLAGS_dumpz) 71 | { 72 | SetIfEmpty(FLAGS_z, FLAGS_prefix + ".z.inference"); 73 | std::cout << "Dump Z " << FLAGS_z << std::endl; 74 | lda->storeZ(FLAGS_z); 75 | } 76 | } 77 | return 0; 78 | } 79 | -------------------------------------------------------------------------------- /src/HashTable.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | template 6 | class HashTable 7 | { 8 | public: 9 | using key_type = TKey; 10 | using value_type = TValue; 11 | using reference = value_type&; 12 | using const_reference = const value_type&; 13 | const key_type EMPTY_KEY = key_type(-1); 14 | 15 | //std::vector count_step[2]; 16 | public: // Jianfei: hack 17 | struct Entry { 18 | key_type key; 19 | value_type value; 20 | int32_t l, r; 21 | }; 22 | std::vector table; 23 | std::vector keyset; 24 | uint32_t sizeFactor; 25 | uint32_t sizeMask; 26 | uint32_t nKey; 27 | 28 | template 29 | uint32_t findkey(key_type key) 30 | { 31 | uint32_t pos = key & sizeMask; 32 | if (table[pos].key == key) 33 | { 34 | return pos; 35 | }else if (table[pos].key == EMPTY_KEY) 36 | { 37 | if (AddKey) 38 | { 39 | table[pos] = Entry{key, value_type(), -1, -1}; 40 | nKey++; 41 | return pos; 42 | }else 43 | { 44 | return -1; 45 | } 46 | }else 47 | { 48 | int32_t father = -1; 49 | bool isleft = true; 50 | int i = 0; 51 | while (pos != (uint32_t)-1 && table[pos].key != key) 52 | { 53 | i++; 54 | if (key < table[pos].key) 55 | { 56 | if (AddKey) 57 | { 58 | father = pos; 59 | isleft = true; 60 | } 61 | pos = table[pos].l; 62 | }else 63 | { 64 | if (AddKey) 65 | { 66 | father = pos; 67 | isleft = false; 68 | } 69 | pos = table[pos].r; 70 | } 71 | } 72 | if (pos == (uint32_t)-1) 73 | { 74 | int current = table.size(); 75 | if (AddKey) 76 | { 77 | if (isleft) 78 | table[father].l = current; 79 | else 80 | table[father].r = current; 81 | table.push_back(Entry{key, value_type(), -1, -1}); 82 | nKey++; 83 | return current; 84 | }else 85 | { 86 | return -1; 87 | } 88 | }else 89 | { 90 | return pos; 91 | } 92 | } 93 | } 94 | 95 | public: 96 | HashTable(size_t sizeFactor = 0) 97 | { 98 | Rebuild(sizeFactor); 99 | } 100 | HashTable& operator = (const HashTable &from) 101 | { 102 | table = from.table; 103 | keyset = from.keyset; 104 | sizeFactor = from.sizeFactor; 105 | sizeMask = from.sizeMask; 106 | nKey = from.nKey; 107 | return *this; 108 | } 109 | reference Put(key_type key) 110 | { 111 | uint32_t pos = findkey(key); 112 | return table[pos].value; 113 | } 114 | value_type Get(key_type key) 115 | { 116 | uint32_t pos = findkey(key); 117 | if (pos == (uint32_t)-1) 118 | return value_type(); 119 | else 120 | return table[pos].value; 121 | } 122 | void Rebuild(size_t sizeFactor) 123 | { 124 | sizeMask = (1< 6 | #include 7 | #include 8 | #include 9 | 10 | #include "Types.hpp" 11 | #include "Bigraph.hpp" 12 | #include "Partition.hpp" 13 | #include "NumaArray.hpp" 14 | 15 | template 16 | class Shuffle 17 | { 18 | public: 19 | Shuffle(Bigraph &g) : g_(g) { Init(); } 20 | ~Shuffle() {} 21 | public: 22 | T* DataV(TVID v) { return &v_data_vec_[g_.VIdx(v)]; } 23 | 24 | template 25 | void VisitURemoteData(Function f) 26 | { 27 | #pragma omp parallel 28 | { 29 | int thread_id = omp_get_thread_num(); 30 | NumaInfo info(thread_id, g_.NU()); 31 | for (TUID u = info.beg; u < info.end; u+= info.step) 32 | { 33 | TDegree N = g_.DegreeU(u); 34 | const TVID* lnks = g_.EdgeOfU(u); 35 | RemoteArray64 data = RemoteArray64(DataV(0), &v2u_shuffle_pos_[g_.UIdx(u)]); 36 | f(u, N, lnks, data); 37 | } 38 | } 39 | } 40 | template 41 | void VisitURemoteDataSequential(Function f) 42 | { 43 | for (TUID u = g_.Ubegin(); u < g_.Uend(); u++) 44 | { 45 | TDegree N = g_.DegreeU(u); 46 | const TVID* lnks = g_.EdgeOfU(u); 47 | RemoteArray64 data = RemoteArray64(DataV(0), &v2u_shuffle_pos_[g_.UIdx(u)]); 48 | f(u, N, lnks, data); 49 | } 50 | } 51 | template 52 | void VisitByV(Function f) 53 | { 54 | #pragma omp parallel for 55 | for (TVID v = g_.Vbegin(); v < g_.Vend(); v++) 56 | { 57 | TDegree N = g_.DegreeV(v); 58 | const TUID* lnks = g_.EdgeOfV(v); 59 | T* data = DataV(v); 60 | f(v, N, lnks, data); 61 | } 62 | } 63 | static void shuffle_gather( NumaArray &src_data, NumaArray &tar_data, NumaArray& shuffle_pos) 64 | { 65 | #pragma omp parallel for //schedule(static, 256) 66 | for (TEID i = 0; i < shuffle_pos.size(); i++) 67 | tar_data[i] = src_data[shuffle_pos[i]]; 68 | } 69 | static void shuffle_scatter( NumaArray &src_data, NumaArray &tar_data, NumaArray& shuffle_pos) 70 | { 71 | #pragma omp parallel for //schedule(static, 256) 72 | for (TEID i = 0; i < shuffle_pos.size(); i++) 73 | tar_data[shuffle_pos[i]] = src_data[i]; 74 | } 75 | 76 | private: 77 | void Init() 78 | { 79 | cnt_u_data = g_.EdgeOfU(g_.Uend()) - g_.EdgeOfU(g_.Ubegin()); 80 | cnt_v_data = g_.EdgeOfV(g_.Vend()) - g_.EdgeOfV(g_.Vbegin()); 81 | 82 | v_data_vec_.Assign(cnt_v_data); 83 | InitShuffle(g_.AdjV(), g_.AdjU(), v2u_shuffle_pos_); 84 | } 85 | 86 | template 87 | static void InitShuffle(TAdjsrc &src_adj, TAdjtar &tar_adj, 88 | NumaArray& shuffle_pos) 89 | { 90 | using TSrc = typename TAdjsrc::TSrc; 91 | using TDst = typename TAdjsrc::TDst; 92 | 93 | shuffle_pos.Assign(tar_adj.NumEdges()); 94 | 95 | int64_t threshold = (1<<20) / sizeof(TEID); 96 | 97 | std::vector src_off(src_adj.NumVertices(), 0); 98 | for (TSrc i = 1; i < src_adj.NumVertices(); i++) 99 | src_off[i] = src_off[i-1] + src_adj.Degree(i-1); 100 | 101 | TDst tar_id = 0; 102 | TEID pos = 0; 103 | while (tar_id < tar_adj.NumVertices()) 104 | { 105 | TDst tar_end_id = tar_id; 106 | while (tar_end_id < tar_adj.NumVertices() && tar_adj.Edges(tar_end_id) - tar_adj.Edges(tar_id) < threshold) 107 | { 108 | TDst id = tar_end_id; 109 | const TSrc* lnks = tar_adj.Edges(id); 110 | for (TDegree k = 0; k < tar_adj.Degree(id); k++) 111 | { 112 | shuffle_pos[pos++] = src_off[lnks[k]]++; 113 | } 114 | tar_end_id++; 115 | } 116 | tar_id = tar_end_id; 117 | } 118 | } 119 | public: 120 | Bigraph &g_; 121 | NumaArray v_data_vec_; 122 | NumaArray v2u_shuffle_pos_; 123 | size_t cnt_u_data; 124 | size_t cnt_v_data; 125 | }; 126 | -------------------------------------------------------------------------------- /src/alias_urn.h: -------------------------------------------------------------------------------- 1 | #ifndef __ALIAS_URN 2 | #define __ALIAS_URN 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "Types.hpp" 10 | using std::vector; 11 | using std::cout; 12 | using std::endl; 13 | 14 | class AliasUrn { 15 | public: 16 | struct AliasEntry { 17 | int LoKey, HiKey; 18 | long long p; 19 | AliasEntry(): LoKey(0), HiKey(0), p(0) {} 20 | AliasEntry(int l, int h, int p): LoKey(l), HiKey(h), p(p) {} 21 | }; 22 | 23 | virtual void BuildAlias(const vector &p, uint32_t u) 24 | { 25 | if (p.empty()) 26 | { 27 | table.resize(0); 28 | return; 29 | } 30 | std::uniform_real_distribution u01; 31 | long long nMax = std::numeric_limits::max(); 32 | 33 | size = p.size(); 34 | this->binSize = nMax / size; 35 | table.resize(size); 36 | 37 | float totalMass = 0; 38 | for (size_t i=0; i= binSize 45 | long long remaining = nMax; 46 | int pLo = 0; 47 | int pHi = size - 1; 48 | for (int i=0; i 0) { 69 | //int pos = pHi + 1 + u % nHi; 70 | //if (pos < 0 || pos > table.size()) { 71 | // cout << pos << " " << size << " " << table.size() << endl; 72 | //} 73 | //table[pHi + 1 + u % nHi].p += remaining; 74 | table.back().p += remaining; 75 | } else { 76 | table.back().p += remaining; 77 | assert(table.back().p >= binSize); 78 | pHi--; 79 | } 80 | 81 | int uHi = pHi + 1; 82 | 83 | // Build alias table 84 | // while lo is not empty, pick up a lo and a hi 85 | // create a table entry 86 | // put the remaining of hi back 87 | for (int i=0; i &keys) 109 | { 110 | for (auto &entry: table) 111 | { 112 | entry.LoKey = keys[entry.LoKey]; 113 | entry.HiKey = keys[entry.HiKey]; 114 | } 115 | } 116 | 117 | int DrawSample(size_t rSize, float u2) 118 | { 119 | if (table.empty()) assert(0); 120 | int bin = rSize % size; 121 | auto &entry = table[bin]; 122 | long long pos = u2 * binSize; 123 | 124 | return pos u01; 130 | vector table; 131 | 132 | long long size; 133 | long long binSize; 134 | }; 135 | 136 | /*class ParallelAliasUrn : public AliasUrn { 137 | // User should guarantee only one ParallelAliasUrn.Build is called at the same time 138 | public: 139 | void BuildAlias(const vector &p, uint32_t u); 140 | };*/ 141 | 142 | #endif 143 | -------------------------------------------------------------------------------- /src/format.cpp: -------------------------------------------------------------------------------- 1 | #define STRIP_FLAG_HELP 0 2 | 3 | #include 4 | 5 | #include "Bigraph.hpp" 6 | #include "Vocab.hpp" 7 | #include "Utils.hpp" 8 | #include 9 | #include 10 | using namespace std; 11 | 12 | DEFINE_string(prefix, "./prefix", "prefix of output files"); 13 | DEFINE_string(vocab_in, "", "input vocabulary file"); 14 | DEFINE_string(vocab_out, "", "output vocabulary file"); 15 | DEFINE_string(input, "", "input file"); 16 | DEFINE_string(output, "", "output file"); 17 | DEFINE_string(type, "text", "type of input: text, uci, libsvm"); 18 | DEFINE_int32(skip, 2, "skip num of words at first of each line (only for text)"); 19 | DEFINE_bool(test, false, "test mode (throw away unseen words)"); 20 | 21 | template 22 | void parse_document(string &line, std::vector &v, Vocab &vocab) 23 | { 24 | std::istringstream sin(line); 25 | std::string w; 26 | v.clear(); 27 | if (FLAGS_type == "text") 28 | { 29 | for (int i = 0; sin >> w; i++) 30 | if (i >= FLAGS_skip) 31 | { 32 | TVID vid; 33 | if (!testMode) vid = vocab.addWord(w); 34 | else vid = vocab.getIdByWord(w); 35 | 36 | if (vid != -1) v.push_back(vid); 37 | } 38 | } 39 | else if (FLAGS_type == "uci") 40 | { 41 | } 42 | else if (FLAGS_type == "libsvm") 43 | { 44 | } 45 | else 46 | { 47 | throw std::runtime_error(std::string("Unknown input type " + FLAGS_type)); 48 | } 49 | } 50 | 51 | template 52 | void text_to_bin(std::string in, std::string out) 53 | { 54 | size_t num_tokens = 0; 55 | Vocab v; 56 | if (testMode || FLAGS_type != "text") 57 | v.load(FLAGS_vocab_in); 58 | int doc_id = 0; 59 | std::vector> edge_list; 60 | std::vector vlist; 61 | bool success = ForEachLinesInFile(in, [&](std::string line) 62 | { 63 | parse_document(line, vlist, v); 64 | for (auto word_id: vlist) 65 | edge_list.emplace_back(doc_id, word_id); 66 | 67 | doc_id++; 68 | num_tokens += vlist.size(); 69 | }); 70 | if (!success) 71 | throw std::runtime_error(std::string("Failed to input file ") + in); 72 | // Shuffle tokens 73 | std::vector new_vid(v.nWords()); 74 | for (unsigned i = 0; i < new_vid.size(); i++) 75 | new_vid[i] = i; 76 | if (!testMode) 77 | std::random_shuffle(new_vid.begin(), new_vid.end()); 78 | v.RearrangeId(new_vid.data()); 79 | v.store(FLAGS_vocab_out); 80 | for (auto &e : edge_list) 81 | e.second = new_vid[e.second]; 82 | if (!testMode) 83 | Bigraph::Generate(out, edge_list); 84 | else 85 | Bigraph::Generate(out, edge_list, v.nWords()); 86 | cout << "Done. Processed " << num_tokens << " tokens." << endl; 87 | } 88 | 89 | int main(int argc, char** argv) 90 | { 91 | gflags::SetUsageMessage("Usage : ./transform [ flags... ]"); 92 | gflags::ParseCommandLineFlags(&argc, &argv, true); 93 | 94 | if (FLAGS_input.empty()) 95 | FLAGS_input = FLAGS_prefix + ".txt"; 96 | if (FLAGS_output.empty()) 97 | FLAGS_output = FLAGS_prefix + ".bin"; 98 | if (FLAGS_vocab_out.empty()) 99 | FLAGS_vocab_out = FLAGS_prefix + ".vocab"; 100 | if (FLAGS_vocab_in.empty() && (FLAGS_test || FLAGS_type != "text")) 101 | throw runtime_error("Input vocabulary is not specified."); 102 | if (FLAGS_vocab_in == FLAGS_vocab_out) 103 | throw runtime_error("Input prefix and output prefix are the same."); 104 | 105 | cout << "Reading corpus from " << FLAGS_input << endl; 106 | if (FLAGS_test || FLAGS_type != "text") 107 | cout << "Reading vocabulary from " << FLAGS_vocab_in << endl; 108 | else 109 | cout << "Vocabulary will be wrote to " << FLAGS_vocab_out << endl; 110 | cout << "Output will be wrote as " << FLAGS_output << endl; 111 | 112 | if (FLAGS_test) 113 | text_to_bin(FLAGS_input, FLAGS_output); 114 | else 115 | text_to_bin(FLAGS_input, FLAGS_output); 116 | 117 | return 0; 118 | } 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WarpLDA: Cache Efficient Implementation of Latent Dirichlet Allocation 2 | 3 | ## Introduction 4 | 5 | WarpLDA is a cache efficient implementation of Latent Dirichlet Allocation, which samples each token in O(1). 6 | 7 | ## Installation 8 | Prerequisites: 9 | 10 | * GCC (>=4.8.5) 11 | * CMake (>=2.8.12) 12 | * git 13 | * libnuma 14 | - CentOS: `yum install libnuma-devel` 15 | - Ubuntu: `apt-get install libnuma-dev` 16 | 17 | Clone this project 18 | 19 | git clone https://github.com/thu-ml/warplda 20 | 21 | Install third-party dependency 22 | 23 | ./get_gflags.sh 24 | 25 | Download some data, and split it as training and testing set 26 | 27 | cd data 28 | wget https://raw.githubusercontent.com/sudar/Yahoo_LDA/master/test/ydir_1k.txt 29 | head -n 900 ydir_1k.txt > ydir_train.txt 30 | tail -n 100 ydir_1k.txt > ydir_test.txt 31 | cd .. 32 | 33 | Compile the project 34 | 35 | ./build.sh 36 | cd release/src 37 | make -j 38 | 39 | ## Quick-start 40 | 41 | Format the data 42 | 43 | ./format -input ../../data/ydir_train.txt -prefix train 44 | ./format -input ../../data/ydir_test.txt -vocab_in train.vocab -test -prefix test 45 | 46 | Train the model 47 | 48 | ./warplda --prefix train --k 100 --niter 300 49 | 50 | Check the result. Each line is a topic, its id, number of tokens assigned to it, and ten most frequent words with their probabilities. 51 | 52 | vim train.info.full.txt 53 | 54 | Infer latent topics of some testing data. 55 | 56 | ./warplda --prefix test --model train.model --inference -niter 40 --perplexity 10 57 | 58 | ## Data format 59 | 60 | The data format is identical to Yahoo! LDA. The input data is a text file with a number of lines, where each line is a document. The format of each line is 61 | 62 | id1 id2 word1 word2 word3 ... 63 | 64 | id1, id2 are two string document identifiers, and each word is a string, separated by white space. 65 | 66 | ## Output format 67 | 68 | WarpLDA generates a number of files: 69 | 70 | #### `.vocab` (generated by `.format`) 71 | Each line of it is a word in the vocabulary. 72 | 73 | #### `.info.full.txt` (generated by `warplda -estimate`) 74 | The most frequent words for each topic. Each line is a topic, with its topic it, number of tokens assigned to it, and a number of most frequent words in the format `(probability, word)`. The number of most frequent words is controlled by `-ntop`. `.info.words.txt` is a simpler version which only contains words. 75 | 76 | #### `.model` (generated by `warplda -estimate`) 77 | The word-topic count matrix. The first line contains four integers 78 | 79 | 80 | 81 | Each of the remaining lines is a row of the word-topic count matrix, represented in the libsvm sparse vector format, 82 | 83 | index:count index:count ... 84 | 85 | For example, `0:2` on the first line means that the first word in the vocabulary is assigned to topic 0 for 2 times. 86 | 87 | #### `.z.estimate` (generated by `warplda -estimate`) 88 | The topic assignments of each token in the libsvm format. Each line is a document, 89 | 90 | : : ... 91 | 92 | #### `.z.inference` (generated by `warplda -inference`) 93 | The format is the same as `.z.estimate`. 94 | 95 | ## Other features 96 | 97 | * Use custom prefix for output `-prefix myprefix` 98 | * Output perplexity every 10 iterations `-perplexity 10` 99 | * Tune Dirichlet hyperparameters `-alpha 10 -beta 0.1` 100 | * Use UCI machine learning repository data 101 | 102 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/vocab.nips.txt 103 | wget https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/docword.nips.txt.gz 104 | gunzip docword.nips.txt.gz 105 | ./uci-to-yahoo docword.nips.txt vocab.nips.txt -o nips.txt 106 | head -n 1400 nips.txt > nips_train.txt 107 | tail -n 100 nips.txt > nips_test.txt 108 | 109 | ## License 110 | 111 | MIT 112 | 113 | ## Reference 114 | 115 | Please cite WarpLDA if you find it is useful! 116 | 117 | @inproceedings{chen2016warplda, 118 | title={WarpLDA: a Cache Efficient O(1) Algorithm for Latent Dirichlet Allocation}, 119 | author={Chen, Jianfei and Li, Kaiwei and Zhu, Jun and Chen, Wenguang}, 120 | booktitle={VLDB}, 121 | year={2016} 122 | } 123 | -------------------------------------------------------------------------------- /src/Utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | //#include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | //#include 10 | #include 11 | 12 | inline uint32_t logceil(uint32_t n) 13 | { 14 | for (uint32_t i = 0; i < 31; i++) 15 | if ((1U << i) >= n) 16 | return i; 17 | return -1; 18 | } 19 | 20 | inline unsigned int Divup(unsigned int a, unsigned int b) 21 | { 22 | return (a + b - 1) / b; 23 | } 24 | 25 | inline std::vector ReadlinesFromFile(std::string fname) 26 | { 27 | std::vector ret; 28 | std::ifstream fin(fname); 29 | if (!fin) 30 | { 31 | std::cerr << " : ReadlinesFromFile " << fname << " Failed" << std::endl; 32 | abort(); 33 | } 34 | std::string line; 35 | while (std::getline(fin, line)) 36 | { 37 | ret.push_back(line); 38 | } 39 | return ret; 40 | } 41 | 42 | inline void SetIfEmpty(std::string &s, std::string t) 43 | { 44 | if (s.empty()) 45 | s = t; 46 | } 47 | 48 | template 49 | bool ForEachLinesInFile(std::string fname, Function f) 50 | { 51 | std::ifstream fin(fname); 52 | if (!fin) return false; 53 | std::string line; 54 | while (std::getline(fin, line)) 55 | { 56 | f(line); 57 | } 58 | return true; 59 | } 60 | #if 0 61 | template 62 | void ForEachLinesInFile(std::string fname, Function f) 63 | { 64 | if (fname.substr(fname.size() - 3) == ".gz") 65 | { 66 | igzstream fin(fname.c_str()); 67 | if (!fin) 68 | { 69 | std::cerr << " : Open " << fname << " Failed" << std::endl; 70 | abort(); 71 | } 72 | 73 | std::string line; 74 | while (std::getline(fin, line)) 75 | { 76 | std::istringstream sin(line); 77 | f(sin); 78 | } 79 | }else 80 | { 81 | std::ifstream fin(fname); 82 | if (!fin) 83 | { 84 | std::cerr << " : Open " << fname << " Failed" << std::endl; 85 | abort(); 86 | } 87 | 88 | std::string line; 89 | while (std::getline(fin, line)) 90 | { 91 | std::istringstream sin(line); 92 | f(sin); 93 | } 94 | } 95 | } 96 | #endif 97 | 98 | inline std::string operator+(std::string str, int x) 99 | { 100 | std::ostringstream ss; 101 | ss << str << x; 102 | return ss.str(); 103 | } 104 | 105 | inline ssize_t Filesize(std::istream &fin) 106 | { 107 | auto pos = fin.tellg(); 108 | fin.seekg(0, std::ios_base::end); 109 | auto sz = fin.tellg(); 110 | fin.seekg(pos); 111 | return sz; 112 | } 113 | 114 | /* 115 | template 116 | void MyMPISerialize(Function f, int mpi_size, int mpi_rank) 117 | { 118 | MPI_Barrier(MPI_COMM_WORLD); 119 | for (int i = 0 ; i < mpi_size; i++) 120 | { 121 | if (i == mpi_rank) 122 | { 123 | std::cout << "Serialize rank " << i << std::endl; 124 | f(); 125 | } 126 | MPI_Barrier(MPI_COMM_WORLD); 127 | } 128 | } 129 | 130 | template 131 | struct MyMPIDataType 132 | { 133 | static MPI_Datatype get_type(); 134 | }; 135 | 136 | template <> 137 | struct MyMPIDataType 138 | { 139 | static MPI_Datatype get_type(){ return MPI_INT; } 140 | }; 141 | 142 | template <> 143 | struct MyMPIDataType 144 | { 145 | static MPI_Datatype get_type(){ return MPI_UNSIGNED; } 146 | }; 147 | 148 | template <> 149 | struct MyMPIDataType 150 | { 151 | static MPI_Datatype get_type(){ return MPI_UNSIGNED_LONG; } 152 | }; 153 | */ 154 | 155 | inline void Memoryinfo(double &tot, double &used) 156 | { 157 | long phypz = sysconf(_SC_PHYS_PAGES); 158 | long psize = sysconf(_SC_PAGE_SIZE); 159 | long avphys = sysconf(_SC_AVPHYS_PAGES); 160 | tot = 1.0/(1L<<30)*psize*phypz; 161 | used = 1.0/(1L<<30)*avphys*psize; 162 | } 163 | static size_t filesize(std::ifstream &fs) 164 | { 165 | size_t last = fs.tellg(); 166 | fs.seekg(0, std::ios::end); 167 | size_t ret = fs.tellg(); 168 | fs.seekg(last, std::ios::beg); 169 | return ret; 170 | 171 | } 172 | 173 | template 174 | bool ReadVector(T & vec, std::string fname) 175 | { 176 | std::ifstream f(fname, std::ios::binary); 177 | if (!f) 178 | return false; 179 | size_t sz = filesize(f); 180 | vec.resize(sz / sizeof(T::value_type)); 181 | if (!f.read((char*)vec.data(), sz)) 182 | return false; 183 | f.close(); 184 | return true; 185 | } 186 | 187 | template 188 | bool WriteVector(T & vec, std::string fname) 189 | { 190 | std::ofstream f(fname, std::ios::binary); 191 | if (!f) 192 | return false; 193 | if (!f.write((char*)vec.data(), sizeof(T::value_type) * vec.size())) 194 | return false; 195 | f.close(); 196 | return true; 197 | } 198 | -------------------------------------------------------------------------------- /src/warplda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "warplda.hpp" 5 | #include "Vocab.hpp" 6 | #include "clock.hpp" 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | const int LOAD_FACTOR = 4; 14 | 15 | template 16 | WarpLDA::WarpLDA() : LDA() 17 | { 18 | 19 | } 20 | 21 | template 22 | void WarpLDA::reduce_ck() 23 | { 24 | #pragma omp parallel for 25 | for (TTopic i = 0; i < K; i++) 26 | { 27 | TCount s = 0; 28 | for (auto& buffer : local_buffers) 29 | s += buffer->ck_new[i]; 30 | ck[i] = s; 31 | } 32 | } 33 | 34 | template 35 | template 36 | void WarpLDA::initialize() 37 | { 38 | shuffle.reset(new Shuffle(g)); 39 | alpha_bar = alpha * K; 40 | beta_bar = beta * g.NV(); 41 | 42 | if (!testMode) 43 | ck.Assign(K, 0); 44 | 45 | nnz_d.Assign(g.NU()); 46 | nnz_w.Assign(g.NV()); 47 | 48 | TDegree max_degree_u = 0; 49 | 50 | local_buffers.resize(omp_get_max_threads()); 51 | 52 | #pragma omp parallel for reduction(max:max_degree_u) 53 | for (TUID i = 0; i < g.NU(); i++) 54 | { 55 | max_degree_u = std::max(max_degree_u, g.DegreeU(i)); 56 | nnz_d[i] = g.DegreeU(i); 57 | } 58 | #pragma omp parallel for 59 | for (TVID i = 0; i < g.NV(); i++) 60 | nnz_w[i] = g.DegreeV(i); 61 | 62 | #pragma omp parallel 63 | { 64 | #pragma omp critical 65 | local_buffers[omp_get_thread_num()].reset(new LocalBuffer(K, max_degree_u)); 66 | } 67 | 68 | shuffle->VisitByV([&](TVID v, TDegree N, const TUID* lnks, TData* data) 69 | { 70 | LocalBuffer* buffer = local_buffers[omp_get_thread_num()].get(); 71 | TCount* ck_new = buffer->ck_new.data(); 72 | for (TDegree i = 0; i < N; i++) 73 | { 74 | TTopic k = buffer->generator() % K; 75 | data[i].oldk = k; 76 | for (unsigned mh=0; mh 89 | void WarpLDA::LocalBuffer::Init() 90 | { 91 | std::fill(ck_new.begin(), ck_new.end(), 0); 92 | log_likelihood = 0; 93 | } 94 | 95 | template 96 | template 97 | void WarpLDA::accept_d_propose_w() 98 | { 99 | #pragma omp parallel 100 | { 101 | local_buffers[omp_get_thread_num()]->Init(); 102 | } 103 | 104 | shuffle->VisitByV([&](TVID v, TDegree N, const TUID* lnks, TData* data) 105 | { 106 | LocalBuffer *local_buffer = local_buffers[omp_get_thread_num()].get(); 107 | TCount* ck_new = local_buffer->ck_new.data(); 108 | 109 | HashTable* cxk; 110 | if (testMode) 111 | cxk = &cwk_model[v]; 112 | else 113 | { 114 | cxk = &local_buffer->cxk_sparse; 115 | cxk->Rebuild(logceil(std::min(K, nnz_w[v] * LOAD_FACTOR))); 116 | for (TDegree i=0; iPut(data[i].oldk)++; 118 | } 119 | } 120 | 121 | // Perplexity 122 | float lgammabeta = lgamma(beta); 123 | for (auto entry: cxk->table) 124 | if (entry.key != cxk->EMPTY_KEY) 125 | local_buffer->log_likelihood += lgamma(beta+entry.value) - lgammabeta; 126 | 127 | int remove_current = !testMode; 128 | for (TDegree i = 0; i < N; i++) 129 | { 130 | TTopic oldk = data[i].oldk; 131 | TTopic originalk = data[i].oldk; 132 | float b = cxk->Get(oldk)+beta-remove_current; 133 | float d = ck[oldk]+beta_bar-remove_current; 134 | //#pragma simd 135 | for (unsigned mh=0; mhGet(newk)+beta - (newk==originalk); 140 | c = ck[newk]+beta_bar - (newk==originalk); 141 | } else { 142 | a = cxk->Get(newk)+beta; 143 | c = ck[newk]+beta_bar; 144 | } 145 | float ad = a*d; 146 | float bc = b*c; 147 | bool accept = local_buffer->Rand32() *bc < ad * std::numeric_limits::max(); 148 | if (accept) { 149 | oldk = newk; 150 | b = a; d = c; 151 | } 152 | } 153 | data[i].oldk = oldk; 154 | if (!testMode) 155 | ck_new[oldk]++; 156 | } 157 | nnz_w[v] = cxk->NKey(); 158 | 159 | double new_topic = K*beta / (K*beta + N); 160 | uint32_t new_topic_th = std::numeric_limits::max() * new_topic; 161 | 162 | // TODO this propose is incorrect! 163 | // TODO use alias table 164 | if (!testMode) { 165 | for (TDegree i = 0; i < N; i++) 166 | { 167 | for (unsigned mh=0; mhRand32(); 170 | uint32_t rk = local_buffer->Rand32() % K; 171 | uint32_t rn = local_buffer->Rand32() % N; 172 | data[i].newk[mh] = r < new_topic_th ? rk : data[rn].oldk; 173 | } 174 | } 175 | } else { 176 | unsigned global_th = std::numeric_limits::max() 177 | * (global_sum / (global_sum + cwk_sums[v])); 178 | double one_scale = 1.0 / ((double)std::numeric_limits::max()+1); 179 | for (TDegree i = 0; i < N; i++) 180 | { 181 | for (unsigned mh=0; mhRand32(); 184 | if (r < global_th) 185 | data[i].newk[mh] = global_urn.DrawSample(local_buffer->Rand32(), local_buffer->Rand32() * one_scale); 186 | else 187 | data[i].newk[mh] = cwk_urns[v].DrawSample(local_buffer->Rand32(), local_buffer->Rand32() * one_scale); 188 | } 189 | } 190 | } 191 | }); 192 | 193 | if (!testMode) 194 | reduce_ck(); 195 | 196 | for (TTopic i = 0; i < K; i++) 197 | total_log_likelihood += -lgamma(ck[i]+beta_bar) + lgamma(beta_bar); 198 | 199 | for (auto &buffer : local_buffers) 200 | total_log_likelihood += buffer->log_likelihood; 201 | } 202 | 203 | template 204 | template 205 | void WarpLDA::accept_w_propose_d() 206 | { 207 | #pragma omp parallel 208 | { 209 | local_buffers[omp_get_thread_num()]->Init(); 210 | } 211 | shuffle->VisitURemoteData([&](TUID d, TDegree N, const TVID* lnks, RemoteArray64 &data) 212 | { 213 | LocalBuffer *local_buffer = local_buffers[omp_get_thread_num()].get(); 214 | TCount* ck_new = local_buffers[omp_get_thread_num()]->ck_new.data(); 215 | TData * local_data = local_buffer->local_data.data(); 216 | 217 | auto& cxk = local_buffer->cxk_sparse; 218 | cxk.Rebuild(logceil(std::min(K, nnz_d[d] * LOAD_FACTOR))); 219 | 220 | for (TDegree i=0; ilog_likelihood += lgamma(alpha+entry.value) - lgammaalpha; 230 | 231 | local_buffer->log_likelihood -= lgamma(alpha_bar+N) - lgamma(alpha_bar); 232 | 233 | for (TDegree i = 0; i < N; i++) 234 | { 235 | TTopic oldk = local_data[i].oldk; 236 | TTopic originalk = local_data[i].oldk; 237 | float b = cxk.Get(oldk)+alpha-1; 238 | float d = ck[oldk]+beta_bar-1; 239 | //#pragma simd 240 | for (unsigned mh=0; mhRand32() *bc < ad * std::numeric_limits::max(); 248 | if (accept) { 249 | oldk = newk; 250 | b = a; d = c; 251 | } 252 | } 253 | if (!testMode) 254 | ck_new[oldk]++; 255 | local_data[i].oldk = oldk; 256 | } 257 | nnz_d[d] = cxk.NKey(); 258 | 259 | double new_topic = alpha_bar / (alpha_bar + N); 260 | uint32_t new_topic_th = std::numeric_limits::max() * new_topic; 261 | for (TDegree i = 0; i < N; i++) 262 | { 263 | data[i].oldk = local_data[i].oldk; 264 | for (unsigned mh=0; mhRand32(); 266 | uint32_t rk = local_buffer->Rand32() % K; 267 | uint32_t rn = local_buffer->Rand32() % N; 268 | data[i].newk[mh] = r < new_topic_th ? rk : local_data[rn].oldk; 269 | } 270 | } 271 | }); 272 | if (!testMode) 273 | reduce_ck(); 274 | for (auto& buffer : local_buffers) 275 | total_log_likelihood += buffer->log_likelihood; 276 | } 277 | 278 | template 279 | void WarpLDA::estimate(int _K, float _alpha, float _beta, int _niter, int _perperplexity_interval, int neval, std::string fmodel, std::string vocab_fname, std::string info, uint32_t ntop) 280 | { 281 | this->K = _K; 282 | this->alpha = _alpha; 283 | this->beta = _beta; 284 | this->niter = _niter; 285 | initialize(); 286 | 287 | for (int i = 0; i < niter; i++) 288 | { 289 | Clock clk; 290 | clk.start(); 291 | 292 | total_log_likelihood = 0; 293 | accept_d_propose_w(); 294 | accept_w_propose_d(); 295 | 296 | double ppl = 0; 297 | bool eval_perplexity = _perperplexity_interval != -1 && i % _perperplexity_interval == 0; 298 | // Evaluate perplexity p(w_d | \hat\theta, \hat\phi) 299 | if (eval_perplexity) 300 | ppl = perplexity(); 301 | 302 | double time_elapsed = clk.timeElapsed(); 303 | time_t cur_time = time(0); 304 | struct tm * now = localtime( & cur_time ); 305 | printf("[%d-%d-%d %d:%d:%d] Iteration %d, %f s, %.2f Mtokens/s, log_likelihood (per token) %lf", now->tm_year + 1900, now->tm_mon + 1, now->tm_mday, now->tm_hour,now->tm_min,now->tm_sec, i, time_elapsed, (double)g.NE()/time_elapsed/1e6, total_log_likelihood/g.NE()); 306 | if (eval_perplexity) printf(" perplexity %lf\n", ppl); 307 | else printf("\n"); 308 | fflush(stdout); 309 | if (i > 0 && i % neval == 0) { 310 | std::string model_file_name; 311 | std::stringstream ss; 312 | ss << i; 313 | model_file_name = fmodel + ".iter" + ss.str(); 314 | this->storeModel(model_file_name); 315 | this->writeInfo(vocab_fname, info + ".iter" + ss.str(), ntop); 316 | } 317 | } 318 | } 319 | 320 | template 321 | void WarpLDA::inference(int niter, int _perperplexity_interval) 322 | { 323 | initialize(); 324 | for (int i = 0; i < niter; i++) 325 | { 326 | Clock clk; 327 | clk.start(); 328 | 329 | accept_d_propose_w(); 330 | accept_w_propose_d(); 331 | 332 | double tm = clk.timeElapsed(); 333 | 334 | double ppl = 0; 335 | bool eval_perplexity = _perperplexity_interval != -1 && i % _perperplexity_interval == 0; 336 | // Evaluate perplexity p(w_d | \hat\theta, \hat\phi) 337 | if (eval_perplexity) 338 | ppl = perplexity(); 339 | 340 | // Evaluate likelihood p(w_d | \hat\theta, \hat\phi) 341 | printf("Iteration %d, %f s, %.2f Mtokens/s ", i, tm, (double)g.NE()/tm/1e6); 342 | if (eval_perplexity) printf(" perplexity %lf\n", ppl); 343 | else printf("\n"); 344 | fflush(stdout); 345 | } 346 | } 347 | 348 | template 349 | void WarpLDA::loadModel(std::string fmodel) 350 | { 351 | std::ifstream fin(fmodel); 352 | if (!fin) 353 | throw std::runtime_error(std::string("Failed to load model file : ") + fmodel); 354 | TVID nv = 0; 355 | fin >> nv >> K >> alpha >> beta; 356 | cwk_model.clear(); 357 | cwk_model.resize(nv); 358 | cwk_urns.resize(nv); 359 | cwk_sums.resize(nv); 360 | ck.Assign(K); 361 | alpha_bar = alpha * K; 362 | beta_bar = beta * g.NV(); 363 | 364 | for (TVID v = 0; v < g.NV(); v++) 365 | { 366 | auto& cwk = cwk_model[v]; 367 | TTopic nkey; 368 | fin >> nkey; 369 | cwk.Rebuild(logceil(nkey * LOAD_FACTOR)); 370 | for (TDegree i = 0; i < nkey; i++) { 371 | TTopic k; 372 | TCount c; 373 | fin >> k; fin.ignore(); fin >> c; 374 | cwk.Put(k) = c; 375 | ck[k] += c; 376 | } 377 | } 378 | fin.close(); 379 | 380 | std::vector kds; 381 | std::vector probs(K); 382 | global_sum = 0; 383 | for (TTopic k = 0; k < K; k++) 384 | global_sum += probs[k] = beta / (ck[k] + beta_bar); 385 | global_urn.BuildAlias(probs, generator.Rand32()); 386 | for (TVID v = 0; v < g.NV(); v++) 387 | { 388 | kds.clear(); 389 | probs.clear(); 390 | auto &cwk = cwk_model[v]; 391 | cwk_sums[v] = 0; 392 | for (auto &entry: cwk.table) if (entry.key != cwk.EMPTY_KEY) { 393 | kds.push_back(entry.key); 394 | double p = entry.value / (ck[entry.key] + beta_bar); 395 | probs.push_back(p); 396 | cwk_sums[v] += p; 397 | } 398 | cwk_urns[v].BuildAlias(probs, generator.Rand32()); 399 | cwk_urns[v].SetKeys(kds); 400 | } 401 | } 402 | 403 | template 404 | void WarpLDA::storeModel(std::string fmodel) 405 | { 406 | cwk_model.clear(); 407 | cwk_model.resize(g.NV()); 408 | shuffle->VisitByV([&](TVID v, TDegree N, const TUID* lnks, TData* data) 409 | { 410 | auto& cxk = cwk_model[v]; 411 | cxk.Rebuild(logceil(std::min(K, nnz_w[v] * LOAD_FACTOR))); 412 | for (TDegree i=0; i 433 | void WarpLDA::loadZ(std::string filez) 434 | { 435 | throw std::runtime_error("This method is not implemented"); 436 | //std::ifstream fin(filez); 437 | //shuffle->VisitURemoteDataSequential([&](TUID d, TDegree N, const TVID* lnks, RemoteArray64 &data) 438 | //{ 439 | // for (unsigned i = 0; i < N; i++) 440 | // { 441 | // fin >> data[i].oldk; 442 | // } 443 | //}); 444 | //fin.close(); 445 | } 446 | 447 | template 448 | void WarpLDA::storeZ(std::string filez) { 449 | std::ofstream fou(filez); 450 | shuffle->VisitURemoteDataSequential([&](TUID d, TDegree N, const TVID* lnks, RemoteArray64 &data) 451 | { 452 | fou << N << ' '; 453 | for (unsigned i = 0; i < N; i++) 454 | { 455 | fou << lnks[i] << ':' << data[i].oldk << ' '; 456 | } 457 | fou << '\n'; 458 | }); 459 | fou.close(); 460 | } 461 | 462 | template 463 | void WarpLDA::writeInfo(std::string vocab_fname, std::string info, uint32_t ntop) 464 | { 465 | Vocab vocab; 466 | if (!vocab.load(vocab_fname)) 467 | throw std::runtime_error(std::string("Failed to load vocab file : ") + vocab_fname); 468 | 469 | std::vector>>> result; //result[thread][k][10](value, word) 470 | result.resize(omp_get_max_threads()); 471 | #pragma omp parallel 472 | result[omp_get_thread_num()].resize(K); 473 | shuffle->VisitByV([&](TVID v, TDegree N, const TUID* lnks, TData* data){ 474 | int tid = omp_get_thread_num(); 475 | auto &result_local = result[tid]; 476 | std::unordered_map cnt; 477 | for (TDegree i = 0; i < N; i++) 478 | { 479 | cnt[data[i].oldk]++; 480 | } 481 | for (auto t : cnt) 482 | { 483 | TTopic k = t.first; 484 | TCount c = t.second; //ckw 485 | auto &r = result_local[k]; 486 | double value = double(c + beta)/(ck[k]+beta_bar); 487 | //printf("ckw %d %d = %lf\n", k, v, value); 488 | std::pair p(value, v); 489 | r.push_back(p); 490 | std::push_heap(r.begin(), r.end(), std::greater>()); 491 | if (r.size() > ntop) 492 | { 493 | std::pop_heap(r.begin(), r.end(), std::greater>()); 494 | r.pop_back(); 495 | } 496 | } 497 | }); 498 | std::ofstream fou1(info+".full.txt"); 499 | std::ofstream fou2(info+".words.txt"); 500 | std::vector>> ans; 501 | for (TTopic k = 0; k < K; k++) 502 | { 503 | ans.resize(K); 504 | auto &a = ans[k]; 505 | for (unsigned tid = 0; tid < result.size(); tid++) 506 | { 507 | auto &r = result[tid][k]; 508 | for (auto &p : r) 509 | { 510 | a.push_back(p); 511 | std::push_heap(a.begin(), a.end(), std::greater>()); 512 | if (a.size() > ntop) 513 | { 514 | std::pop_heap(a.begin(), a.end(), std::greater>()); 515 | a.pop_back(); 516 | } 517 | } 518 | } 519 | std::sort(a.rbegin(), a.rend()); 520 | fou1 << "Topic #" << k << ':' << '\n'; 521 | for (auto &p : a) 522 | { 523 | std::string word = vocab.getWordById(p.second); 524 | fou1 << word << '\t' << p.first << '\n'; 525 | fou2 << word << " "; 526 | } 527 | fou1 << std::endl; 528 | fou2 << std::endl; 529 | } 530 | fou1.close(); 531 | fou2.close(); 532 | 533 | } 534 | 535 | template 536 | template 537 | double WarpLDA::perplexity() 538 | { 539 | std::vector>> cdk(g.NU()); 540 | shuffle->VisitURemoteData([&](TUID d, TDegree N, const TVID* lnks, RemoteArray64 &data) { 541 | LocalBuffer *local_buffer = local_buffers[omp_get_thread_num()].get(); 542 | 543 | auto& cxk = local_buffer->cxk_sparse; 544 | cxk.Rebuild(logceil(std::min(K, nnz_d[d] * LOAD_FACTOR))); 545 | 546 | for (TDegree i=0; iInit(); 565 | } 566 | 567 | shuffle->VisitByV([&](TVID v, TDegree N, const TUID* lnks, TData* data) 568 | { 569 | LocalBuffer *local_buffer = local_buffers[omp_get_thread_num()].get(); 570 | 571 | auto &cxk = !testMode ? local_buffer->cxk_sparse : cwk_model[v]; 572 | double alpha_term = 0; 573 | if (!testMode) { 574 | cxk.Rebuild(logceil(std::min(K, nnz_w[v] * LOAD_FACTOR))); 575 | for (TDegree i=0; i 1) { 590 | std::cout << cxk.Get(entry.first) << ' ' << beta << ' ' << ck[entry.first] << ' ' << beta_bar << std::endl; 591 | throw std::runtime_error("what?"); 592 | } 593 | } 594 | prob += alpha_term; 595 | prob /= (L + alpha_bar); 596 | local_buffer->log_likelihood += log(prob); 597 | } 598 | }); 599 | double log_likelihood = 0; 600 | for (auto &buffer : local_buffers) 601 | log_likelihood += buffer->log_likelihood; 602 | 603 | return exp(-log_likelihood / g.NE()); 604 | } 605 | 606 | template class WarpLDA<1>; 607 | --------------------------------------------------------------------------------