├── README.md ├── vocabulary_for_mglda.py └── mglda.py /README.md: -------------------------------------------------------------------------------- 1 | mglda with gibbs sampling 2 | MIT lisence (C) Masanao Ochi 3 | 4 | I implement the algorithm proposed at the paper "Modeling online reviews with multi-grain topic models." 5 | This paper was written by I.Titov et.al. 6 | 7 | I was consulted the LDA code written by S.Nakatani very much (https://github.com/shuyo/iir/blob/master/lda/lda.py). 8 | I want to take this opportunity to express my appreciation for his great work. 9 | Thank you. 10 | 11 | USAGE: 12 | $ python mglda.py 13 | -------------------------------------------------------------------------------- /vocabulary_for_mglda.py: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | ''' 5 | Created on Sep 11, 2012 6 | 7 | mglda用にwindowとかをつくります 8 | 9 | ''' 10 | 11 | import nltk, re 12 | 13 | def load_corpus(range): 14 | m = re.match(r'(\d+):(\d+)$', range) 15 | if m: 16 | start = int(m.group(1)) 17 | end = int(m.group(2)) 18 | from nltk.corpus import brown as corpus 19 | return [corpus.words(fileid) for fileid in corpus.fileids()[start:end]] 20 | 21 | def load_corpus_each_sentence(range): 22 | m = re.match(r'(\d+):(\d+)$', range) 23 | if m: 24 | start = int(m.group(1)) 25 | end = int(m.group(2)) 26 | # from nltk.corpus import brown as corpus 27 | from nltk.corpus import movie_reviews as corpus 28 | return [corpus.sents(fileid) for fileid in corpus.fileids()[start:end]] 29 | 30 | def load_file(filename): 31 | corpus = [] 32 | f = open(filename, 'r') 33 | for line in f: 34 | doc = re.findall(r'\w+(?:\'\w+)?',line) 35 | if len(doc)>0: 36 | corpus.append(doc) 37 | f.close() 38 | return corpus 39 | 40 | #stopwords_list = nltk.corpus.stopwords.words('english') 41 | stopwords_list = "a,s,able,about,above,according,accordingly,across,actually,after,afterwards,again,against,ain,t,all,allow,allows,almost,alone,along,already,also,although,always,am,among,amongst,an,and,another,any,anybody,anyhow,anyone,anything,anyway,anyways,anywhere,apart,appear,appreciate,appropriate,are,aren,t,around,as,aside,ask,asking,associated,at,available,away,awfully,be,became,because,become,becomes,becoming,been,before,beforehand,behind,being,believe,below,beside,besides,best,better,between,beyond,both,brief,but,by,c,mon,c,s,came,can,can,t,cannot,cant,cause,causes,certain,certainly,changes,clearly,co,com,come,comes,concerning,consequently,consider,considering,contain,containing,contains,corresponding,could,couldn,t,course,currently,definitely,described,despite,did,didn,t,different,do,does,doesn,t,doing,don,t,done,down,downwards,during,each,edu,eg,eight,either,else,elsewhere,enough,entirely,especially,et,etc,even,ever,every,everybody,everyone,everything,everywhere,ex,exactly,example,except,far,few,fifth,first,five,followed,following,follows,for,former,formerly,forth,four,from,further,furthermore,get,gets,getting,given,gives,go,goes,going,gone,got,gotten,greetings,had,hadn,t,happens,hardly,has,hasn,t,have,haven,t,having,he,he,s,hello,help,hence,her,here,here,s,hereafter,hereby,herein,hereupon,hers,herself,hi,him,himself,his,hither,hopefully,how,howbeit,however,i,d,i,ll,i,m,i,ve,ie,if,ignored,immediate,in,inasmuch,inc,indeed,indicate,indicated,indicates,inner,insofar,instead,into,inward,is,isn,t,it,it,d,it,ll,it,s,its,itself,just,keep,keeps,kept,know,knows,known,last,lately,later,latter,latterly,least,less,lest,let,let,s,like,liked,likely,little,look,looking,looks,ltd,mainly,many,may,maybe,me,mean,meanwhile,merely,might,more,moreover,most,mostly,much,must,my,myself,name,namely,nd,near,nearly,necessary,need,needs,neither,never,nevertheless,new,next,nine,no,nobody,non,none,noone,nor,normally,not,nothing,novel,now,nowhere,obviously,of,off,often,oh,ok,okay,old,on,once,one,ones,only,onto,or,other,others,otherwise,ought,our,ours,ourselves,out,outside,over,overall,own,particular,particularly,per,perhaps,placed,please,plus,possible,presumably,probably,provides,que,quite,qv,rather,rd,re,really,reasonably,regarding,regardless,regards,relatively,respectively,right,said,same,saw,say,saying,says,second,secondly,see,seeing,seem,seemed,seeming,seems,seen,self,selves,sensible,sent,serious,seriously,seven,several,shall,she,should,shouldn,t,since,six,so,some,somebody,somehow,someone,something,sometime,sometimes,somewhat,somewhere,soon,sorry,specified,specify,specifying,still,sub,such,sup,sure,t,s,take,taken,tell,tends,th,than,thank,thanks,thanx,that,that,s,thats,the,their,theirs,them,themselves,then,thence,there,there,s,thereafter,thereby,therefore,therein,theres,thereupon,these,they,they,d,they,ll,they,re,they,ve,think,third,this,thorough,thoroughly,those,though,three,through,throughout,thru,thus,to,together,too,took,toward,towards,tried,tries,truly,try,trying,twice,two,un,under,unfortunately,unless,unlikely,until,unto,up,upon,us,use,used,useful,uses,using,usually,value,various,very,via,viz,vs,want,wants,was,wasn,t,way,we,we,d,we,ll,we,re,we,ve,welcome,well,went,were,weren,t,what,what,s,whatever,when,whence,whenever,where,where,s,whereafter,whereas,whereby,wherein,whereupon,wherever,whether,which,while,whither,who,who,s,whoever,whole,whom,whose,why,will,willing,wish,with,within,without,won,t,wonder,would,would,wouldn,t,yes,yet,you,you,d,you,ll,you,re,you,ve,your,yours,yourself,yourselves,zero".split(',') 42 | recover_list = {"wa":"was", "ha":"has"} 43 | wl = nltk.WordNetLemmatizer() 44 | 45 | def is_stopword(w): 46 | return w in stopwords_list 47 | def lemmatize(w0): 48 | w = wl.lemmatize(w0.lower()) 49 | #if w=='de': print w0, w 50 | if w in recover_list: return recover_list[w] 51 | return w 52 | 53 | class Vocabulary: 54 | def __init__(self, excluds_stopwords=False): 55 | self.vocas = [] # id to word 56 | self.vocas_id = dict() # word to id 57 | self.docfreq = [] # id to document frequency 58 | self.excluds_stopwords = excluds_stopwords 59 | 60 | def term_to_id(self, term0): 61 | term = lemmatize(term0) 62 | if not re.match(r'[a-z]+$', term): return None 63 | if self.excluds_stopwords and is_stopword(term): return None 64 | if term not in self.vocas_id: 65 | voca_id = len(self.vocas) 66 | # print str(voca_id) + ": " + term 67 | self.vocas_id[term] = voca_id 68 | self.vocas.append(term) 69 | self.docfreq.append(0) 70 | else: 71 | voca_id = self.vocas_id[term] 72 | return voca_id 73 | 74 | def doc_to_ids(self, doc): 75 | #print ' '.join(doc) 76 | list = [] 77 | words = dict() 78 | for term in doc: 79 | id = self.term_to_id(term) 80 | if id != None: 81 | list.append(id) 82 | if not words.has_key(id): 83 | words[id] = 1 84 | self.docfreq[id] += 1 85 | if "close" in dir(doc): doc.close() 86 | return list 87 | 88 | def doc_to_ids_each_sentence(self, doc): 89 | #print ' '.join(doc) 90 | sent_list = [] 91 | words = dict() 92 | 93 | for sent in doc: 94 | list = [] 95 | for term in sent: 96 | id = self.term_to_id(term) 97 | if id != None: 98 | list.append(id) 99 | if not words.has_key(id): 100 | words[id] = 1 101 | self.docfreq[id] += 1 102 | sent_list.append(list) 103 | if "close" in dir(doc): doc.close() 104 | return sent_list 105 | 106 | def cut_low_freq(self, corpus, threshold=1): 107 | new_vocas = [] 108 | new_docfreq = [] 109 | self.vocas_id = dict() 110 | conv_map = dict() 111 | for id, term in enumerate(self.vocas): 112 | freq = self.docfreq[id] 113 | if freq > threshold: 114 | new_id = len(new_vocas) 115 | self.vocas_id[term] = new_id 116 | new_vocas.append(term) 117 | new_docfreq.append(freq) 118 | conv_map[id] = new_id 119 | self.vocas = new_vocas 120 | self.docfreq = new_docfreq 121 | 122 | def conv(doc): 123 | new_doc = [] 124 | for id in doc: 125 | if id in conv_map: new_doc.append(conv_map[id]) 126 | return new_doc 127 | return [conv(doc) for doc in corpus] 128 | 129 | def __getitem__(self, v): 130 | return self.vocas[v] 131 | 132 | def size(self): 133 | return len(self.vocas) 134 | 135 | def is_stopword_id(self, id): 136 | return self.vocas[id] in stopwords_list 137 | -------------------------------------------------------------------------------- /mglda.py: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # mg-lda 5 | # This code is available under the MIT License. 6 | # (c)2012 Masanao Ochi. 7 | 8 | import numpy 9 | 10 | class MGLDA: 11 | def __init__(self, K_gl, K_loc, gamma, alpha_gl, alpha_loc, alpha_mix_gl, alpha_mix_loc, beta_gl, beta_loc, T, docs, W, smartinit=False): 12 | 13 | self.K_gl = K_gl 14 | self.K_loc = K_loc 15 | 16 | self.gamma = gamma 17 | self.alpha_gl = alpha_gl # parameter of topics prior 18 | self.alpha_loc = alpha_loc 19 | self.alpha_mix_gl = alpha_mix_gl 20 | self.alpha_mix_loc = alpha_mix_loc 21 | 22 | self.beta_gl = beta_gl # parameter of words prior 23 | self.beta_loc = beta_loc 24 | 25 | self.T = T # sliding window width 26 | 27 | self.docs = docs 28 | self.W = W 29 | 30 | self.v_d_s_n = [] #sumi 31 | self.r_d_s_n = [] #sumi 32 | self.z_d_s_n = [] #sumi 33 | 34 | self.n_gl_z_w = numpy.zeros((self.K_gl, self.W)) 35 | self.n_gl_z = numpy.zeros(self.K_gl) 36 | self.n_d_s_v = [] #sumi 37 | self.n_d_s = [] #sumi 38 | self.n_d_v_gl = [] #sumi 39 | self.n_d_v = [] #sumi 40 | self.n_d_gl_z = numpy.zeros((len(self.docs), self.K_gl)) 41 | self.n_d_gl = numpy.zeros((len(self.docs))) 42 | 43 | self.n_loc_z_w = numpy.zeros((self.K_loc, self.W)) 44 | self.n_loc_z = numpy.zeros(self.K_loc) 45 | self.n_d_v_loc = [] #sumi 46 | self.n_d_v_loc_z = [] #sumi 47 | 48 | self.inflation = 0 49 | 50 | print "random fitting to initialize" 51 | for m, doc in enumerate(self.docs): 52 | v_d = [] 53 | r_d = [] 54 | z_d = [] 55 | 56 | n_d_s_v_d = [] 57 | n_d_s_d = [] 58 | 59 | n_d_v_gl_v = [] 60 | n_d_v_v = [] 61 | n_d_v_loc_v = [] 62 | n_d_v_loc_z_v = [] 63 | for v in range(self.T+len(doc)-1): 64 | n_d_v_gl_v.append(self.inflation) # initialize word count with global topic for each sliding window 65 | n_d_v_v.append(self.inflation) # initialize word count for each sliding window 66 | n_d_v_loc_v.append(self.inflation) # initialize word count with local topic for each sliding window 67 | 68 | n_d_v_loc_z_z = [] 69 | for k in range(self.K_loc): 70 | n_d_v_loc_z_z.append(self.inflation) # initialize word count assigned local topic for each sliding window 71 | n_d_v_loc_z_v.append(n_d_v_loc_z_z) 72 | 73 | self.n_d_v_gl.append(n_d_v_gl_v) 74 | self.n_d_v.append(n_d_v_v) 75 | self.n_d_v_loc.append(n_d_v_loc_v) 76 | 77 | self.n_d_v_loc_z.append(n_d_v_loc_z_v) 78 | 79 | for s, sent in enumerate(doc): 80 | v_s = [] 81 | r_s = [] 82 | z_s = [] 83 | for i, word in enumerate(sent): 84 | v = numpy.random.randint(0, self.T) # initialize sliding window for each word 85 | v_s.append(v) 86 | 87 | r_int = numpy.random.randint(0, 2) # initialize topic category 88 | r = "" 89 | if r_int == 0: 90 | r = "gl" 91 | else: 92 | r = "loc" 93 | r_s.append(r) 94 | 95 | z = 0 96 | if r == "gl": 97 | z = numpy.random.randint(0, self.K_gl) # initialize global topic 98 | else: 99 | z = numpy.random.randint(0, self.K_loc) # initialize local topic 100 | z_s.append(z) 101 | v_d.append(v_s) 102 | r_d.append(r_s) 103 | z_d.append(z_s) 104 | 105 | n_d_s_v_s = [] 106 | for v in range(self.T): 107 | n_d_s_v_s.append(self.inflation) # initialize n_d_s_v 108 | n_d_s_v_d.append(n_d_s_v_s) 109 | 110 | n_d_s_d.append(self.inflation) # initialize n_d_s 111 | 112 | self.v_d_s_n.append(v_d) 113 | self.r_d_s_n.append(r_d) 114 | self.z_d_s_n.append(z_d) 115 | 116 | self.n_d_s_v.append(n_d_s_v_d) 117 | self.n_d_s.append(n_d_s_d) 118 | 119 | print "initialize" 120 | for m, doc in enumerate(self.docs): 121 | for s, sent in enumerate(doc): 122 | for i, word in enumerate(sent): 123 | v = self.v_d_s_n[m][s][i] # 0--T 124 | r = self.r_d_s_n[m][s][i] 125 | z = self.z_d_s_n[m][s][i] 126 | if r == "gl": 127 | self.n_gl_z_w[z][word] += 1 128 | self.n_gl_z[z] += 1 129 | self.n_d_v_gl[m][s+v] += 1 130 | self.n_d_gl_z[m][z] += 1 131 | self.n_d_gl[m] += 1 132 | elif r == "loc": 133 | self.n_loc_z_w[z][word] += 1 134 | self.n_loc_z[z] += 1 135 | self.n_d_v_loc[m][s+v] += 1 136 | self.n_d_v_loc_z[m][s+v][z] += 1 137 | else: 138 | print "error0: " + str(r) 139 | 140 | self.n_d_s_v[m][s][v] += 1 141 | self.n_d_s[m][s] += 1 142 | self.n_d_v[m][s+v] += 1 143 | 144 | print "init comp." 145 | 146 | def inference(self): 147 | """learning once iteration""" 148 | for m, doc in enumerate(self.docs): 149 | for s, sent in enumerate(doc): 150 | for i, word in enumerate(sent): 151 | v = self.v_d_s_n[m][s][i] # 0--T 152 | r = self.r_d_s_n[m][s][i] 153 | z = self.z_d_s_n[m][s][i] 154 | 155 | # discount 156 | if r == "gl": 157 | self.n_gl_z_w[z][word] -= 1 158 | self.n_gl_z[z] -= 1 159 | self.n_d_v_gl[m][s+v] -= 1 160 | self.n_d_gl_z[m][z] -= 1 161 | self.n_d_gl[m] -= 1 162 | elif r == "loc": 163 | self.n_loc_z_w[z][word] -= 1 164 | self.n_loc_z[z] -= 1 165 | self.n_d_v_loc[m][s+v] -= 1 166 | self.n_d_v_loc_z[m][s+v][z] -= 1 167 | else: 168 | print "error1: " + str(r) 169 | 170 | self.n_d_s_v[m][s][v] -= 1 171 | self.n_d_s[m][s] -= 1 172 | self.n_d_v[m][s+v] -= 1 173 | 174 | # sampling topic new_z for t 175 | p_v_r_z = [] 176 | label_v_r_z = [] 177 | for v_t in range(self.T): 178 | # for r == "gl" 179 | for z_t in range(self.K_gl): 180 | label = [v_t, "gl", z_t] 181 | label_v_r_z.append(label) 182 | # sampling eq as gl 183 | term1 = float(self.n_gl_z_w[z_t][word] + self.beta_gl) / (self.n_gl_z[z_t] + self.W*self.beta_gl) 184 | term2 = float(self.n_d_s_v[m][s][v_t] + self.gamma) / (self.n_d_s[m][s] + self.T*self.gamma) 185 | term3 = float(self.n_d_v_gl[m][s+v_t] + self.alpha_mix_gl) / (self.n_d_v[m][s+v_t] + self.alpha_mix_gl + self.alpha_mix_loc) 186 | term4 = float(self.n_d_gl_z[m][z_t] + self.alpha_gl) / (self.n_d_gl[m] + self.K_gl*self.alpha_gl) 187 | score = term1 * term2 * term3 * term4 188 | p_v_r_z.append(score) 189 | # for r == "loc" 190 | for z_t in range(self.K_loc): 191 | label = [v_t, "loc", z_t] 192 | label_v_r_z.append(label) 193 | # sampling eq as loc 194 | term1 = float(self.n_loc_z_w[z_t][word] + self.beta_loc) / (self.n_loc_z[z_t] + self.W*self.beta_loc) 195 | term2 = float(self.n_d_s_v[m][s][v_t] + self.gamma) / (self.n_d_s[m][s] + self.T*self.gamma) 196 | term3 = float(self.n_d_v_loc[m][s+v_t] + self.alpha_mix_loc) / (self.n_d_v[m][s+v_t] + self.alpha_mix_gl + self.alpha_mix_loc) 197 | term4 = float(self.n_d_v_loc_z[m][s+v_t][z_t] + self.alpha_loc) / (self.n_d_v_loc[m][s+v_t] + self.K_loc*self.alpha_loc) 198 | score = term1 * term2 * term3 * term4 199 | p_v_r_z.append(score) 200 | 201 | np_p_v_r_z = numpy.array(p_v_r_z) 202 | new_p_v_r_z_idx = numpy.random.multinomial(1, np_p_v_r_z / np_p_v_r_z.sum()).argmax() 203 | new_v, new_r, new_z = label_v_r_z[new_p_v_r_z_idx] 204 | 205 | # update 206 | if new_r == "gl": 207 | self.n_gl_z_w[new_z][word] += 1 208 | self.n_gl_z[new_z] += 1 209 | self.n_d_v_gl[m][s+new_v] += 1 210 | self.n_d_gl_z[m][new_z] += 1 211 | self.n_d_gl[m] += 1 212 | elif new_r == "loc": 213 | self.n_loc_z_w[new_z][word] += 1 214 | self.n_loc_z[new_z] += 1 215 | self.n_d_v_loc[m][s+new_v] += 1 216 | self.n_d_v_loc_z[m][s+new_v][new_z] += 1 217 | else: 218 | print "error2: " + str(r) 219 | 220 | self.n_d_s_v[m][s][new_v] += 1 221 | self.n_d_s[m][s] += 1 222 | self.n_d_v[m][s+new_v] += 1 223 | 224 | self.v_d_s_n[m][s][i] = new_v 225 | self.r_d_s_n[m][s][i] = new_r 226 | self.z_d_s_n[m][s][i] = new_z 227 | 228 | def worddist(self): 229 | """get topic-word distribution""" 230 | return (self.n_gl_z_w + 1) / (self.n_gl_z[:, numpy.newaxis] + 1), (self.n_loc_z_w + 1) / (self.n_loc_z[:, numpy.newaxis] + 1) 231 | 232 | def mglda_learning(mglda, iteration, voca): 233 | for i in range(iteration): 234 | print "\n\n\n==== " + str(i) + "-th inference ====" 235 | mglda.inference() 236 | print "inference complete" 237 | output_word_topic_dist(mglda, voca) 238 | 239 | def output_word_topic_dist(mglda, voca): 240 | z_gl_count = numpy.zeros(mglda.K_gl, dtype=int) 241 | z_loc_count = numpy.zeros(mglda.K_loc, dtype=int) 242 | word_gl_count = [dict() for k in xrange(mglda.K_gl)] 243 | word_loc_count = [dict() for k in xrange(mglda.K_loc)] 244 | 245 | for m, doc in enumerate(mglda.docs): 246 | for s, sent in enumerate(doc): 247 | for i, word in enumerate(sent): 248 | # v = mglda.v_d_s_n[m][s][i] # 0--T 249 | r = mglda.r_d_s_n[m][s][i] 250 | z = mglda.z_d_s_n[m][s][i] 251 | if r == "gl": 252 | z_gl_count[z] += 1 253 | if word in word_gl_count[z]: 254 | word_gl_count[z][word] += 1 255 | else: 256 | word_gl_count[z][word] = 1 257 | elif r == "loc": 258 | z_loc_count[z] += 1 259 | if word in word_loc_count[z]: 260 | word_loc_count[z][word] += 1 261 | else: 262 | word_loc_count[z][word] = 1 263 | else: 264 | print "error3: " + str(r) 265 | 266 | phi_gl, phi_loc = mglda.worddist() 267 | for k in range(mglda.K_gl): 268 | print "\n-- global topic: %d (%d words)" % (k, z_gl_count[k]) 269 | print "mglda.n_gl_z[k]" 270 | print mglda.n_gl_z[k] 271 | for w in numpy.argsort(-phi_gl[k])[:20]: 272 | print "%s: %f (%d)" % (voca[w], phi_gl[k,w], word_gl_count[k].get(w,0)) 273 | 274 | for k in range(mglda.K_loc): 275 | print "\n-- local topic: %d (%d words)" % (k, z_loc_count[k]) 276 | print mglda.n_loc_z[k] 277 | print "mglda.n_loc_z[k]" 278 | for w in numpy.argsort(-phi_loc[k])[:20]: 279 | print "%s: %f (%d)" % (voca[w], phi_loc[k,w], word_loc_count[k].get(w,0)) 280 | 281 | def test(): 282 | # import nltk.corpus 283 | import vocabulary_for_mglda as vocabulary 284 | 285 | corpus = vocabulary.load_corpus_each_sentence("0:2000") 286 | 287 | #docs[sentence_idx][word_idx] 288 | voca = vocabulary.Vocabulary(True) 289 | docs = [voca.doc_to_ids_each_sentence(doc) for doc in corpus] 290 | K_gl, K_loc, gamma, alpha_gl, alpha_loc, alpha_mix_gl, alpha_mix_loc, beta_gl, beta_loc, T, docs, W = 50, 10, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 3, docs, voca.size() 291 | mglda = MGLDA(K_gl, K_loc, gamma, alpha_gl, alpha_loc, alpha_mix_gl, alpha_mix_loc, beta_gl, beta_loc, T, docs, W) 292 | print "corpus=%d, words=%d, K_gl=%d, K_loc=%d, gamma=%f, alpha_gl=%f, alpha_loc=%f, alpha_mix_gl=%f, alpha_mix_loc=%f, beta_gl=%f, beta_loc=%f" % (len(corpus), len(voca.vocas), K_gl, K_loc, gamma, alpha_gl, alpha_loc, alpha_mix_gl, alpha_mix_loc, beta_gl, beta_loc) 293 | 294 | iteration = 1000 295 | mglda_learning(mglda, iteration, voca) 296 | 297 | if __name__ == "__main__": 298 | test() 299 | --------------------------------------------------------------------------------