├── dpmm ├── __init__.py ├── pp_plot.py ├── generative_process.py ├── algorithm_3.py ├── test_inference.py ├── algorithm_8.py ├── getting_it_right.py ├── conjugate_split_merge.py └── nonconjugate_split_merge.py ├── README.md ├── setup.py └── LICENSE /dpmm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Dirichlet Process Mixture Model 2 | =============================== 3 | 4 | Example code for a Dirichlet Process mixture of Dirichlet-multinomial 5 | distributions. Implements the generative process, four inference 6 | algorithms (Neal's "algorithm 3" and "algorithm 8", Jain & Neal's 7 | conjugate split-merge and nonconjugate split-merge algorithms), and 8 | Geweke's "getting it right" posterior simulator test. 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | 4 | setup(name='dpmm', 5 | version='1.0', 6 | description='Dirichlet Process Mixture Model', 7 | url='https://github.com/hannawallach/dpmm/', 8 | author='Hanna Wallach', 9 | author_email='hanna@dirichlet.net', 10 | license='Apache 2.0', 11 | packages=['dpmm'], 12 | install_requires=['kale', 'matplotlib', 'numpy', 'scipy']) 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Microsoft. All rights reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); you 4 | may not use this file except in compliance with the License. You may 5 | obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 | implied. See the License for the specific language governing 13 | permissions and limitations under the License. 14 | -------------------------------------------------------------------------------- /dpmm/pp_plot.py: -------------------------------------------------------------------------------- 1 | from numpy import arange, array, empty_like, searchsorted, sort 2 | from pylab import plot, show 3 | 4 | 5 | def cdf(data): 6 | """ 7 | Returns the empirical CDF (a function) for the specified data. 8 | 9 | Arguments: 10 | 11 | data -- data from which to compute the CDF 12 | """ 13 | 14 | tmp = empty_like(data) 15 | tmp[:] = data 16 | tmp.sort() 17 | 18 | def f(x): 19 | return searchsorted(tmp, x, 'right') / float(len(tmp)) 20 | 21 | return f 22 | 23 | 24 | def pp_plot(a, b): 25 | """ 26 | Generates a P-P plot. 27 | """ 28 | 29 | x = sort(a) 30 | 31 | if len(x) > 10000: 32 | step = len(x) / 5000 33 | x = x[::step] 34 | 35 | plot(cdf(a)(x), cdf(b)(x), alpha=0.5) 36 | plot([0, 1], [0, 1], ':', c='k', lw=2, alpha=0.5) 37 | 38 | show() 39 | 40 | 41 | def test(num_samples=100000): 42 | 43 | from numpy.random import normal 44 | 45 | a = normal(20.0, 5.0, num_samples) 46 | b = normal(20.0, 5.0, num_samples) 47 | 48 | pp_plot(a, b) 49 | 50 | 51 | if __name__ == '__main__': 52 | test() 53 | -------------------------------------------------------------------------------- /dpmm/generative_process.py: -------------------------------------------------------------------------------- 1 | from numpy import argsort, bincount, ones, where, zeros 2 | from numpy.random import poisson, seed 3 | from numpy.random.mtrand import dirichlet 4 | 5 | from kale.math_utils import sample 6 | 7 | 8 | def generate_data(V, D, l, alpha, beta): 9 | """ 10 | Generates a synthetic corpus of documents from a Dirichlet process 11 | mixture model with multinomial mixture components (topics). The 12 | mixture components are drawn from a symmetric Dirichlet prior. 13 | 14 | Arguments: 15 | 16 | V -- vocabulary size 17 | D -- number of documents 18 | l -- average document length 19 | alpha -- concentration parameter for the Dirichlet process 20 | beta -- concentration parameter for the symmetric Dirichlet prior 21 | """ 22 | 23 | T = D # maximum number of topics 24 | 25 | phi_TV = zeros((T, V)) 26 | z_D = zeros(D, dtype=int) 27 | N_DV = zeros((D, V), dtype=int) 28 | 29 | for d in xrange(D): 30 | 31 | # draw a topic assignment for this document 32 | 33 | dist = bincount(z_D).astype(float) 34 | dist[0] = alpha 35 | [t] = sample(dist) 36 | t = len(dist) if t == 0 else t 37 | z_D[d] = t 38 | 39 | # if it's a new topic, draw the parameters for that topic 40 | 41 | if t == len(dist): 42 | phi_TV[t - 1, :] = dirichlet(beta * ones(V) / V) 43 | 44 | # draw the tokens from the topic 45 | 46 | for v in sample(phi_TV[t - 1, :], num_samples=poisson(l)): 47 | N_DV[d, v] += 1 48 | 49 | z_D = z_D - 1 50 | 51 | return phi_TV, z_D, N_DV 52 | -------------------------------------------------------------------------------- /dpmm/algorithm_3.py: -------------------------------------------------------------------------------- 1 | from numpy import bincount, log, log2, seterr, unique, zeros 2 | from scipy.special import gammaln 3 | 4 | from kale.math_utils import log_sample, vi 5 | 6 | 7 | def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T): 8 | """ 9 | Performs a single iteration of Radford Neal's Algorithm 3. 10 | """ 11 | 12 | for d in xrange(D): 13 | 14 | old_t = z_D[d] 15 | 16 | if inv_z_T is not None: 17 | inv_z_T[old_t].remove(d) 18 | 19 | N_TV[old_t, :] -= N_DV[d, :] 20 | N_T[old_t] -= N_D[d] 21 | D_T[old_t] -= 1 22 | 23 | seterr(divide='ignore') 24 | log_dist = log(D_T) 25 | seterr(divide='warn') 26 | 27 | idx = old_t if D_T[old_t] == 0 else inactive_topics.pop() 28 | active_topics.add(idx) 29 | log_dist[idx] = log(alpha) 30 | 31 | for t in active_topics: 32 | log_dist[t] += gammaln(N_T[t] + beta) 33 | log_dist[t] -= gammaln(N_D[d] + N_T[t] + beta) 34 | tmp = N_TV[t, :] + beta / V 35 | log_dist[t] += gammaln(N_DV[d, :] + tmp).sum() 36 | log_dist[t] -= gammaln(tmp).sum() 37 | 38 | [t] = log_sample(log_dist) 39 | 40 | z_D[d] = t 41 | 42 | if inv_z_T is not None: 43 | inv_z_T[t].add(d) 44 | 45 | N_TV[t, :] += N_DV[d, :] 46 | N_T[t] += N_D[d] 47 | D_T[t] += 1 48 | 49 | if t != idx: 50 | active_topics.remove(idx) 51 | inactive_topics.add(idx) 52 | 53 | 54 | def inference(N_DV, alpha, beta, z_D, num_itns, true_z_D=None): 55 | """ 56 | Algorithm 3. 57 | """ 58 | 59 | D, V = N_DV.shape 60 | 61 | T = D # maximum number of topics 62 | 63 | N_D = N_DV.sum(1) # document lengths 64 | 65 | active_topics = set(unique(z_D)) 66 | inactive_topics = set(xrange(T)) - active_topics 67 | 68 | N_TV = zeros((T, V), dtype=int) 69 | N_T = zeros(T, dtype=int) 70 | 71 | for d in xrange(D): 72 | N_TV[z_D[d], :] += N_DV[d, :] 73 | N_T[z_D[d]] += N_D[d] 74 | 75 | D_T = bincount(z_D, minlength=T) 76 | 77 | for itn in xrange(num_itns): 78 | 79 | iteration(V, D, N_DV, N_D, alpha, beta, z_D, None, active_topics, inactive_topics, N_TV, N_T, D_T) 80 | 81 | if true_z_D is not None: 82 | 83 | v = vi(true_z_D, z_D) 84 | 85 | print 'Itn. %d' % (itn + 1) 86 | print '%d topics' % len(active_topics) 87 | print 'VI: %f bits (%f bits max.)' % (v, log2(D)) 88 | 89 | if v < 1e-6: 90 | break 91 | 92 | return z_D 93 | -------------------------------------------------------------------------------- /dpmm/test_inference.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 2 | from numpy import argsort, bincount, set_printoptions, where, zeros 3 | from numpy.random import seed 4 | 5 | from generative_process import generate_data 6 | 7 | 8 | def test_inference(algorithm, V, D, l, alpha, beta, num_itns, s): 9 | """ 10 | Generates data via the generative process and then infers the 11 | parameters of the generative process using that data. 12 | """ 13 | 14 | seed(s) 15 | 16 | print 'Generating data...' 17 | 18 | phi_TV, z_D, N_DV = generate_data(V, D, l, alpha, beta) 19 | 20 | set_printoptions(precision=4, suppress=True) 21 | 22 | for t in argsort(bincount(z_D))[::-1]: 23 | idx, = where(z_D[:] == t) 24 | print len(idx), phi_TV[t, :] 25 | 26 | print 'Running inference...' 27 | 28 | # initialize every document to the same topic 29 | 30 | algorithm.inference(N_DV, alpha, beta, zeros(D, dtype=int), num_itns, z_D) 31 | 32 | 33 | def main(): 34 | 35 | import algorithm_3 36 | import algorithm_8 37 | import conjugate_split_merge 38 | import nonconjugate_split_merge 39 | 40 | functions = { 41 | 'algorithm_3': algorithm_3, 42 | 'algorithm_8': algorithm_8, 43 | 'conjugate_split_merge': conjugate_split_merge, 44 | 'nonconjugate_split_merge': nonconjugate_split_merge 45 | } 46 | 47 | p = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 48 | 49 | p.add_argument('algorithm', metavar='', 50 | choices=['algorithm_3', 51 | 'algorithm_8', 52 | 'conjugate_split_merge', 53 | 'nonconjugate_split_merge'], 54 | help='inference algorithm to test') 55 | p.add_argument('-V', type=int, metavar='', default=5, 56 | help='vocabulary size') 57 | p.add_argument('-D', type=int, metavar='', default=1000, 58 | help='number of documents') 59 | p.add_argument('-l', type=int, metavar='', default=1000, 60 | help='average document length') 61 | p.add_argument('--alpha', type=float, metavar='', default=1.0, 62 | help='concentration parameter for the DP') 63 | p.add_argument('--beta', type=float, metavar='', default=0.5, 64 | help='concentration parameter for the Dirichlet prior') 65 | p.add_argument('--num-itns', type=int, metavar='', default=250, 66 | help='number of iterations') 67 | p.add_argument('--seed', type=int, metavar='', 68 | help='seed for the random number generator') 69 | 70 | args = p.parse_args() 71 | 72 | test_inference(functions[args.algorithm], 73 | args.V, 74 | args.D, 75 | args.l, 76 | args.alpha, 77 | args.beta, 78 | args.num_itns, 79 | args.seed) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /dpmm/algorithm_8.py: -------------------------------------------------------------------------------- 1 | from numpy import bincount, log, log2, ones, seterr, unique, zeros 2 | from numpy.random.mtrand import dirichlet 3 | 4 | from kale.math_utils import log_sample, vi 5 | 6 | 7 | def iteration(V, D, N_DV, N_D, alpha, beta, M, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T): 8 | """ 9 | Performs a single iteration of Radford Neal's Algorithm 8. 10 | """ 11 | 12 | for t in active_topics: 13 | phi_TV[t, :] = dirichlet(N_TV[t, :] + beta / V) 14 | 15 | for d in xrange(D): 16 | 17 | old_t = z_D[d] 18 | 19 | if inv_z_T is not None: 20 | inv_z_T[old_t].remove(d) 21 | 22 | N_TV[old_t, :] -= N_DV[d, :] 23 | N_T[old_t] -= N_D[d] 24 | D_T[old_t] -= 1 25 | 26 | seterr(divide='ignore') 27 | log_dist = log(D_T) 28 | seterr(divide='warn') 29 | 30 | idx = -1 * ones(M, dtype=int) 31 | idx[0] = old_t if D_T[old_t] == 0 else inactive_topics.pop() 32 | for m in xrange(1, M): 33 | idx[m] = inactive_topics.pop() 34 | active_topics |= set(idx) 35 | log_dist[idx] = log(alpha) - log(M) 36 | 37 | if idx[0] == old_t: 38 | phi_TV[idx[1:], :] = dirichlet(beta * ones(V) / V, M - 1) 39 | else: 40 | phi_TV[idx, :] = dirichlet(beta * ones(V) / V, M) 41 | 42 | for t in active_topics: 43 | log_dist[t] += (N_DV[d, :] * log(phi_TV[t, :])).sum() 44 | 45 | [t] = log_sample(log_dist) 46 | 47 | z_D[d] = t 48 | 49 | if inv_z_T is not None: 50 | inv_z_T[t].add(d) 51 | 52 | N_TV[t, :] += N_DV[d, :] 53 | N_T[t] += N_D[d] 54 | D_T[t] += 1 55 | 56 | idx = set(idx) 57 | idx.discard(t) 58 | active_topics -= idx 59 | inactive_topics |= idx 60 | 61 | 62 | def inference(N_DV, alpha, beta, z_D, num_itns, true_z_D=None): 63 | """ 64 | Algorithm 8. 65 | """ 66 | 67 | M = 10 # number of auxiliary samples 68 | 69 | D, V = N_DV.shape 70 | 71 | T = D + M - 1 # maximum number of topics 72 | 73 | N_D = N_DV.sum(1) # document lengths 74 | 75 | phi_TV = zeros((T, V)) # topic parameters 76 | 77 | active_topics = set(unique(z_D)) 78 | inactive_topics = set(xrange(T)) - active_topics 79 | 80 | N_TV = zeros((T, V), dtype=int) 81 | N_T = zeros(T, dtype=int) 82 | 83 | for d in xrange(D): 84 | N_TV[z_D[d], :] += N_DV[d, :] 85 | N_T[z_D[d]] += N_D[d] 86 | 87 | D_T = bincount(z_D, minlength=T) 88 | 89 | for itn in xrange(num_itns): 90 | 91 | iteration(V, D, N_DV, N_D, alpha, beta, M, phi_TV, z_D, None, active_topics, inactive_topics, N_TV, N_T, D_T) 92 | 93 | if true_z_D is not None: 94 | 95 | v = vi(true_z_D, z_D) 96 | 97 | print 'Itn. %d' % (itn + 1) 98 | print '%d topics' % len(active_topics) 99 | print 'VI: %f bits (%f bits max.)' % (v, log2(D)) 100 | 101 | if v < 1e-6: 102 | break 103 | 104 | return phi_TV, z_D 105 | -------------------------------------------------------------------------------- /dpmm/getting_it_right.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 2 | from numpy import array, empty_like, unique, zeros 3 | from numpy.random import poisson, seed 4 | 5 | from generative_process import generate_data 6 | from kale.iterview import iterview 7 | from kale.math_utils import sample 8 | from pp_plot import pp_plot 9 | 10 | 11 | def getting_it_right(algorithm, V, D, l, alpha, beta, num_itns, s): 12 | """ 13 | Runs Geweke's "getting it right" test. 14 | """ 15 | 16 | seed(s) 17 | 18 | # generate forward samples via the generative process 19 | 20 | print 'Generating forward samples...' 21 | 22 | forward_samples = [] 23 | 24 | for _ in iterview(xrange(num_itns)): 25 | forward_samples.append(generate_data(V, D, l, alpha, beta)[1:]) 26 | 27 | # generate reverse samples via the inference algorithm 28 | 29 | print 'Generating reverse samples...' 30 | 31 | reverse_samples = [] 32 | 33 | phi_TV, z_D, _ = generate_data(V, D, l, alpha, beta) 34 | 35 | for _ in iterview(xrange(num_itns)): 36 | 37 | N_DV = zeros((D, V), dtype=int) 38 | 39 | if (algorithm.__name__ == 'algorithm_8' or 40 | algorithm.__name__ == 'nonconjugate_split_merge'): 41 | for d in xrange(D): 42 | for v in sample(phi_TV[z_D[d], :], num_samples=poisson(l)): 43 | N_DV[d, v] += 1 44 | 45 | phi_TV, z_D = algorithm.inference(N_DV, alpha, beta, z_D, 1) 46 | 47 | else: 48 | 49 | T = D # maximum number of topics 50 | 51 | N_TV = zeros((T, V), dtype=int) 52 | N_T = zeros(T, dtype=int) 53 | 54 | for d in xrange(D): 55 | t = z_D[d] 56 | for _ in xrange(poisson(l)): 57 | [v] = sample((N_TV[t, :] + beta / V) / (N_T[t] + beta)) 58 | N_DV[d, v] += 1 59 | N_TV[t, v] += 1 60 | N_T[t] += 1 61 | 62 | z_D = algorithm.inference(N_DV, alpha, beta, z_D, 1) 63 | 64 | z_D_copy = empty_like(z_D) 65 | z_D_copy[:] = z_D 66 | 67 | reverse_samples.append((z_D_copy, N_DV)) 68 | 69 | print 'Computing test statistics...' 70 | 71 | # test statistics: number of topics, maximum topic size, mean 72 | # topic size, standard deviation of topic sizes 73 | 74 | # compute test statistics for forward samples 75 | 76 | forward_num_topics = [] 77 | forward_max_topic_size = [] 78 | forward_mean_topic_size = [] 79 | forward_std_topic_size = [] 80 | 81 | for z_D, _ in forward_samples: 82 | forward_num_topics.append(len(unique(z_D))) 83 | topic_sizes = [] 84 | for t in unique(z_D): 85 | topic_sizes.append((z_D[:] == t).sum()) 86 | topic_sizes = array(topic_sizes) 87 | forward_max_topic_size.append(topic_sizes.max()) 88 | forward_mean_topic_size.append(topic_sizes.mean()) 89 | forward_std_topic_size.append(topic_sizes.std()) 90 | 91 | # compute test statistics for reverse samples 92 | 93 | reverse_num_topics = [] 94 | reverse_max_topic_size = [] 95 | reverse_mean_topic_size = [] 96 | reverse_std_topic_size = [] 97 | 98 | for z_D, _ in reverse_samples: 99 | reverse_num_topics.append(len(unique(z_D))) 100 | topic_sizes = [] 101 | for t in unique(z_D): 102 | topic_sizes.append((z_D[:] == t).sum()) 103 | topic_sizes = array(topic_sizes) 104 | reverse_max_topic_size.append(topic_sizes.max()) 105 | reverse_mean_topic_size.append(topic_sizes.mean()) 106 | reverse_std_topic_size.append(topic_sizes.std()) 107 | 108 | # generate P-P plots 109 | 110 | pp_plot(array(forward_num_topics), array(reverse_num_topics)) 111 | pp_plot(array(forward_max_topic_size), array(reverse_max_topic_size)) 112 | pp_plot(array(forward_mean_topic_size), array(reverse_mean_topic_size)) 113 | pp_plot(array(forward_std_topic_size), array(reverse_std_topic_size)) 114 | 115 | 116 | def main(): 117 | 118 | import algorithm_3 119 | import algorithm_8 120 | import conjugate_split_merge 121 | import nonconjugate_split_merge 122 | 123 | functions = { 124 | 'algorithm_3': algorithm_3, 125 | 'algorithm_8': algorithm_8, 126 | 'conjugate_split_merge': conjugate_split_merge, 127 | 'nonconjugate_split_merge': nonconjugate_split_merge 128 | } 129 | 130 | p = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 131 | 132 | p.add_argument('algorithm', metavar='', 133 | choices=['algorithm_3', 134 | 'algorithm_8', 135 | 'conjugate_split_merge', 136 | 'nonconjugate_split_merge'], 137 | help='inference algorithm to test') 138 | p.add_argument('-V', type=int, metavar='', default=3, 139 | help='vocabulary size') 140 | p.add_argument('-D', type=int, metavar='', default=10, 141 | help='number of documents') 142 | p.add_argument('-l', type=int, metavar='', default=10, 143 | help='average document length') 144 | p.add_argument('--alpha', type=float, metavar='', default=1.0, 145 | help='concentration parameter for the DP') 146 | p.add_argument('--beta', type=float, metavar='', default=3.0, 147 | help='concentration parameter for the Dirichlet prior') 148 | p.add_argument('--num-itns', type=int, metavar='', default=50000, 149 | help='number of iterations') 150 | p.add_argument('--seed', type=int, metavar='', 151 | help='seed for the random number generator') 152 | 153 | args = p.parse_args() 154 | 155 | getting_it_right(functions[args.algorithm], 156 | args.V, 157 | args.D, 158 | args.l, 159 | args.alpha, 160 | args.beta, 161 | args.num_itns, 162 | args.seed) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /dpmm/conjugate_split_merge.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from numpy import bincount, empty, log, log2, unique, zeros 3 | from numpy.random import choice, uniform 4 | from scipy.special import gammaln 5 | 6 | from algorithm_3 import iteration as algorithm_3_iteration 7 | from kale.math_utils import log_sample, log_sum_exp, vi 8 | 9 | 10 | def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, num_inner_itns): 11 | """ 12 | Performs a single iteration of Metropolis-Hastings (split-merge). 13 | """ 14 | 15 | N_s_V = empty(V, dtype=int) 16 | N_t_V = empty(V, dtype=int) 17 | 18 | log_dist = empty(2) 19 | 20 | d, e = choice(D, 2, replace=False) # choose 2 documents 21 | 22 | if z_D[d] == z_D[e]: 23 | s = inactive_topics.pop() 24 | active_topics.add(s) 25 | else: 26 | s = z_D[d] 27 | 28 | inv_z_s = set([d]) 29 | N_s_V[:] = N_DV[d, :] 30 | N_s = N_D[d] 31 | D_s = 1 32 | 33 | t = z_D[e] 34 | inv_z_t = set([e]) 35 | N_t_V[:] = N_DV[e, :] 36 | N_t = N_D[e] 37 | D_t = 1 38 | 39 | if z_D[d] == z_D[e]: 40 | idx = inv_z_T[t] - set([d, e]) 41 | else: 42 | idx = (inv_z_T[s] | inv_z_T[t]) - set([d, e]) 43 | 44 | for f in idx: 45 | if uniform() < 0.5: 46 | inv_z_s.add(f) 47 | N_s_V += N_DV[f, :] 48 | N_s += N_D[f] 49 | D_s += 1 50 | else: 51 | inv_z_t.add(f) 52 | N_t_V += N_DV[f, :] 53 | N_t += N_D[f] 54 | D_t += 1 55 | 56 | acc = 0.0 57 | 58 | for inner_itn in xrange(num_inner_itns): 59 | for f in idx: 60 | 61 | # (fake) restricted Gibbs sampling scan 62 | 63 | if f in inv_z_s: 64 | inv_z_s.remove(f) 65 | N_s_V -= N_DV[f, :] 66 | N_s -= N_D[f] 67 | D_s -= 1 68 | else: 69 | inv_z_t.remove(f) 70 | N_t_V -= N_DV[f, :] 71 | N_t -= N_D[f] 72 | D_t -= 1 73 | 74 | log_dist[0] = log(D_s) 75 | log_dist[0] += gammaln(N_s + beta) 76 | log_dist[0] -= gammaln(N_D[f] + N_s + beta) 77 | tmp = N_s_V + beta / V 78 | log_dist[0] += gammaln(N_DV[f, :] + tmp).sum() 79 | log_dist[0] -= gammaln(tmp).sum() 80 | 81 | log_dist[1] = log(D_t) 82 | log_dist[1] += gammaln(N_t + beta) 83 | log_dist[1] -= gammaln(N_D[f] + N_t + beta) 84 | tmp = N_t_V + beta / V 85 | log_dist[1] += gammaln(N_DV[f, :] + tmp).sum() 86 | log_dist[1] -= gammaln(tmp).sum() 87 | 88 | log_dist -= log_sum_exp(log_dist) 89 | 90 | if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]: 91 | u = 0 if z_D[f] == s else 1 92 | else: 93 | [u] = log_sample(log_dist) 94 | 95 | if u == 0: 96 | inv_z_s.add(f) 97 | N_s_V += N_DV[f, :] 98 | N_s += N_D[f] 99 | D_s += 1 100 | else: 101 | inv_z_t.add(f) 102 | N_t_V += N_DV[f, :] 103 | N_t += N_D[f] 104 | D_t += 1 105 | 106 | if inner_itn == num_inner_itns - 1: 107 | acc += log_dist[u] 108 | 109 | if z_D[d] == z_D[e]: 110 | 111 | acc *= -1.0 112 | 113 | acc += log(alpha) 114 | acc += gammaln(D_s) + gammaln(D_t) - gammaln(D_T[t]) 115 | 116 | acc += gammaln(beta) + gammaln(N_T[t] + beta) 117 | acc -= gammaln(N_s + beta) + gammaln(N_t + beta) 118 | tmp = beta / V 119 | acc += gammaln(N_s_V + tmp).sum() + gammaln(N_t_V + tmp).sum() 120 | acc -= V * gammaln(tmp) + gammaln(N_TV[t, :] + tmp).sum() 121 | 122 | if log(uniform()) < min(0.0, acc): 123 | z_D[list(inv_z_s)] = s 124 | z_D[list(inv_z_t)] = t 125 | inv_z_T[s] = inv_z_s 126 | inv_z_T[t] = inv_z_t 127 | N_TV[s, :] = N_s_V 128 | N_TV[t, :] = N_t_V 129 | N_T[s] = N_s 130 | N_T[t] = N_t 131 | D_T[s] = D_s 132 | D_T[t] = D_t 133 | else: 134 | active_topics.remove(s) 135 | inactive_topics.add(s) 136 | 137 | else: 138 | 139 | for f in inv_z_T[s]: 140 | inv_z_t.add(f) 141 | N_t_V += N_DV[f, :] 142 | N_t += N_D[f] 143 | D_t += 1 144 | 145 | acc -= log(alpha) 146 | acc += gammaln(D_t) - gammaln(D_T[s]) - gammaln(D_T[t]) 147 | 148 | acc += gammaln(N_T[s] + beta) + gammaln(N_T[t] + beta) 149 | acc -= gammaln(beta) + gammaln(N_t + beta) 150 | tmp = beta / V 151 | acc += V * gammaln(tmp) + gammaln(N_t_V + tmp).sum() 152 | acc -= (gammaln(N_TV[s, :] + tmp).sum() + 153 | gammaln(N_TV[t, :] + tmp).sum()) 154 | 155 | if log(uniform()) < min(0.0, acc): 156 | active_topics.remove(s) 157 | inactive_topics.add(s) 158 | z_D[list(inv_z_t)] = t 159 | inv_z_T[s].clear() 160 | inv_z_T[t] = inv_z_t 161 | N_TV[s, :] = zeros(V, dtype=int) 162 | N_TV[t, :] = N_t_V 163 | N_T[s] = 0 164 | N_T[t] = N_t 165 | D_T[s] = 0 166 | D_T[t] = D_t 167 | 168 | 169 | def inference(N_DV, alpha, beta, z_D, num_itns, true_z_D=None): 170 | """ 171 | Conjugate split-merge. 172 | """ 173 | 174 | D, V = N_DV.shape 175 | 176 | T = D # maximum number of topics 177 | 178 | N_D = N_DV.sum(1) # document lengths 179 | 180 | inv_z_T = defaultdict(set) 181 | for d in xrange(D): 182 | inv_z_T[z_D[d]].add(d) # inverse mapping from topics to documents 183 | 184 | active_topics = set(unique(z_D)) 185 | inactive_topics = set(xrange(T)) - active_topics 186 | 187 | N_TV = zeros((T, V), dtype=int) 188 | N_T = zeros(T, dtype=int) 189 | 190 | for d in xrange(D): 191 | N_TV[z_D[d], :] += N_DV[d, :] 192 | N_T[z_D[d]] += N_D[d] 193 | 194 | D_T = bincount(z_D, minlength=T) 195 | 196 | for itn in xrange(num_itns): 197 | 198 | for _ in xrange(3): 199 | iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, 6) 200 | 201 | algorithm_3_iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T) 202 | 203 | if true_z_D is not None: 204 | 205 | v = vi(true_z_D, z_D) 206 | 207 | print 'Itn. %d' % (itn + 1) 208 | print '%d topics' % len(active_topics) 209 | print 'VI: %f bits (%f bits max.)' % (v, log2(D)) 210 | 211 | if v < 1e-6: 212 | break 213 | 214 | return z_D 215 | -------------------------------------------------------------------------------- /dpmm/nonconjugate_split_merge.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from numpy import bincount, empty, log, log2, unique, zeros 3 | from numpy.random import choice, uniform 4 | from numpy.random.mtrand import dirichlet 5 | from scipy.special import gammaln 6 | 7 | from algorithm_8 import iteration as algorithm_8_iteration 8 | from kale.math_utils import log_sample, log_sum_exp, vi 9 | 10 | 11 | def iteration(V, D, N_DV, N_D, alpha, beta, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, num_inner_itns): 12 | """ 13 | Performs a single iteration of Metropolis-Hastings (split-merge). 14 | """ 15 | 16 | phi_s_V = empty(V) 17 | phi_t_V = empty(V) 18 | phi_merge_t_V = empty(V) 19 | 20 | N_s_V = empty(V, dtype=int) 21 | N_t_V = empty(V, dtype=int) 22 | N_merge_t_V = empty(V, dtype=int) 23 | 24 | log_dist = empty(2) 25 | 26 | d, e = choice(D, 2, replace=False) # choose 2 documents 27 | 28 | if z_D[d] == z_D[e]: 29 | s = inactive_topics.pop() 30 | active_topics.add(s) 31 | else: 32 | s = z_D[d] 33 | 34 | inv_z_s = set([d]) 35 | N_s_V[:] = N_DV[d, :] 36 | N_s = N_D[d] 37 | D_s = 1 38 | 39 | t = z_D[e] 40 | inv_z_t = set([e]) 41 | N_t_V[:] = N_DV[e, :] 42 | N_t = N_D[e] 43 | D_t = 1 44 | 45 | inv_z_merge_t = set([d, e]) 46 | N_merge_t_V[:] = N_DV[d, :] + N_DV[e, :] 47 | N_merge_t = N_D[d] + N_D[e] 48 | D_merge_t = 2 49 | 50 | if z_D[d] == z_D[e]: 51 | idx = inv_z_T[t] - set([d, e]) 52 | else: 53 | idx = (inv_z_T[s] | inv_z_T[t]) - set([d, e]) 54 | 55 | for f in idx: 56 | if uniform() < 0.5: 57 | inv_z_s.add(f) 58 | N_s_V += N_DV[f, :] 59 | N_s += N_D[f] 60 | D_s += 1 61 | else: 62 | inv_z_t.add(f) 63 | N_t_V += N_DV[f, :] 64 | N_t += N_D[f] 65 | D_t += 1 66 | 67 | inv_z_merge_t.add(f) 68 | N_merge_t_V += N_DV[f, :] 69 | N_merge_t += N_D[f] 70 | D_merge_t += 1 71 | 72 | if z_D[d] == z_D[e]: 73 | phi_merge_t_V[:] = phi_TV[t, :] 74 | else: 75 | phi_merge_t_V = dirichlet(N_merge_t_V + beta / V) 76 | 77 | acc = 0.0 78 | 79 | for inner_itn in xrange(num_inner_itns): 80 | 81 | # sample new parameters for topics s and t ... but if it's the 82 | # last iteration and we're doing a merge, then just set the 83 | # parameters back to phi_TV[s, :] and phi_TV[t, :] 84 | 85 | if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]: 86 | phi_s_V[:] = phi_TV[s, :] 87 | phi_t_V[:] = phi_TV[t, :] 88 | else: 89 | phi_s_V = dirichlet(N_s_V + beta / V) 90 | phi_t_V = dirichlet(N_t_V + beta / V) 91 | 92 | if inner_itn == num_inner_itns - 1: 93 | 94 | acc += gammaln(N_s + beta) 95 | acc -= gammaln(N_s_V + beta / V).sum() 96 | acc += ((N_s_V + beta / V - 1) * log(phi_s_V)).sum() 97 | 98 | acc += gammaln(N_t + beta) 99 | acc -= gammaln(N_t_V + beta / V).sum() 100 | acc += ((N_t_V + beta / V - 1) * log(phi_t_V)).sum() 101 | 102 | acc -= gammaln(N_merge_t + beta) 103 | acc += gammaln(N_merge_t_V + beta / V).sum() 104 | acc -= ((N_merge_t_V + beta / V - 1) * 105 | log(phi_merge_t_V)).sum() 106 | 107 | for f in idx: 108 | 109 | # (fake) restricted Gibbs sampling scan 110 | 111 | if f in inv_z_s: 112 | inv_z_s.remove(f) 113 | N_s_V -= N_DV[f, :] 114 | N_s -= N_D[f] 115 | D_s -= 1 116 | else: 117 | inv_z_t.remove(f) 118 | N_t_V -= N_DV[f, :] 119 | N_t -= N_D[f] 120 | D_t -= 1 121 | 122 | log_dist[0] = log(D_s) 123 | log_dist[0] += (N_DV[f, :] * log(phi_s_V)).sum() 124 | 125 | log_dist[1] = log(D_t) 126 | log_dist[1] += (N_DV[f, :] * log(phi_t_V)).sum() 127 | 128 | log_dist -= log_sum_exp(log_dist) 129 | 130 | if inner_itn == num_inner_itns - 1 and z_D[d] != z_D[e]: 131 | u = 0 if z_D[f] == s else 1 132 | else: 133 | [u] = log_sample(log_dist) 134 | 135 | if u == 0: 136 | inv_z_s.add(f) 137 | N_s_V += N_DV[f, :] 138 | N_s += N_D[f] 139 | D_s += 1 140 | else: 141 | inv_z_t.add(f) 142 | N_t_V += N_DV[f, :] 143 | N_t += N_D[f] 144 | D_t += 1 145 | 146 | if inner_itn == num_inner_itns - 1: 147 | acc += log_dist[u] 148 | 149 | if z_D[d] == z_D[e]: 150 | 151 | acc *= -1.0 152 | 153 | acc += log(alpha) 154 | acc += gammaln(D_s) + gammaln(D_t) - gammaln(D_T[t]) 155 | tmp = beta / V 156 | acc += gammaln(beta) - V * gammaln(tmp) 157 | acc += (tmp - 1) * (log(phi_s_V).sum() + log(phi_t_V).sum()) 158 | acc -= (tmp - 1) * log(phi_TV[t, :]).sum() 159 | 160 | acc += (N_s_V * log(phi_s_V)).sum() + (N_t_V * log(phi_t_V)).sum() 161 | acc -= (N_TV[t, :] * log(phi_TV[t, :])).sum() 162 | 163 | if log(uniform()) < min(0.0, acc): 164 | phi_TV[s, :] = phi_s_V 165 | phi_TV[t, :] = phi_t_V 166 | z_D[list(inv_z_s)] = s 167 | z_D[list(inv_z_t)] = t 168 | inv_z_T[s] = inv_z_s 169 | inv_z_T[t] = inv_z_t 170 | N_TV[s, :] = N_s_V 171 | N_TV[t, :] = N_t_V 172 | N_T[s] = N_s 173 | N_T[t] = N_t 174 | D_T[s] = D_s 175 | D_T[t] = D_t 176 | else: 177 | active_topics.remove(s) 178 | inactive_topics.add(s) 179 | 180 | else: 181 | 182 | acc -= log(alpha) 183 | acc += gammaln(D_merge_t) - gammaln(D_T[s]) - gammaln(D_T[t]) 184 | tmp = beta / V 185 | acc += V * gammaln(tmp) - gammaln(beta) 186 | acc += (tmp - 1) * log(phi_merge_t_V).sum() 187 | acc -= (tmp - 1) * (log(phi_TV[s, :]).sum() + log(phi_TV[t, :]).sum()) 188 | 189 | acc += (N_merge_t_V * log(phi_merge_t_V)).sum() 190 | acc -= ((N_TV[s, :] * log(phi_TV[s, :])).sum() + 191 | (N_TV[t, :] * log(phi_TV[t, :])).sum()) 192 | 193 | if log(uniform()) < min(0.0, acc): 194 | phi_TV[s, :] = zeros(V) 195 | phi_TV[t, :] = phi_merge_t_V 196 | active_topics.remove(s) 197 | inactive_topics.add(s) 198 | z_D[list(inv_z_merge_t)] = t 199 | inv_z_T[s].clear() 200 | inv_z_T[t] = inv_z_merge_t 201 | N_TV[s, :] = zeros(V, dtype=int) 202 | N_TV[t, :] = N_merge_t_V 203 | N_T[s] = 0 204 | N_T[t] = N_merge_t 205 | D_T[s] = 0 206 | D_T[t] = D_merge_t 207 | 208 | 209 | def inference(N_DV, alpha, beta, z_D, num_itns, true_z_D=None): 210 | """ 211 | Nonconjugate split-merge. 212 | """ 213 | 214 | M = 10 # number of auxiliary samples 215 | 216 | D, V = N_DV.shape 217 | 218 | T = D + M - 1 # maximum number of topics 219 | 220 | N_D = N_DV.sum(1) # document lengths 221 | 222 | phi_TV = zeros((T, V)) # topic parameters 223 | 224 | inv_z_T = defaultdict(set) 225 | for d in xrange(D): 226 | inv_z_T[z_D[d]].add(d) # inverse mapping from topics to documents 227 | 228 | active_topics = set(unique(z_D)) 229 | inactive_topics = set(xrange(T)) - active_topics 230 | 231 | N_TV = zeros((T, V), dtype=int) 232 | N_T = zeros(T, dtype=int) 233 | 234 | for d in xrange(D): 235 | N_TV[z_D[d], :] += N_DV[d, :] 236 | N_T[z_D[d]] += N_D[d] 237 | 238 | D_T = bincount(z_D, minlength=T) 239 | 240 | # intialize topic parameters (necessary for Metropolis-Hastings only) 241 | 242 | for t in active_topics: 243 | phi_TV[t, :] = dirichlet(N_TV[t, :] + beta / V) 244 | 245 | for itn in xrange(num_itns): 246 | 247 | for _ in xrange(3): 248 | iteration(V, D, N_DV, N_D, alpha, beta, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T, 6) 249 | 250 | algorithm_8_iteration(V, D, N_DV, N_D, alpha, beta, M, phi_TV, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T) 251 | 252 | if true_z_D is not None: 253 | 254 | v = vi(true_z_D, z_D) 255 | 256 | print 'Itn. %d' % (itn + 1) 257 | print '%d topics' % len(active_topics) 258 | print 'VI: %f bits (%f bits max.)' % (v, log2(D)) 259 | 260 | if v < 1e-6: 261 | break 262 | 263 | return phi_TV, z_D 264 | --------------------------------------------------------------------------------