├── README.md
├── figs
├── image_clustering_results.png
├── vmf_mix_results.png
└── vmf_results.png
├── image_clustering_cifar10.py
├── mle_for_mix_vmf.py
├── mle_for_vmf.py
├── models.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Python/PyTorch Implementation of von Mises-Fisher and Its Mixture
2 |
3 | The von Mises-Fisher (vMF) is a well-known density model for directional random variables. The recent surge of the deep embedding methodologies for high-dimensional structured data such as images or texts, aimed at extracting salient directional information, can make the vMF model even more popular. In this article, we will review the vMF model and its mixture, provide detailed recipes of how to train the models, focusing on the maximum likelihood estimators, in Python/PyTorch. In particular, implementation of vMF typically suffers from the notorious numerical issue of the Bessel function evaluation in the density normalizer, especially when the dimensionality is high, and we address the issue using the MPMath library that supports arbitrary precision. For the mixture learning, we provide both minibatch-based large-scale SGD learning, as well as the EM algorithm which is a full batch estimator. For each estimator/methodology, we test our implementation on some synthetic data, while we also demonstrate the use case in a more realistic scenario of image clustering. The technical report for the details can be found in http://arxiv.org/abs/2102.05340.
4 |
5 | ---
6 |
7 |
8 | ## Features
9 |
10 | * vMF density estimation via near-closed-form Maximum Likelihood Estimator (MLE) or Stochastic Gradient Descent (SGD)
11 | * Estimation of mixture of vMFs via Expectation-Maximization (EM) or SGD
12 | * Numerically stable estimation of the vMF normalizer through [mpmath](https://mpmath.org/)'s Bessel function estimation
13 | * Examples of image clustering application
14 |
15 |
16 | ## Requirements
17 |
18 | * Python 3.7
19 | * PyTorch >= 1.4.0
20 | * NumPy >= 1.18.1
21 | * [mpmath](https://mpmath.org/) >= 1.1.0
22 | * scikit-learn >= 0.23.2 (Required for metric computation in image clustering)
23 |
24 |
25 | ## Usage examples
26 |
27 | ### 1) vMF Density Estimation
28 |
29 | We generate samples from a true vMF model, and aim to estimate the true model parameters by either the near-closed-form full-batch MLE or the SGD estimator. See the demo code in ```mle_for_vmf.py``` for the details. The results are briefly summarized in the following Table 1. As shown, both estimators are equally accurate.
30 |
31 |
32 |
33 |
34 |
35 | ### 2) vMF Mixture Estimation
36 |
37 | We also generate samples from a true mixture of three vMFs, and aim to estimate the true mixture parameters by either the EM or the SGD estimator. See the demo code in ```mle_for_mix_vmf.py``` for the details. The results are briefly summarized in the following Figure 2. As shown, both estimators are equally accurate.
38 |
39 |
40 |
41 |
42 |
43 | ### 3) Image Clustering Application
44 |
45 | We aim to cluster the images in the CIFAR-10 dataset. Our pipeline approach consists of estimating first the unit-hyperspherical latent space by minimizing the reconstruction error in the auto-encoding process, then learning a vMF mixture model in the latent space. See the demo code in ```image_clustering_cifar10.py``` for the details. The results are briefly summarized in the following Table 2. The results indicate that the vMF mixture learning approaches (EM and SGD) significantly outperform the famous k-means algorithm in terms of the two clustering performance metrics, Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI). And both the EM and SGD estimators perform equally well.
46 |
47 |
48 |
49 |
50 |
51 |
52 | ## Citation
53 | If you found this library useful in your research, please cite:
54 | ```
55 | @inproceedings{mkim2021vmf,
56 | title = {On PyTorch Implementation of Density Estimators for von Mises-Fisher and Its Mixture},
57 | author = {Kim, Minyoung},
58 | year = {2021},
59 | URL = {http://arxiv.org/abs/2102.05340},
60 | booktitle = {arXiv preprint}
61 | }
62 | ```
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/figs/image_clustering_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minyoungkim21/vmf-lib/1aa775ae3dfd31313d3abcc8c62bce9b58f8238b/figs/image_clustering_results.png
--------------------------------------------------------------------------------
/figs/vmf_mix_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minyoungkim21/vmf-lib/1aa775ae3dfd31313d3abcc8c62bce9b58f8238b/figs/vmf_mix_results.png
--------------------------------------------------------------------------------
/figs/vmf_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minyoungkim21/vmf-lib/1aa775ae3dfd31313d3abcc8c62bce9b58f8238b/figs/vmf_results.png
--------------------------------------------------------------------------------
/image_clustering_cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import sklearn.metrics
5 | import torch
6 | import torch.optim as optim
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 | from torchvision import datasets, transforms
10 | from torch.utils.data import DataLoader
11 |
12 | import models
13 | import utils
14 |
15 |
16 | '''
17 | Set options
18 | '''
19 |
20 | opts = {}
21 |
22 | # path to data and result files
23 | opts['data_path'] = os.path.join('data', 'CIFAR10') # will donwload data if not exist
24 | utils.mkdirs(opts['data_path'])
25 | opts['result_path'] = os.path.join('result', 'CIFAR10')
26 | utils.mkdirs(opts['result_path'])
27 |
28 | # options for autoencoder
29 | opts['cuda'] = True
30 | opts['z_dim'] = 100
31 | opts['batch_size'] = 128
32 | opts['ae_hid_dim'] = 512
33 | opts['ae_max_epoch'] = 100
34 | opts['ae_lr'] = 1e-4
35 |
36 | # options for EM
37 | opts_em = {}
38 | opts_em['max_iters'] = 100 # maximum number of EM iterations
39 | opts_em['rll_tol'] = 1e-5 # tolerance of relative loglik improvement
40 | opts_em['batch_size'] = 128 # batch inside E and M steps
41 |
42 | # options for SGD
43 | opts_sgd = {}
44 | opts_sgd['batch_size'] = 256 # batch size
45 | opts_sgd['max_epochs'] = 100 # maximum number of epochs
46 |
47 |
48 | class ConvEncoder(nn.Module):
49 |
50 | def __init__(self, z_dim, h_dim=256):
51 |
52 | super().__init__()
53 |
54 | self.z_dim = z_dim
55 |
56 | self.conv1 = nn.Conv2d(3, 32, 4, 2, 1)
57 | self.conv2 = nn.Conv2d(32, 32, 4, 2, 1)
58 | self.conv3 = nn.Conv2d(32, 64, 4, 2, 1)
59 | self.fc4 = nn.Linear(64*4*4, h_dim)
60 | self.fc5 = nn.Linear(h_dim, z_dim)
61 |
62 | def forward(self, x):
63 |
64 | out = F.leaky_relu(self.conv1(x))
65 | out = F.leaky_relu(self.conv2(out))
66 | out = F.leaky_relu(self.conv3(out))
67 | out = out.view(out.size(0), -1)
68 | out = F.leaky_relu(self.fc4(out))
69 | out = self.fc5(out)
70 |
71 | out = out / utils.norm(out, dim=1)
72 |
73 | return out
74 |
75 |
76 | class ConvDecoder(nn.Module):
77 |
78 | def __init__(self, z_dim, h_dim=256):
79 |
80 | super().__init__()
81 |
82 | self.z_dim = z_dim
83 |
84 | self.fc1 = nn.Linear(z_dim, h_dim)
85 | self.fc2 = nn.Linear(h_dim, 4*4*64)
86 | self.deconv3 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
87 | self.deconv4 = nn.ConvTranspose2d(32, 32, 4, 2, 1)
88 | self.deconv5 = nn.ConvTranspose2d(32, 3, 4, 2, 1)
89 |
90 | def forward(self, z):
91 |
92 | out = F.relu(self.fc1(z))
93 | out = F.relu(self.fc2(out))
94 | out = out.view(out.size(0), 64, 4, 4)
95 | out = F.relu(self.deconv3(out))
96 | out = F.relu(self.deconv4(out))
97 | out = self.deconv5(out)
98 |
99 | return out
100 |
101 |
102 | '''
103 | Train autoencoder: image -> z
104 | '''
105 |
106 | class Autoencoder(nn.Module):
107 |
108 | def __init__(self, z_dim, hid_dim):
109 |
110 | super().__init__()
111 |
112 | self.z_dim = z_dim
113 |
114 | self.encoder = ConvEncoder(z_dim, hid_dim)
115 | self.decoder = ConvDecoder(z_dim, hid_dim)
116 |
117 | def forward(self, x):
118 |
119 | x2 = self.decoder(self.encoder(x))
120 | return ((x2-x)**2).sum() / x.shape[0]
121 |
122 | def encode(self, x):
123 |
124 | return self.encoder(x)
125 |
126 |
127 | ae = Autoencoder(z_dim=opts['z_dim'], hid_dim=opts['ae_hid_dim'])
128 | if opts['cuda']:
129 | ae = ae.cuda()
130 |
131 | # load data (automatically download if not exist)
132 | dset_tr = datasets.CIFAR10(opts['data_path'], train=True, download=True, transform=transforms.ToTensor())
133 | dset_te = datasets.CIFAR10(opts['data_path'], train=False, download=True, transform=transforms.ToTensor())
134 | dl_tr = DataLoader(dset_tr, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
135 | dl_te = DataLoader(dset_te, batch_size=opts['batch_size'], shuffle=False, drop_last=False)
136 |
137 | # optimizer
138 | optimizer = optim.Adam(ae.parameters(), lr=opts['ae_lr'])
139 |
140 | for epoch in range(opts['ae_max_epoch']):
141 |
142 | train_loss = 0
143 |
144 | for batch_idx, (XX, YY) in enumerate(dl_tr):
145 |
146 | if opts['cuda']:
147 | XX = XX.cuda()
148 |
149 | optimizer.zero_grad()
150 | loss = ae(XX)
151 | loss.backward()
152 | optimizer.step()
153 |
154 | train_loss += loss.item() * XX.shape[0]
155 |
156 | if batch_idx % 20 == 0:
157 | prn_str = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
158 | epoch, batch_idx*XX.shape[0], len(dl_tr.dataset),
159 | 100.*batch_idx/len(dl_tr), loss.item() )
160 | print(prn_str)
161 |
162 | epoch_loss = train_loss / len(dl_tr.dataset)
163 | prn_str = '====> Epoch: {} Average loss: {:.4f}'.format(epoch, epoch_loss)
164 | print(prn_str)
165 |
166 | # save the learned autoencoder
167 | torch.save(ae.state_dict(), os.path.join(opts['result_path'], 'ae_cifar10.pth'))
168 |
169 |
170 | # load autoencoder (if necessary)
171 | ae = Autoencoder(z_dim=opts['z_dim'], hid_dim=opts['ae_hid_dim'])
172 | ae.load_state_dict( torch.load(os.path.join(opts['result_path'], 'ae_cifar10.pth')) )
173 | if opts['cuda']:
174 | ae = ae.cuda()
175 |
176 | # embed all images (and their labels)
177 | dl_tr = DataLoader(dset_tr, batch_size=opts['batch_size'], shuffle=False, drop_last=False)
178 | dl_te = DataLoader(dset_te, batch_size=opts['batch_size'], shuffle=False, drop_last=False)
179 | with torch.no_grad():
180 | for split in ['tr', 'te']:
181 | dl = dl_tr if split=='tr' else dl_te
182 | for ii, (xx, yy) in enumerate(dl):
183 | if opts['cuda']:
184 | xx = xx.cuda()
185 | zz = ae.encoder(xx).cpu()
186 | if ii==0:
187 | ZZ, YY = zz, yy
188 | else:
189 | ZZ, YY = torch.cat([ZZ, zz], dim=0), torch.cat([YY, yy], dim=0)
190 | torch.save([ZZ, YY], os.path.join(opts['result_path'], 'embeds_%s_cifar10.pth' % split))
191 |
192 |
193 | '''
194 | Do mixture learning in the embedded space
195 | '''
196 |
197 | seed = 1234
198 |
199 | random.seed(seed)
200 | np.random.seed(seed)
201 | torch.manual_seed(seed)
202 | torch.cuda.manual_seed(seed)
203 | torch.cuda.manual_seed_all(seed)
204 |
205 | # load embedded data and labels
206 | Ztr, Ytr = torch.load(os.path.join(opts['result_path'], 'embeds_tr_cifar10.pth'))
207 | Zte, Yte = torch.load(os.path.join(opts['result_path'], 'embeds_te_cifar10.pth'))
208 |
209 |
210 | '''
211 | EM (done on CPU)
212 | '''
213 |
214 | # randomly initialized mixture
215 | mix = models.MixvMF(x_dim=opts['z_dim'], order=10)
216 |
217 | # data loader for training samples
218 | dataloader = DataLoader(Ztr, batch_size=opts_em['batch_size'], shuffle=False, drop_last=False)
219 |
220 | # EM learning
221 | ll_old = -np.inf
222 | with torch.no_grad():
223 |
224 | for steps in range(opts_em['max_iters']):
225 |
226 | # E-step
227 | logalpha, mus, kappas = mix.get_params()
228 | logliks, logpcs = mix(Ztr)
229 | ll = logliks.sum()
230 | jll = logalpha.unsqueeze(0) + logpcs
231 | qz = jll.log_softmax(1).exp()
232 |
233 | if steps==0:
234 | prn_str = '[Before EM starts] loglik = %.4f\n' % ll.item()
235 | else:
236 | prn_str = '[Steps %03d] loglik (before M-step) = %.4f\n' % (steps, ll.item())
237 | print(prn_str)
238 |
239 | # tolerance check
240 | if steps>0:
241 | rll = (ll-ll_old).abs() / (ll_old.abs()+utils.realmin)
242 | if rll < opts_em['rll_tol']:
243 | prn_str = 'Stop EM since the relative improvement '
244 | prn_str += '(%.6f) < tolerance (%.6f)\n' % (rll.item(), opts_em['rll_tol'])
245 | print(prn_str)
246 | break
247 |
248 | ll_old = ll
249 |
250 | # M-step
251 | qzx = ( qz.unsqueeze(2) * Ztr.unsqueeze(1) ).sum(0)
252 | qzx_norms = utils.norm(qzx, dim=1)
253 | mus_new = qzx / qzx_norms
254 | Rs = qzx_norms[:,0] / (qz.sum(0) + utils.realmin)
255 | kappas_new = (mix.x_dim*Rs - Rs**3) / (1 - Rs**2)
256 | alpha_new = qz.sum(0) / Ztr.shape[0]
257 |
258 | # assign new params
259 | mix.set_params(alpha_new, mus_new, kappas_new)
260 |
261 | # save model
262 | mix_em = mix
263 | torch.save(mix_em.state_dict(), os.path.join(opts['result_path'], 'mix_em_cifar10.pth'))
264 |
265 | logliks, logpcs = mix_em(Ztr)
266 | ll = logliks.sum()
267 | prn_str = '[Training done] loglik = %.4f\n' % ll.item()
268 | print(prn_str)
269 |
270 | # cluster label predictions
271 | with torch.no_grad():
272 | logalpha, mus, kappas = mix_em.get_params()
273 | clabs_em = ( logalpha.unsqueeze(0) + logpcs ).max(1)[1]
274 | logliks_te, logpcs_te = mix_em(Zte)
275 | clabs_te_em = ( logalpha.unsqueeze(0) + logpcs_te ).max(1)[1]
276 |
277 | # clustering metrics
278 | metrics_em = {}
279 | metrics_em['ARI'] = sklearn.metrics.adjusted_rand_score(Ytr, clabs_em)
280 | metrics_em['ARI_te'] = sklearn.metrics.adjusted_rand_score(Yte, clabs_te_em)
281 | metrics_em['NMI'] = sklearn.metrics.normalized_mutual_info_score(Ytr, clabs_em)
282 | metrics_em['NMI_te'] = sklearn.metrics.normalized_mutual_info_score(Yte, clabs_te_em)
283 |
284 | prn_str = '== EM estimator ==\n'
285 | prn_str += 'Test: ARI = %.4f, NMI = %.4f\n' % (metrics_em['ARI_te'], metrics_em['NMI_te'])
286 | prn_str += 'Train: ARI = %.4f, NMI = %.4f\n' % (metrics_em['ARI'], metrics_em['NMI'])
287 | print(prn_str)
288 |
289 |
290 | '''
291 | SGD
292 | '''
293 |
294 | # data loader for training samples
295 | dataloader = DataLoader(Ztr, batch_size=opts_sgd['batch_size'], shuffle=False, drop_last=False)
296 |
297 | # create a model
298 | mix = models.MixvMF(x_dim=opts['z_dim'], order=10)
299 | mix = mix.cuda()
300 |
301 | # create optimizers and set optim params
302 | params = list(mix.parameters())
303 | optim = torch.optim.Adam(params, lr=1e-1, betas=[0.9, 0.99])
304 | lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.95)
305 |
306 | # SGD training
307 | for epoch in range(opts_sgd['max_epochs']):
308 |
309 | ll, nsamps = 0.0, 0
310 | for ii, data in enumerate(dataloader):
311 | data = data.cuda()
312 | logliks, _ = mix(data)
313 | anll = -logliks.mean()
314 | optim.zero_grad()
315 | anll.backward()
316 | optim.step()
317 | ll += -anll.item()*data.shape[0]
318 |
319 | lr_sched.step()
320 |
321 | if (epoch+1) % 1 == 0:
322 | prn_str = '[Epoch %03d] loglik (accumulated) = %.4f\n' % (epoch, ll)
323 | print(prn_str)
324 |
325 | # save model
326 | mix_sgd = mix
327 | mix_sgd = mix_sgd.cpu()
328 | torch.save(mix_sgd.state_dict(), os.path.join(opts['result_path'], 'mix_sgd_cifar10.pth'))
329 |
330 | # cluster label predictions
331 | with torch.no_grad():
332 |
333 | logalpha, mus, kappas = mix_sgd.get_params()
334 |
335 | # prediction on train data
336 | for ii, data in enumerate(dataloader):
337 | logliks_, logpcs_ = mix_sgd(data)
338 | if ii==0:
339 | logliks, logpcs = logliks_, logpcs_
340 | else:
341 | logliks = torch.cat([logliks, logliks_], dim=0)
342 | logpcs = torch.cat([logpcs, logpcs_], dim=0)
343 | clabs_sgd = ( logalpha.unsqueeze(0) + logpcs ).max(1)[1]
344 |
345 | # prediction on test data
346 | dataloader = DataLoader(Zte, batch_size=opts_sgd['batch_size'], shuffle=False, drop_last=False)
347 | for ii, data in enumerate(dataloader):
348 | logliks_, logpcs_ = mix_sgd(data)
349 | if ii==0:
350 | logliks_te, logpcs_te = logliks_, logpcs_
351 | else:
352 | logliks_te = torch.cat([logliks_te, logliks_], dim=0)
353 | logpcs_te = torch.cat([logpcs_te, logpcs_], dim=0)
354 | clabs_te_sgd = ( logalpha.unsqueeze(0) + logpcs_te ).max(1)[1]
355 |
356 | # clustering metrics
357 | metrics_sgd = {}
358 | metrics_sgd['ARI'] = sklearn.metrics.adjusted_rand_score(Ytr, clabs_sgd)
359 | metrics_sgd['ARI_te'] = sklearn.metrics.adjusted_rand_score(Yte, clabs_te_sgd)
360 | metrics_sgd['NMI'] = sklearn.metrics.normalized_mutual_info_score(Ytr, clabs_sgd)
361 | metrics_sgd['NMI_te'] = sklearn.metrics.normalized_mutual_info_score(Yte, clabs_te_sgd)
362 |
363 | prn_str = '== SGD estimator ==\n'
364 | prn_str += 'Test: ARI = %.4f, NMI = %.4f\n' % (metrics_sgd['ARI_te'], metrics_sgd['NMI_te'])
365 | prn_str += 'Train: ARI = %.4f, NMI = %.4f\n' % (metrics_sgd['ARI'], metrics_sgd['NMI'])
366 | print(prn_str)
367 |
368 |
--------------------------------------------------------------------------------
/mle_for_mix_vmf.py:
--------------------------------------------------------------------------------
1 | import random
2 | import itertools
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | import models
8 | import utils
9 |
10 |
11 | seed = 1234
12 |
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | '''
21 | Define a true vMF mixture model
22 | '''
23 |
24 | mix_true = models.MixvMF(x_dim=5, order=3)
25 |
26 | mus_true = [
27 | torch.tensor([0.3, -1.2, 2.3, 0.4, 2.1]),
28 | torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]),
29 | torch.tensor([-0.3, 1.2, -2.3, -0.4, -2.1]),
30 | ]
31 | mus_true = [ mu / utils.norm(mu, dim=0) for mu in mus_true ]
32 | kappas_true = [
33 | torch.tensor(100.0),
34 | torch.tensor(50.0),
35 | torch.tensor(100.0)
36 | ]
37 | alpha_true = torch.tensor([0.3, 0.4, 0.3])
38 |
39 | mix_true.set_params(alpha_true, mus_true, kappas_true)
40 |
41 | # sample data from true mixture
42 | samples, cids = mix_true.sample(N=1000, rsf=1)
43 |
44 |
45 | '''
46 | Full-batch EM learning
47 | '''
48 |
49 | opts = {}
50 | opts['max_iters'] = 100 # maximum number of EM iterations
51 | opts['rll_tol'] = 1e-5 # tolerance of relative loglik improvement
52 |
53 | # randomly initialized mixture
54 | mix = models.MixvMF(x_dim=5, order=3)
55 |
56 | # EM learning
57 | ll_old = -np.inf
58 | with torch.no_grad():
59 |
60 | for steps in range(opts['max_iters']):
61 |
62 | # E-step
63 | logalpha, mus, kappas = mix.get_params()
64 | logliks, logpcs = mix(samples)
65 | ll = logliks.sum()
66 | jll = logalpha.unsqueeze(0) + logpcs
67 | qz = jll.log_softmax(1).exp()
68 |
69 | if steps==0:
70 | prn_str = '[Before EM starts] loglik = %.4f\n' % ll.item()
71 | else:
72 | prn_str = '[Steps %03d] loglik (before M-step) = %.4f\n' % (steps, ll.item())
73 | print(prn_str)
74 |
75 | # tolerance check
76 | if steps>0:
77 | rll = (ll-ll_old).abs() / (ll_old.abs()+utils.realmin)
78 | if rll < opts['rll_tol']:
79 | prn_str = 'Stop EM since the relative improvement '
80 | prn_str += '(%.6f) < tolerance (%.6f)\n' % (rll.item(), opts['rll_tol'])
81 | print(prn_str)
82 | break
83 |
84 | ll_old = ll
85 |
86 | # M-step
87 | qzx = ( qz.unsqueeze(2) * samples.unsqueeze(1) ).sum(0)
88 | qzx_norms = utils.norm(qzx, dim=1)
89 | mus_new = qzx / qzx_norms
90 | Rs = qzx_norms[:,0] / (qz.sum(0) + utils.realmin)
91 | kappas_new = (mix.x_dim*Rs - Rs**3) / (1 - Rs**2)
92 | alpha_new = qz.sum(0) / samples.shape[0]
93 |
94 | # assign new params
95 | mix.set_params(alpha_new, mus_new, kappas_new)
96 |
97 | logliks, logpcs = mix(samples)
98 | ll = logliks.sum()
99 | prn_str = '[Training done] loglik = %.4f\n' % ll.item()
100 | print(prn_str)
101 |
102 | # find the best matching permutations of components
103 | print('Find the best matching permutations of components')
104 | with torch.no_grad():
105 | logalpha, mus, kappas = mix.get_params()
106 | alpha = logalpha.exp()
107 | perms = list(itertools.permutations(range(mix.order)))
108 | best_perm, best_error = None, 1e10
109 | for perm in perms:
110 | perm = np.array(perm)
111 | error_alpha = (alpha[perm] - alpha_true).abs().sum()
112 | error_mus = (mus[perm,:] - torch.stack(mus_true, dim=0)).abs().sum()
113 | error_kappas = (kappas[perm] - torch.stack(kappas_true, dim=0)).abs().sum()
114 | error = (error_alpha + error_mus + error_kappas).item()
115 | print('perm = %s: error = %.4f' % (perm, error))
116 | if error < best_error:
117 | best_perm, best_error = perm, error
118 | print('best perm has changed to: %s' % best_perm)
119 |
120 | print('For the best components permutation:')
121 | print('----------')
122 | print('alpha_true = %s' % alpha_true)
123 | print('alpha = %s' % alpha[best_perm])
124 | print('error in alpha = %.4f' % (alpha[best_perm] - alpha_true).abs().sum().item())
125 | print('----------')
126 | print('mus_true = %s' % torch.stack(mus_true, dim=0))
127 | print('mus = %s' % mus[best_perm])
128 | print('error in mus = %.4f' % (mus[best_perm,:] - torch.stack(mus_true, dim=0)).abs().sum().item())
129 | print('----------')
130 | print('kappas_true = %s' % torch.stack(kappas_true, dim=0))
131 | print('kappas = %s' % kappas[best_perm])
132 | print('error in kappas = %.4f' % (kappas[best_perm] - torch.stack(kappas_true, dim=0)).abs().sum().item())
133 | print('----------')
134 |
135 | # save model
136 | mix_em = mix
137 |
138 |
139 | '''
140 | SGD-based ML estimator
141 | '''
142 |
143 | B = 64 # batch size
144 | max_epochs = 100 # maximum number of epochs
145 |
146 | dataloader = DataLoader(samples, batch_size=B, shuffle=True, drop_last=True)
147 |
148 | # create a model
149 | mix = models.MixvMF(x_dim=5, order=3)
150 | mix = mix.cuda()
151 |
152 | # create optimizers and set optim params
153 | params = list(mix.parameters())
154 | optim = torch.optim.Adam(params, lr=1e-1, betas=[0.9, 0.99])
155 | lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.95)
156 |
157 | # SGD training
158 | for epoch in range(max_epochs):
159 |
160 | ll, nsamps = 0.0, 0
161 | for ii, data in enumerate(dataloader):
162 | data = data.cuda()
163 | logliks, _ = mix(data)
164 | anll = -logliks.mean()
165 | optim.zero_grad()
166 | anll.backward()
167 | optim.step()
168 | ll += -anll.item()*data.shape[0]
169 |
170 | lr_sched.step()
171 |
172 | if (epoch+1) % 1 == 0:
173 | prn_str = '[Epoch %03d] loglik (accumulated) = %.4f\n' % (epoch, ll)
174 | print(prn_str)
175 |
176 | mix = mix.cpu()
177 |
178 | # find the best matching permutations of components
179 | print('Find the best matching permutations of components')
180 | with torch.no_grad():
181 | logalpha, mus, kappas = mix.get_params() # logalpha = M-dim
182 | alpha = logalpha.exp()
183 | perms = list(itertools.permutations(range(mix.order)))
184 | best_perm, best_error = None, 1e10
185 | for perm in perms:
186 | perm = np.array(perm)
187 | error_alpha = (alpha[perm] - alpha_true).abs().sum()
188 | error_mus = (mus[perm,:] - torch.stack(mus_true, dim=0)).abs().sum()
189 | error_kappas = (kappas[perm] - torch.stack(kappas_true, dim=0)).abs().sum()
190 | error = (error_alpha + error_mus + error_kappas).item()
191 | print('perm = %s: error = %.4f' % (perm, error))
192 | if error < best_error:
193 | best_perm, best_error = perm, error
194 | print('best perm has changed to: %s' % best_perm)
195 |
196 | print('For the best components permutation:')
197 | print('----------')
198 | print('alpha_true = %s' % alpha_true)
199 | print('alpha = %s' % alpha[best_perm])
200 | print('error in alpha = %.4f' % (alpha[best_perm] - alpha_true).abs().sum().item())
201 | print('----------')
202 | print('mus_true = %s' % torch.stack(mus_true, dim=0))
203 | print('mus = %s' % mus[best_perm])
204 | print('error in mus = %.4f' % (mus[best_perm,:] - torch.stack(mus_true, dim=0)).abs().sum().item())
205 | print('----------')
206 | print('kappas_true = %s' % torch.stack(kappas_true, dim=0))
207 | print('kappas = %s' % kappas[best_perm])
208 | print('error in kappas = %.4f' % (kappas[best_perm] - torch.stack(kappas_true, dim=0)).abs().sum().item())
209 | print('----------')
210 |
211 | # save model
212 | mix_sgd = mix
213 |
214 |
--------------------------------------------------------------------------------
/mle_for_vmf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | import models
5 | import utils
6 |
7 |
8 | '''
9 | Define a true vMF model
10 | '''
11 |
12 | mu_true = torch.zeros(100) # 5 or 20
13 | mu_true[0] = 1.0
14 | mu_true = mu_true / utils.norm(mu_true, dim=0)
15 | kappa_true = torch.tensor(500.0) # 50.0
16 | vmf_true = models.vMF(x_dim=mu_true.shape[0])
17 | vmf_true.set_params(mu=mu_true, kappa=kappa_true)
18 | vmf_true = vmf_true.cuda()
19 |
20 | # sample from true vMF model
21 | samples = vmf_true.sample(N=10000, rsf=1)
22 |
23 |
24 | '''
25 | Full-batch ML estimator
26 | '''
27 |
28 | xm = samples.mean(0)
29 | xm_norm = (xm**2).sum().sqrt()
30 | mu0 = xm / xm_norm
31 | kappa0 = (len(xm)*xm_norm - xm_norm**3) / (1-xm_norm**2)
32 |
33 | mu_err = ((mu0.cpu() - mu_true)**2).sum().item() # relative error
34 | kappa_err = (kappa0.cpu() - kappa_true).abs().item() / kappa_true.item()
35 | prn_str = '== Batch ML estimator ==\n'
36 | prn_str += 'mu = %s (error = %.8f)\n' % (mu0.cpu().numpy(), mu_err)
37 | prn_str += 'kappa = %s (error = %.8f)\n' % (kappa0.cpu().numpy(), kappa_err)
38 | print(prn_str)
39 |
40 |
41 | '''
42 | SGD-based ML estimator
43 | '''
44 |
45 | B = 128 # batch size
46 | max_epochs = 100 # maximum number of epochs
47 |
48 | dataloader = DataLoader(samples, batch_size=B, shuffle=True, drop_last=True)
49 |
50 | # create a model
51 | vmf = models.vMF(x_dim=mu_true.shape[0])
52 | vmf = vmf.cuda()
53 |
54 | # create optimizers and set optim params
55 | params = list(vmf.parameters())
56 | optim = torch.optim.Adam(params, lr=1e-2, betas=[0.9, 0.99])
57 | lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.95)
58 |
59 | # error of the initial model
60 | with torch.no_grad():
61 | mu, kappa = vmf.get_params()
62 | mu_err = (mu.cpu() - mu_true).abs().mean().item() # dim-wise absolute error
63 | kappa_err = (kappa.cpu() - kappa_true).abs().item()
64 | prn_str = '== Before training starts ==\n'
65 | prn_str += 'mu = %s (error = %.6f)\n' % (mu.cpu().numpy(), mu_err)
66 | prn_str += 'kappa = %s (error = %.6f)\n' % (kappa.cpu().numpy(), kappa_err)
67 | print(prn_str)
68 |
69 | # SGD training
70 | for epoch in range(max_epochs):
71 |
72 | obj, nsamps = 0.0, 0
73 | for ii, data in enumerate(dataloader):
74 | enll = -vmf(data).mean()
75 | optim.zero_grad()
76 | enll.backward()
77 | optim.step()
78 | obj += enll.item()*data.shape[0]
79 | nsamps += data.shape[0]
80 | obj /= nsamps
81 |
82 | lr_sched.step()
83 |
84 | if (epoch+1) % 10 == 0:
85 | with torch.no_grad():
86 | mu, kappa = vmf.get_params()
87 | mu_err = ((mu.cpu() - mu_true)**2).sum().item() # relative error
88 | kappa_err = (kappa.cpu() - kappa_true).abs().item() / kappa_true.item()
89 | prn_str = '== After epoch %d ==\n' % epoch
90 | prn_str += 'Expectected negative log-likelihood = %.4f\n' % enll.item()
91 | prn_str += 'mu = %s (error = %.8f)\n' % (mu.cpu().numpy(), mu_err)
92 | prn_str += 'kappa = %s (error = %.8f)\n' % (kappa.cpu().numpy(), kappa_err)
93 | print(prn_str)
94 |
95 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mpmath
3 | import torch
4 | import torch.nn as nn
5 |
6 | import utils
7 |
8 |
9 | class vMFLogPartition(torch.autograd.Function):
10 |
11 | '''
12 | Evaluates log C_d(kappa) for vMF density
13 | Allows autograd wrt kappa
14 | '''
15 |
16 | besseli = np.vectorize(mpmath.besseli)
17 | log = np.vectorize(mpmath.log)
18 | nhlog2pi = -0.5 * np.log(2*np.pi)
19 |
20 | @staticmethod
21 | def forward(ctx, *args):
22 |
23 | '''
24 | Args:
25 | args[0] = d; scalar (> 0)
26 | args[1] = kappa; (> 0) torch tensor of any shape
27 |
28 | Returns:
29 | logC = log C_d(kappa); torch tensor of the same shape as kappa
30 | '''
31 |
32 | d = args[0]
33 | kappa = args[1]
34 |
35 | s = 0.5*d - 1
36 |
37 | # log I_s(kappa)
38 | mp_kappa = mpmath.mpf(1.0) * kappa.detach().cpu().numpy()
39 | mp_logI = vMFLogPartition.log( vMFLogPartition.besseli(s, mp_kappa) )
40 | logI = torch.from_numpy( np.array(mp_logI.tolist(), dtype=float) ).to(kappa)
41 |
42 | if (logI!=logI).sum().item() > 0: # there is nan
43 | raise ValueError('NaN is detected from the output of log-besseli()')
44 |
45 | logC = d * vMFLogPartition.nhlog2pi + s * kappa.log() - logI
46 |
47 | # save for backard()
48 | ctx.s, ctx.mp_kappa, ctx.logI = s, mp_kappa, logI
49 |
50 | return logC
51 |
52 | @staticmethod
53 | def backward(ctx, *grad_output):
54 |
55 | s, mp_kappa, logI = ctx.s, ctx.mp_kappa, ctx.logI
56 |
57 | # log I_{s+1}(kappa)
58 | mp_logI2 = vMFLogPartition.log( vMFLogPartition.besseli(s+1, mp_kappa) )
59 | logI2 = torch.from_numpy( np.array(mp_logI2.tolist(), dtype=float) ).to(logI)
60 |
61 | if (logI2!=logI2).sum().item() > 0: # there is nan
62 | raise ValueError('NaN is detected from the output of log-besseli()')
63 |
64 | dlogC_dkappa = -(logI2 - logI).exp()
65 |
66 | return None, grad_output[0] * dlogC_dkappa
67 |
68 |
69 | class vMF(nn.Module):
70 |
71 | '''
72 | vMF(x; mu, kappa)
73 | '''
74 |
75 | def __init__(self, x_dim, reg=1e-6):
76 |
77 | super(vMF, self).__init__()
78 |
79 | self.x_dim = x_dim
80 |
81 | self.mu_unnorm = nn.Parameter(torch.randn(x_dim))
82 | self.logkappa = nn.Parameter(0.01*torch.randn([]))
83 |
84 | self.reg = reg
85 |
86 | def set_params(self, mu, kappa):
87 |
88 | with torch.no_grad():
89 | self.mu_unnorm.copy_(mu)
90 | self.logkappa.copy_(torch.log(kappa+utils.realmin))
91 |
92 | def get_params(self):
93 |
94 | mu = self.mu_unnorm / utils.norm(self.mu_unnorm)
95 | kappa = self.logkappa.exp() + self.reg
96 |
97 | return mu, kappa
98 |
99 | def forward(self, x, utc=False):
100 |
101 | '''
102 | Evaluate logliks, log p(x)
103 |
104 | Args:
105 | x = batch for x
106 | utc = whether to evaluate only up to constant or exactly
107 | if True, no log-partition computed
108 | if False, exact loglik computed
109 |
110 | Returns:
111 | logliks = log p(x)
112 | '''
113 |
114 | mu, kappa = self.get_params()
115 |
116 | dotp = (mu.unsqueeze(0) * x).sum(1)
117 |
118 | if utc:
119 | logliks = kappa * dotp
120 | else:
121 | logC = vMFLogPartition.apply(self.x_dim, kappa)
122 | logliks = kappa * dotp + logC
123 |
124 | return logliks
125 |
126 | def sample(self, N=1, rsf=10):
127 |
128 | '''
129 | Args:
130 | N = number of samples to generate
131 | rsf = multiplicative factor for extra backup samples in rejection sampling
132 |
133 | Returns:
134 | samples; N samples generated
135 |
136 | Notes:
137 | no autodiff
138 | '''
139 |
140 | d = self.x_dim
141 |
142 | with torch.no_grad():
143 |
144 | mu, kappa = self.get_params()
145 |
146 | # Step-1: Sample uniform unit vectors in R^{d-1}
147 | v = torch.randn(N, d-1).to(mu)
148 | v = v / utils.norm(v, dim=1)
149 |
150 | # Step-2: Sample v0
151 | kmr = np.sqrt( 4*kappa.item()**2 + (d-1)**2 )
152 | bb = (kmr - 2*kappa) / (d-1)
153 | aa = (kmr + 2*kappa + d - 1) / 4
154 | dd = (4*aa*bb)/(1+bb) - (d-1)*np.log(d-1)
155 | beta = torch.distributions.Beta( torch.tensor(0.5*(d-1)), torch.tensor(0.5*(d-1)) )
156 | uniform = torch.distributions.Uniform(0.0, 1.0)
157 | v0 = torch.tensor([]).to(mu)
158 | while len(v0) < N:
159 | eps = beta.sample([1, rsf*(N-len(v0))]).squeeze().to(mu)
160 | uns = uniform.sample([1, rsf*(N-len(v0))]).squeeze().to(mu)
161 | w0 = (1 - (1+bb)*eps) / (1 - (1-bb)*eps)
162 | t0 = (2*aa*bb) / (1 - (1-bb)*eps)
163 | det = (d-1)*t0.log() - t0 + dd - uns.log()
164 | v0 = torch.cat([v0, torch.tensor(w0[det>=0]).to(mu)])
165 | if len(v0) > N:
166 | v0 = v0[:N]
167 | break
168 | v0 = v0.reshape([N,1])
169 |
170 | # Step-3: Form x = [v0; sqrt(1-v0^2)*v]
171 | samples = torch.cat([v0, (1-v0**2).sqrt()*v], 1)
172 |
173 | # Setup-4: Householder transformation
174 | e1mu = torch.zeros(d,1).to(mu); e1mu[0,0] = 1.0
175 | e1mu = e1mu - mu if len(mu.shape)==2 else e1mu - mu.unsqueeze(1)
176 | e1mu = e1mu / utils.norm(e1mu, dim=0)
177 | samples = samples - 2 * (samples @ e1mu) @ e1mu.t()
178 |
179 | return samples
180 |
181 |
182 | class MixvMF(nn.Module):
183 |
184 | '''
185 | MixvMF(x) = \sum_{m=1}^M \alpha_m vMF(x; mu_m, kappa_m)
186 | '''
187 |
188 | def __init__(self, x_dim, order, reg=1e-6):
189 |
190 | super(MixvMF, self).__init__()
191 |
192 | self.x_dim = x_dim
193 | self.order = order
194 | self.reg = reg
195 |
196 | self.alpha_logit = nn.Parameter(0.01*torch.randn(order))
197 | self.comps = nn.ModuleList(
198 | [ vMF(x_dim, reg) for _ in range(order) ]
199 | )
200 |
201 | def set_params(self, alpha, mus, kappas):
202 |
203 | with torch.no_grad():
204 | self.alpha_logit.copy_(torch.log(alpha+utils.realmin))
205 | for m in range(self.order):
206 | self.comps[m].mu_unnorm.copy_(mus[m])
207 | self.comps[m].logkappa.copy_(torch.log(kappas[m]+utils.realmin))
208 |
209 | def get_params(self):
210 |
211 | logalpha = self.alpha_logit.log_softmax(0)
212 |
213 | mus, kappas = [], []
214 | for m in range(self.order):
215 | mu, kappa = self.comps[m].get_params()
216 | mus.append(mu)
217 | kappas.append(kappa)
218 |
219 | mus = torch.stack(mus, axis=0)
220 | kappas = torch.stack(kappas, axis=0)
221 |
222 | return logalpha, mus, kappas
223 |
224 | def forward(self, x):
225 |
226 | '''
227 | Evaluate logliks, log p(x)
228 |
229 | Args:
230 | x = batch for x
231 |
232 | Returns:
233 | logliks = log p(x)
234 | logpcs = log p(x|c=m)
235 | '''
236 |
237 | logalpha = self.alpha_logit.log_softmax(0)
238 |
239 | logpcs = []
240 | for m in range(self.order):
241 | logpcs.append(self.comps[m](x))
242 | logpcs = torch.stack(logpcs, dim=1)
243 |
244 | logliks = (logalpha.unsqueeze(0) + logpcs).logsumexp(1)
245 |
246 | return logliks, logpcs
247 |
248 | def sample(self, N=1, rsf=10):
249 |
250 | '''
251 | Args:
252 | N = number of samples to generate
253 | rsf = multiplicative factor for extra backup samples in rejection sampling
254 | (used in sampling from vMF)
255 |
256 | Returns:
257 | samples = N samples generated
258 | cids = which components the samples come from; N-dim {0,1,...,M-1}-valued
259 |
260 | Notes:
261 | no autodiff
262 | '''
263 |
264 | with torch.no_grad():
265 |
266 | alpha = self.alpha_logit.log_softmax(0).exp()
267 |
268 | cids = torch.multinomial(alpha, N, replacement=True)
269 |
270 | samples = torch.zeros(N, self.x_dim)
271 | for c in range(self.order):
272 | Nc = (cids==c).sum()
273 | if Nc > 0:
274 | samples[cids==c,:] = self.comps[c].sample(N=Nc, rsf=rsf)
275 |
276 | return samples, cids
277 |
278 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | realmin = 1e-10
5 |
6 |
7 | def norm(input, p=2, dim=0, eps=1e-12):
8 | return input.norm(p, dim, keepdim=True).clamp(min=eps).expand_as(input)
9 |
10 |
11 | def mkdirs(path):
12 | if not os.path.exists(path):
13 | os.makedirs(path)
14 |
15 |
--------------------------------------------------------------------------------