├── README.md ├── figs ├── image_clustering_results.png ├── vmf_mix_results.png └── vmf_results.png ├── image_clustering_cifar10.py ├── mle_for_mix_vmf.py ├── mle_for_vmf.py ├── models.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Python/PyTorch Implementation of von Mises-Fisher and Its Mixture 2 | 3 | The von Mises-Fisher (vMF) is a well-known density model for directional random variables. The recent surge of the deep embedding methodologies for high-dimensional structured data such as images or texts, aimed at extracting salient directional information, can make the vMF model even more popular. In this article, we will review the vMF model and its mixture, provide detailed recipes of how to train the models, focusing on the maximum likelihood estimators, in Python/PyTorch. In particular, implementation of vMF typically suffers from the notorious numerical issue of the Bessel function evaluation in the density normalizer, especially when the dimensionality is high, and we address the issue using the MPMath library that supports arbitrary precision. For the mixture learning, we provide both minibatch-based large-scale SGD learning, as well as the EM algorithm which is a full batch estimator. For each estimator/methodology, we test our implementation on some synthetic data, while we also demonstrate the use case in a more realistic scenario of image clustering. The technical report for the details can be found in http://arxiv.org/abs/2102.05340. 4 | 5 | --- 6 | 7 | 8 | ## Features 9 | 10 | * vMF density estimation via near-closed-form Maximum Likelihood Estimator (MLE) or Stochastic Gradient Descent (SGD) 11 | * Estimation of mixture of vMFs via Expectation-Maximization (EM) or SGD 12 | * Numerically stable estimation of the vMF normalizer through [mpmath](https://mpmath.org/)'s Bessel function estimation 13 | * Examples of image clustering application 14 | 15 | 16 | ## Requirements 17 | 18 | * Python 3.7 19 | * PyTorch >= 1.4.0 20 | * NumPy >= 1.18.1 21 | * [mpmath](https://mpmath.org/) >= 1.1.0 22 | * scikit-learn >= 0.23.2 (Required for metric computation in image clustering) 23 | 24 | 25 | ## Usage examples 26 | 27 | ### 1) vMF Density Estimation 28 | 29 | We generate samples from a true vMF model, and aim to estimate the true model parameters by either the near-closed-form full-batch MLE or the SGD estimator. See the demo code in ```mle_for_vmf.py``` for the details. The results are briefly summarized in the following Table 1. As shown, both estimators are equally accurate. 30 | 31 |

32 | 33 |

34 | 35 | ### 2) vMF Mixture Estimation 36 | 37 | We also generate samples from a true mixture of three vMFs, and aim to estimate the true mixture parameters by either the EM or the SGD estimator. See the demo code in ```mle_for_mix_vmf.py``` for the details. The results are briefly summarized in the following Figure 2. As shown, both estimators are equally accurate. 38 | 39 |

40 | 41 |

42 | 43 | ### 3) Image Clustering Application 44 | 45 | We aim to cluster the images in the CIFAR-10 dataset. Our pipeline approach consists of estimating first the unit-hyperspherical latent space by minimizing the reconstruction error in the auto-encoding process, then learning a vMF mixture model in the latent space. See the demo code in ```image_clustering_cifar10.py``` for the details. The results are briefly summarized in the following Table 2. The results indicate that the vMF mixture learning approaches (EM and SGD) significantly outperform the famous k-means algorithm in terms of the two clustering performance metrics, Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI). And both the EM and SGD estimators perform equally well. 46 | 47 |

48 | 49 |

50 | 51 | 52 | ## Citation 53 | If you found this library useful in your research, please cite: 54 | ``` 55 | @inproceedings{mkim2021vmf, 56 | title = {On PyTorch Implementation of Density Estimators for von Mises-Fisher and Its Mixture}, 57 | author = {Kim, Minyoung}, 58 | year = {2021}, 59 | URL = {http://arxiv.org/abs/2102.05340}, 60 | booktitle = {arXiv preprint} 61 | } 62 | ``` 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /figs/image_clustering_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minyoungkim21/vmf-lib/1aa775ae3dfd31313d3abcc8c62bce9b58f8238b/figs/image_clustering_results.png -------------------------------------------------------------------------------- /figs/vmf_mix_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minyoungkim21/vmf-lib/1aa775ae3dfd31313d3abcc8c62bce9b58f8238b/figs/vmf_mix_results.png -------------------------------------------------------------------------------- /figs/vmf_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minyoungkim21/vmf-lib/1aa775ae3dfd31313d3abcc8c62bce9b58f8238b/figs/vmf_results.png -------------------------------------------------------------------------------- /image_clustering_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import sklearn.metrics 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torchvision import datasets, transforms 10 | from torch.utils.data import DataLoader 11 | 12 | import models 13 | import utils 14 | 15 | 16 | ''' 17 | Set options 18 | ''' 19 | 20 | opts = {} 21 | 22 | # path to data and result files 23 | opts['data_path'] = os.path.join('data', 'CIFAR10') # will donwload data if not exist 24 | utils.mkdirs(opts['data_path']) 25 | opts['result_path'] = os.path.join('result', 'CIFAR10') 26 | utils.mkdirs(opts['result_path']) 27 | 28 | # options for autoencoder 29 | opts['cuda'] = True 30 | opts['z_dim'] = 100 31 | opts['batch_size'] = 128 32 | opts['ae_hid_dim'] = 512 33 | opts['ae_max_epoch'] = 100 34 | opts['ae_lr'] = 1e-4 35 | 36 | # options for EM 37 | opts_em = {} 38 | opts_em['max_iters'] = 100 # maximum number of EM iterations 39 | opts_em['rll_tol'] = 1e-5 # tolerance of relative loglik improvement 40 | opts_em['batch_size'] = 128 # batch inside E and M steps 41 | 42 | # options for SGD 43 | opts_sgd = {} 44 | opts_sgd['batch_size'] = 256 # batch size 45 | opts_sgd['max_epochs'] = 100 # maximum number of epochs 46 | 47 | 48 | class ConvEncoder(nn.Module): 49 | 50 | def __init__(self, z_dim, h_dim=256): 51 | 52 | super().__init__() 53 | 54 | self.z_dim = z_dim 55 | 56 | self.conv1 = nn.Conv2d(3, 32, 4, 2, 1) 57 | self.conv2 = nn.Conv2d(32, 32, 4, 2, 1) 58 | self.conv3 = nn.Conv2d(32, 64, 4, 2, 1) 59 | self.fc4 = nn.Linear(64*4*4, h_dim) 60 | self.fc5 = nn.Linear(h_dim, z_dim) 61 | 62 | def forward(self, x): 63 | 64 | out = F.leaky_relu(self.conv1(x)) 65 | out = F.leaky_relu(self.conv2(out)) 66 | out = F.leaky_relu(self.conv3(out)) 67 | out = out.view(out.size(0), -1) 68 | out = F.leaky_relu(self.fc4(out)) 69 | out = self.fc5(out) 70 | 71 | out = out / utils.norm(out, dim=1) 72 | 73 | return out 74 | 75 | 76 | class ConvDecoder(nn.Module): 77 | 78 | def __init__(self, z_dim, h_dim=256): 79 | 80 | super().__init__() 81 | 82 | self.z_dim = z_dim 83 | 84 | self.fc1 = nn.Linear(z_dim, h_dim) 85 | self.fc2 = nn.Linear(h_dim, 4*4*64) 86 | self.deconv3 = nn.ConvTranspose2d(64, 32, 4, 2, 1) 87 | self.deconv4 = nn.ConvTranspose2d(32, 32, 4, 2, 1) 88 | self.deconv5 = nn.ConvTranspose2d(32, 3, 4, 2, 1) 89 | 90 | def forward(self, z): 91 | 92 | out = F.relu(self.fc1(z)) 93 | out = F.relu(self.fc2(out)) 94 | out = out.view(out.size(0), 64, 4, 4) 95 | out = F.relu(self.deconv3(out)) 96 | out = F.relu(self.deconv4(out)) 97 | out = self.deconv5(out) 98 | 99 | return out 100 | 101 | 102 | ''' 103 | Train autoencoder: image -> z 104 | ''' 105 | 106 | class Autoencoder(nn.Module): 107 | 108 | def __init__(self, z_dim, hid_dim): 109 | 110 | super().__init__() 111 | 112 | self.z_dim = z_dim 113 | 114 | self.encoder = ConvEncoder(z_dim, hid_dim) 115 | self.decoder = ConvDecoder(z_dim, hid_dim) 116 | 117 | def forward(self, x): 118 | 119 | x2 = self.decoder(self.encoder(x)) 120 | return ((x2-x)**2).sum() / x.shape[0] 121 | 122 | def encode(self, x): 123 | 124 | return self.encoder(x) 125 | 126 | 127 | ae = Autoencoder(z_dim=opts['z_dim'], hid_dim=opts['ae_hid_dim']) 128 | if opts['cuda']: 129 | ae = ae.cuda() 130 | 131 | # load data (automatically download if not exist) 132 | dset_tr = datasets.CIFAR10(opts['data_path'], train=True, download=True, transform=transforms.ToTensor()) 133 | dset_te = datasets.CIFAR10(opts['data_path'], train=False, download=True, transform=transforms.ToTensor()) 134 | dl_tr = DataLoader(dset_tr, batch_size=opts['batch_size'], shuffle=True, drop_last=True) 135 | dl_te = DataLoader(dset_te, batch_size=opts['batch_size'], shuffle=False, drop_last=False) 136 | 137 | # optimizer 138 | optimizer = optim.Adam(ae.parameters(), lr=opts['ae_lr']) 139 | 140 | for epoch in range(opts['ae_max_epoch']): 141 | 142 | train_loss = 0 143 | 144 | for batch_idx, (XX, YY) in enumerate(dl_tr): 145 | 146 | if opts['cuda']: 147 | XX = XX.cuda() 148 | 149 | optimizer.zero_grad() 150 | loss = ae(XX) 151 | loss.backward() 152 | optimizer.step() 153 | 154 | train_loss += loss.item() * XX.shape[0] 155 | 156 | if batch_idx % 20 == 0: 157 | prn_str = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 158 | epoch, batch_idx*XX.shape[0], len(dl_tr.dataset), 159 | 100.*batch_idx/len(dl_tr), loss.item() ) 160 | print(prn_str) 161 | 162 | epoch_loss = train_loss / len(dl_tr.dataset) 163 | prn_str = '====> Epoch: {} Average loss: {:.4f}'.format(epoch, epoch_loss) 164 | print(prn_str) 165 | 166 | # save the learned autoencoder 167 | torch.save(ae.state_dict(), os.path.join(opts['result_path'], 'ae_cifar10.pth')) 168 | 169 | 170 | # load autoencoder (if necessary) 171 | ae = Autoencoder(z_dim=opts['z_dim'], hid_dim=opts['ae_hid_dim']) 172 | ae.load_state_dict( torch.load(os.path.join(opts['result_path'], 'ae_cifar10.pth')) ) 173 | if opts['cuda']: 174 | ae = ae.cuda() 175 | 176 | # embed all images (and their labels) 177 | dl_tr = DataLoader(dset_tr, batch_size=opts['batch_size'], shuffle=False, drop_last=False) 178 | dl_te = DataLoader(dset_te, batch_size=opts['batch_size'], shuffle=False, drop_last=False) 179 | with torch.no_grad(): 180 | for split in ['tr', 'te']: 181 | dl = dl_tr if split=='tr' else dl_te 182 | for ii, (xx, yy) in enumerate(dl): 183 | if opts['cuda']: 184 | xx = xx.cuda() 185 | zz = ae.encoder(xx).cpu() 186 | if ii==0: 187 | ZZ, YY = zz, yy 188 | else: 189 | ZZ, YY = torch.cat([ZZ, zz], dim=0), torch.cat([YY, yy], dim=0) 190 | torch.save([ZZ, YY], os.path.join(opts['result_path'], 'embeds_%s_cifar10.pth' % split)) 191 | 192 | 193 | ''' 194 | Do mixture learning in the embedded space 195 | ''' 196 | 197 | seed = 1234 198 | 199 | random.seed(seed) 200 | np.random.seed(seed) 201 | torch.manual_seed(seed) 202 | torch.cuda.manual_seed(seed) 203 | torch.cuda.manual_seed_all(seed) 204 | 205 | # load embedded data and labels 206 | Ztr, Ytr = torch.load(os.path.join(opts['result_path'], 'embeds_tr_cifar10.pth')) 207 | Zte, Yte = torch.load(os.path.join(opts['result_path'], 'embeds_te_cifar10.pth')) 208 | 209 | 210 | ''' 211 | EM (done on CPU) 212 | ''' 213 | 214 | # randomly initialized mixture 215 | mix = models.MixvMF(x_dim=opts['z_dim'], order=10) 216 | 217 | # data loader for training samples 218 | dataloader = DataLoader(Ztr, batch_size=opts_em['batch_size'], shuffle=False, drop_last=False) 219 | 220 | # EM learning 221 | ll_old = -np.inf 222 | with torch.no_grad(): 223 | 224 | for steps in range(opts_em['max_iters']): 225 | 226 | # E-step 227 | logalpha, mus, kappas = mix.get_params() 228 | logliks, logpcs = mix(Ztr) 229 | ll = logliks.sum() 230 | jll = logalpha.unsqueeze(0) + logpcs 231 | qz = jll.log_softmax(1).exp() 232 | 233 | if steps==0: 234 | prn_str = '[Before EM starts] loglik = %.4f\n' % ll.item() 235 | else: 236 | prn_str = '[Steps %03d] loglik (before M-step) = %.4f\n' % (steps, ll.item()) 237 | print(prn_str) 238 | 239 | # tolerance check 240 | if steps>0: 241 | rll = (ll-ll_old).abs() / (ll_old.abs()+utils.realmin) 242 | if rll < opts_em['rll_tol']: 243 | prn_str = 'Stop EM since the relative improvement ' 244 | prn_str += '(%.6f) < tolerance (%.6f)\n' % (rll.item(), opts_em['rll_tol']) 245 | print(prn_str) 246 | break 247 | 248 | ll_old = ll 249 | 250 | # M-step 251 | qzx = ( qz.unsqueeze(2) * Ztr.unsqueeze(1) ).sum(0) 252 | qzx_norms = utils.norm(qzx, dim=1) 253 | mus_new = qzx / qzx_norms 254 | Rs = qzx_norms[:,0] / (qz.sum(0) + utils.realmin) 255 | kappas_new = (mix.x_dim*Rs - Rs**3) / (1 - Rs**2) 256 | alpha_new = qz.sum(0) / Ztr.shape[0] 257 | 258 | # assign new params 259 | mix.set_params(alpha_new, mus_new, kappas_new) 260 | 261 | # save model 262 | mix_em = mix 263 | torch.save(mix_em.state_dict(), os.path.join(opts['result_path'], 'mix_em_cifar10.pth')) 264 | 265 | logliks, logpcs = mix_em(Ztr) 266 | ll = logliks.sum() 267 | prn_str = '[Training done] loglik = %.4f\n' % ll.item() 268 | print(prn_str) 269 | 270 | # cluster label predictions 271 | with torch.no_grad(): 272 | logalpha, mus, kappas = mix_em.get_params() 273 | clabs_em = ( logalpha.unsqueeze(0) + logpcs ).max(1)[1] 274 | logliks_te, logpcs_te = mix_em(Zte) 275 | clabs_te_em = ( logalpha.unsqueeze(0) + logpcs_te ).max(1)[1] 276 | 277 | # clustering metrics 278 | metrics_em = {} 279 | metrics_em['ARI'] = sklearn.metrics.adjusted_rand_score(Ytr, clabs_em) 280 | metrics_em['ARI_te'] = sklearn.metrics.adjusted_rand_score(Yte, clabs_te_em) 281 | metrics_em['NMI'] = sklearn.metrics.normalized_mutual_info_score(Ytr, clabs_em) 282 | metrics_em['NMI_te'] = sklearn.metrics.normalized_mutual_info_score(Yte, clabs_te_em) 283 | 284 | prn_str = '== EM estimator ==\n' 285 | prn_str += 'Test: ARI = %.4f, NMI = %.4f\n' % (metrics_em['ARI_te'], metrics_em['NMI_te']) 286 | prn_str += 'Train: ARI = %.4f, NMI = %.4f\n' % (metrics_em['ARI'], metrics_em['NMI']) 287 | print(prn_str) 288 | 289 | 290 | ''' 291 | SGD 292 | ''' 293 | 294 | # data loader for training samples 295 | dataloader = DataLoader(Ztr, batch_size=opts_sgd['batch_size'], shuffle=False, drop_last=False) 296 | 297 | # create a model 298 | mix = models.MixvMF(x_dim=opts['z_dim'], order=10) 299 | mix = mix.cuda() 300 | 301 | # create optimizers and set optim params 302 | params = list(mix.parameters()) 303 | optim = torch.optim.Adam(params, lr=1e-1, betas=[0.9, 0.99]) 304 | lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.95) 305 | 306 | # SGD training 307 | for epoch in range(opts_sgd['max_epochs']): 308 | 309 | ll, nsamps = 0.0, 0 310 | for ii, data in enumerate(dataloader): 311 | data = data.cuda() 312 | logliks, _ = mix(data) 313 | anll = -logliks.mean() 314 | optim.zero_grad() 315 | anll.backward() 316 | optim.step() 317 | ll += -anll.item()*data.shape[0] 318 | 319 | lr_sched.step() 320 | 321 | if (epoch+1) % 1 == 0: 322 | prn_str = '[Epoch %03d] loglik (accumulated) = %.4f\n' % (epoch, ll) 323 | print(prn_str) 324 | 325 | # save model 326 | mix_sgd = mix 327 | mix_sgd = mix_sgd.cpu() 328 | torch.save(mix_sgd.state_dict(), os.path.join(opts['result_path'], 'mix_sgd_cifar10.pth')) 329 | 330 | # cluster label predictions 331 | with torch.no_grad(): 332 | 333 | logalpha, mus, kappas = mix_sgd.get_params() 334 | 335 | # prediction on train data 336 | for ii, data in enumerate(dataloader): 337 | logliks_, logpcs_ = mix_sgd(data) 338 | if ii==0: 339 | logliks, logpcs = logliks_, logpcs_ 340 | else: 341 | logliks = torch.cat([logliks, logliks_], dim=0) 342 | logpcs = torch.cat([logpcs, logpcs_], dim=0) 343 | clabs_sgd = ( logalpha.unsqueeze(0) + logpcs ).max(1)[1] 344 | 345 | # prediction on test data 346 | dataloader = DataLoader(Zte, batch_size=opts_sgd['batch_size'], shuffle=False, drop_last=False) 347 | for ii, data in enumerate(dataloader): 348 | logliks_, logpcs_ = mix_sgd(data) 349 | if ii==0: 350 | logliks_te, logpcs_te = logliks_, logpcs_ 351 | else: 352 | logliks_te = torch.cat([logliks_te, logliks_], dim=0) 353 | logpcs_te = torch.cat([logpcs_te, logpcs_], dim=0) 354 | clabs_te_sgd = ( logalpha.unsqueeze(0) + logpcs_te ).max(1)[1] 355 | 356 | # clustering metrics 357 | metrics_sgd = {} 358 | metrics_sgd['ARI'] = sklearn.metrics.adjusted_rand_score(Ytr, clabs_sgd) 359 | metrics_sgd['ARI_te'] = sklearn.metrics.adjusted_rand_score(Yte, clabs_te_sgd) 360 | metrics_sgd['NMI'] = sklearn.metrics.normalized_mutual_info_score(Ytr, clabs_sgd) 361 | metrics_sgd['NMI_te'] = sklearn.metrics.normalized_mutual_info_score(Yte, clabs_te_sgd) 362 | 363 | prn_str = '== SGD estimator ==\n' 364 | prn_str += 'Test: ARI = %.4f, NMI = %.4f\n' % (metrics_sgd['ARI_te'], metrics_sgd['NMI_te']) 365 | prn_str += 'Train: ARI = %.4f, NMI = %.4f\n' % (metrics_sgd['ARI'], metrics_sgd['NMI']) 366 | print(prn_str) 367 | 368 | -------------------------------------------------------------------------------- /mle_for_mix_vmf.py: -------------------------------------------------------------------------------- 1 | import random 2 | import itertools 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | import models 8 | import utils 9 | 10 | 11 | seed = 1234 12 | 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | ''' 21 | Define a true vMF mixture model 22 | ''' 23 | 24 | mix_true = models.MixvMF(x_dim=5, order=3) 25 | 26 | mus_true = [ 27 | torch.tensor([0.3, -1.2, 2.3, 0.4, 2.1]), 28 | torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]), 29 | torch.tensor([-0.3, 1.2, -2.3, -0.4, -2.1]), 30 | ] 31 | mus_true = [ mu / utils.norm(mu, dim=0) for mu in mus_true ] 32 | kappas_true = [ 33 | torch.tensor(100.0), 34 | torch.tensor(50.0), 35 | torch.tensor(100.0) 36 | ] 37 | alpha_true = torch.tensor([0.3, 0.4, 0.3]) 38 | 39 | mix_true.set_params(alpha_true, mus_true, kappas_true) 40 | 41 | # sample data from true mixture 42 | samples, cids = mix_true.sample(N=1000, rsf=1) 43 | 44 | 45 | ''' 46 | Full-batch EM learning 47 | ''' 48 | 49 | opts = {} 50 | opts['max_iters'] = 100 # maximum number of EM iterations 51 | opts['rll_tol'] = 1e-5 # tolerance of relative loglik improvement 52 | 53 | # randomly initialized mixture 54 | mix = models.MixvMF(x_dim=5, order=3) 55 | 56 | # EM learning 57 | ll_old = -np.inf 58 | with torch.no_grad(): 59 | 60 | for steps in range(opts['max_iters']): 61 | 62 | # E-step 63 | logalpha, mus, kappas = mix.get_params() 64 | logliks, logpcs = mix(samples) 65 | ll = logliks.sum() 66 | jll = logalpha.unsqueeze(0) + logpcs 67 | qz = jll.log_softmax(1).exp() 68 | 69 | if steps==0: 70 | prn_str = '[Before EM starts] loglik = %.4f\n' % ll.item() 71 | else: 72 | prn_str = '[Steps %03d] loglik (before M-step) = %.4f\n' % (steps, ll.item()) 73 | print(prn_str) 74 | 75 | # tolerance check 76 | if steps>0: 77 | rll = (ll-ll_old).abs() / (ll_old.abs()+utils.realmin) 78 | if rll < opts['rll_tol']: 79 | prn_str = 'Stop EM since the relative improvement ' 80 | prn_str += '(%.6f) < tolerance (%.6f)\n' % (rll.item(), opts['rll_tol']) 81 | print(prn_str) 82 | break 83 | 84 | ll_old = ll 85 | 86 | # M-step 87 | qzx = ( qz.unsqueeze(2) * samples.unsqueeze(1) ).sum(0) 88 | qzx_norms = utils.norm(qzx, dim=1) 89 | mus_new = qzx / qzx_norms 90 | Rs = qzx_norms[:,0] / (qz.sum(0) + utils.realmin) 91 | kappas_new = (mix.x_dim*Rs - Rs**3) / (1 - Rs**2) 92 | alpha_new = qz.sum(0) / samples.shape[0] 93 | 94 | # assign new params 95 | mix.set_params(alpha_new, mus_new, kappas_new) 96 | 97 | logliks, logpcs = mix(samples) 98 | ll = logliks.sum() 99 | prn_str = '[Training done] loglik = %.4f\n' % ll.item() 100 | print(prn_str) 101 | 102 | # find the best matching permutations of components 103 | print('Find the best matching permutations of components') 104 | with torch.no_grad(): 105 | logalpha, mus, kappas = mix.get_params() 106 | alpha = logalpha.exp() 107 | perms = list(itertools.permutations(range(mix.order))) 108 | best_perm, best_error = None, 1e10 109 | for perm in perms: 110 | perm = np.array(perm) 111 | error_alpha = (alpha[perm] - alpha_true).abs().sum() 112 | error_mus = (mus[perm,:] - torch.stack(mus_true, dim=0)).abs().sum() 113 | error_kappas = (kappas[perm] - torch.stack(kappas_true, dim=0)).abs().sum() 114 | error = (error_alpha + error_mus + error_kappas).item() 115 | print('perm = %s: error = %.4f' % (perm, error)) 116 | if error < best_error: 117 | best_perm, best_error = perm, error 118 | print('best perm has changed to: %s' % best_perm) 119 | 120 | print('For the best components permutation:') 121 | print('----------') 122 | print('alpha_true = %s' % alpha_true) 123 | print('alpha = %s' % alpha[best_perm]) 124 | print('error in alpha = %.4f' % (alpha[best_perm] - alpha_true).abs().sum().item()) 125 | print('----------') 126 | print('mus_true = %s' % torch.stack(mus_true, dim=0)) 127 | print('mus = %s' % mus[best_perm]) 128 | print('error in mus = %.4f' % (mus[best_perm,:] - torch.stack(mus_true, dim=0)).abs().sum().item()) 129 | print('----------') 130 | print('kappas_true = %s' % torch.stack(kappas_true, dim=0)) 131 | print('kappas = %s' % kappas[best_perm]) 132 | print('error in kappas = %.4f' % (kappas[best_perm] - torch.stack(kappas_true, dim=0)).abs().sum().item()) 133 | print('----------') 134 | 135 | # save model 136 | mix_em = mix 137 | 138 | 139 | ''' 140 | SGD-based ML estimator 141 | ''' 142 | 143 | B = 64 # batch size 144 | max_epochs = 100 # maximum number of epochs 145 | 146 | dataloader = DataLoader(samples, batch_size=B, shuffle=True, drop_last=True) 147 | 148 | # create a model 149 | mix = models.MixvMF(x_dim=5, order=3) 150 | mix = mix.cuda() 151 | 152 | # create optimizers and set optim params 153 | params = list(mix.parameters()) 154 | optim = torch.optim.Adam(params, lr=1e-1, betas=[0.9, 0.99]) 155 | lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.95) 156 | 157 | # SGD training 158 | for epoch in range(max_epochs): 159 | 160 | ll, nsamps = 0.0, 0 161 | for ii, data in enumerate(dataloader): 162 | data = data.cuda() 163 | logliks, _ = mix(data) 164 | anll = -logliks.mean() 165 | optim.zero_grad() 166 | anll.backward() 167 | optim.step() 168 | ll += -anll.item()*data.shape[0] 169 | 170 | lr_sched.step() 171 | 172 | if (epoch+1) % 1 == 0: 173 | prn_str = '[Epoch %03d] loglik (accumulated) = %.4f\n' % (epoch, ll) 174 | print(prn_str) 175 | 176 | mix = mix.cpu() 177 | 178 | # find the best matching permutations of components 179 | print('Find the best matching permutations of components') 180 | with torch.no_grad(): 181 | logalpha, mus, kappas = mix.get_params() # logalpha = M-dim 182 | alpha = logalpha.exp() 183 | perms = list(itertools.permutations(range(mix.order))) 184 | best_perm, best_error = None, 1e10 185 | for perm in perms: 186 | perm = np.array(perm) 187 | error_alpha = (alpha[perm] - alpha_true).abs().sum() 188 | error_mus = (mus[perm,:] - torch.stack(mus_true, dim=0)).abs().sum() 189 | error_kappas = (kappas[perm] - torch.stack(kappas_true, dim=0)).abs().sum() 190 | error = (error_alpha + error_mus + error_kappas).item() 191 | print('perm = %s: error = %.4f' % (perm, error)) 192 | if error < best_error: 193 | best_perm, best_error = perm, error 194 | print('best perm has changed to: %s' % best_perm) 195 | 196 | print('For the best components permutation:') 197 | print('----------') 198 | print('alpha_true = %s' % alpha_true) 199 | print('alpha = %s' % alpha[best_perm]) 200 | print('error in alpha = %.4f' % (alpha[best_perm] - alpha_true).abs().sum().item()) 201 | print('----------') 202 | print('mus_true = %s' % torch.stack(mus_true, dim=0)) 203 | print('mus = %s' % mus[best_perm]) 204 | print('error in mus = %.4f' % (mus[best_perm,:] - torch.stack(mus_true, dim=0)).abs().sum().item()) 205 | print('----------') 206 | print('kappas_true = %s' % torch.stack(kappas_true, dim=0)) 207 | print('kappas = %s' % kappas[best_perm]) 208 | print('error in kappas = %.4f' % (kappas[best_perm] - torch.stack(kappas_true, dim=0)).abs().sum().item()) 209 | print('----------') 210 | 211 | # save model 212 | mix_sgd = mix 213 | 214 | -------------------------------------------------------------------------------- /mle_for_vmf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | import models 5 | import utils 6 | 7 | 8 | ''' 9 | Define a true vMF model 10 | ''' 11 | 12 | mu_true = torch.zeros(100) # 5 or 20 13 | mu_true[0] = 1.0 14 | mu_true = mu_true / utils.norm(mu_true, dim=0) 15 | kappa_true = torch.tensor(500.0) # 50.0 16 | vmf_true = models.vMF(x_dim=mu_true.shape[0]) 17 | vmf_true.set_params(mu=mu_true, kappa=kappa_true) 18 | vmf_true = vmf_true.cuda() 19 | 20 | # sample from true vMF model 21 | samples = vmf_true.sample(N=10000, rsf=1) 22 | 23 | 24 | ''' 25 | Full-batch ML estimator 26 | ''' 27 | 28 | xm = samples.mean(0) 29 | xm_norm = (xm**2).sum().sqrt() 30 | mu0 = xm / xm_norm 31 | kappa0 = (len(xm)*xm_norm - xm_norm**3) / (1-xm_norm**2) 32 | 33 | mu_err = ((mu0.cpu() - mu_true)**2).sum().item() # relative error 34 | kappa_err = (kappa0.cpu() - kappa_true).abs().item() / kappa_true.item() 35 | prn_str = '== Batch ML estimator ==\n' 36 | prn_str += 'mu = %s (error = %.8f)\n' % (mu0.cpu().numpy(), mu_err) 37 | prn_str += 'kappa = %s (error = %.8f)\n' % (kappa0.cpu().numpy(), kappa_err) 38 | print(prn_str) 39 | 40 | 41 | ''' 42 | SGD-based ML estimator 43 | ''' 44 | 45 | B = 128 # batch size 46 | max_epochs = 100 # maximum number of epochs 47 | 48 | dataloader = DataLoader(samples, batch_size=B, shuffle=True, drop_last=True) 49 | 50 | # create a model 51 | vmf = models.vMF(x_dim=mu_true.shape[0]) 52 | vmf = vmf.cuda() 53 | 54 | # create optimizers and set optim params 55 | params = list(vmf.parameters()) 56 | optim = torch.optim.Adam(params, lr=1e-2, betas=[0.9, 0.99]) 57 | lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.95) 58 | 59 | # error of the initial model 60 | with torch.no_grad(): 61 | mu, kappa = vmf.get_params() 62 | mu_err = (mu.cpu() - mu_true).abs().mean().item() # dim-wise absolute error 63 | kappa_err = (kappa.cpu() - kappa_true).abs().item() 64 | prn_str = '== Before training starts ==\n' 65 | prn_str += 'mu = %s (error = %.6f)\n' % (mu.cpu().numpy(), mu_err) 66 | prn_str += 'kappa = %s (error = %.6f)\n' % (kappa.cpu().numpy(), kappa_err) 67 | print(prn_str) 68 | 69 | # SGD training 70 | for epoch in range(max_epochs): 71 | 72 | obj, nsamps = 0.0, 0 73 | for ii, data in enumerate(dataloader): 74 | enll = -vmf(data).mean() 75 | optim.zero_grad() 76 | enll.backward() 77 | optim.step() 78 | obj += enll.item()*data.shape[0] 79 | nsamps += data.shape[0] 80 | obj /= nsamps 81 | 82 | lr_sched.step() 83 | 84 | if (epoch+1) % 10 == 0: 85 | with torch.no_grad(): 86 | mu, kappa = vmf.get_params() 87 | mu_err = ((mu.cpu() - mu_true)**2).sum().item() # relative error 88 | kappa_err = (kappa.cpu() - kappa_true).abs().item() / kappa_true.item() 89 | prn_str = '== After epoch %d ==\n' % epoch 90 | prn_str += 'Expectected negative log-likelihood = %.4f\n' % enll.item() 91 | prn_str += 'mu = %s (error = %.8f)\n' % (mu.cpu().numpy(), mu_err) 92 | prn_str += 'kappa = %s (error = %.8f)\n' % (kappa.cpu().numpy(), kappa_err) 93 | print(prn_str) 94 | 95 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mpmath 3 | import torch 4 | import torch.nn as nn 5 | 6 | import utils 7 | 8 | 9 | class vMFLogPartition(torch.autograd.Function): 10 | 11 | ''' 12 | Evaluates log C_d(kappa) for vMF density 13 | Allows autograd wrt kappa 14 | ''' 15 | 16 | besseli = np.vectorize(mpmath.besseli) 17 | log = np.vectorize(mpmath.log) 18 | nhlog2pi = -0.5 * np.log(2*np.pi) 19 | 20 | @staticmethod 21 | def forward(ctx, *args): 22 | 23 | ''' 24 | Args: 25 | args[0] = d; scalar (> 0) 26 | args[1] = kappa; (> 0) torch tensor of any shape 27 | 28 | Returns: 29 | logC = log C_d(kappa); torch tensor of the same shape as kappa 30 | ''' 31 | 32 | d = args[0] 33 | kappa = args[1] 34 | 35 | s = 0.5*d - 1 36 | 37 | # log I_s(kappa) 38 | mp_kappa = mpmath.mpf(1.0) * kappa.detach().cpu().numpy() 39 | mp_logI = vMFLogPartition.log( vMFLogPartition.besseli(s, mp_kappa) ) 40 | logI = torch.from_numpy( np.array(mp_logI.tolist(), dtype=float) ).to(kappa) 41 | 42 | if (logI!=logI).sum().item() > 0: # there is nan 43 | raise ValueError('NaN is detected from the output of log-besseli()') 44 | 45 | logC = d * vMFLogPartition.nhlog2pi + s * kappa.log() - logI 46 | 47 | # save for backard() 48 | ctx.s, ctx.mp_kappa, ctx.logI = s, mp_kappa, logI 49 | 50 | return logC 51 | 52 | @staticmethod 53 | def backward(ctx, *grad_output): 54 | 55 | s, mp_kappa, logI = ctx.s, ctx.mp_kappa, ctx.logI 56 | 57 | # log I_{s+1}(kappa) 58 | mp_logI2 = vMFLogPartition.log( vMFLogPartition.besseli(s+1, mp_kappa) ) 59 | logI2 = torch.from_numpy( np.array(mp_logI2.tolist(), dtype=float) ).to(logI) 60 | 61 | if (logI2!=logI2).sum().item() > 0: # there is nan 62 | raise ValueError('NaN is detected from the output of log-besseli()') 63 | 64 | dlogC_dkappa = -(logI2 - logI).exp() 65 | 66 | return None, grad_output[0] * dlogC_dkappa 67 | 68 | 69 | class vMF(nn.Module): 70 | 71 | ''' 72 | vMF(x; mu, kappa) 73 | ''' 74 | 75 | def __init__(self, x_dim, reg=1e-6): 76 | 77 | super(vMF, self).__init__() 78 | 79 | self.x_dim = x_dim 80 | 81 | self.mu_unnorm = nn.Parameter(torch.randn(x_dim)) 82 | self.logkappa = nn.Parameter(0.01*torch.randn([])) 83 | 84 | self.reg = reg 85 | 86 | def set_params(self, mu, kappa): 87 | 88 | with torch.no_grad(): 89 | self.mu_unnorm.copy_(mu) 90 | self.logkappa.copy_(torch.log(kappa+utils.realmin)) 91 | 92 | def get_params(self): 93 | 94 | mu = self.mu_unnorm / utils.norm(self.mu_unnorm) 95 | kappa = self.logkappa.exp() + self.reg 96 | 97 | return mu, kappa 98 | 99 | def forward(self, x, utc=False): 100 | 101 | ''' 102 | Evaluate logliks, log p(x) 103 | 104 | Args: 105 | x = batch for x 106 | utc = whether to evaluate only up to constant or exactly 107 | if True, no log-partition computed 108 | if False, exact loglik computed 109 | 110 | Returns: 111 | logliks = log p(x) 112 | ''' 113 | 114 | mu, kappa = self.get_params() 115 | 116 | dotp = (mu.unsqueeze(0) * x).sum(1) 117 | 118 | if utc: 119 | logliks = kappa * dotp 120 | else: 121 | logC = vMFLogPartition.apply(self.x_dim, kappa) 122 | logliks = kappa * dotp + logC 123 | 124 | return logliks 125 | 126 | def sample(self, N=1, rsf=10): 127 | 128 | ''' 129 | Args: 130 | N = number of samples to generate 131 | rsf = multiplicative factor for extra backup samples in rejection sampling 132 | 133 | Returns: 134 | samples; N samples generated 135 | 136 | Notes: 137 | no autodiff 138 | ''' 139 | 140 | d = self.x_dim 141 | 142 | with torch.no_grad(): 143 | 144 | mu, kappa = self.get_params() 145 | 146 | # Step-1: Sample uniform unit vectors in R^{d-1} 147 | v = torch.randn(N, d-1).to(mu) 148 | v = v / utils.norm(v, dim=1) 149 | 150 | # Step-2: Sample v0 151 | kmr = np.sqrt( 4*kappa.item()**2 + (d-1)**2 ) 152 | bb = (kmr - 2*kappa) / (d-1) 153 | aa = (kmr + 2*kappa + d - 1) / 4 154 | dd = (4*aa*bb)/(1+bb) - (d-1)*np.log(d-1) 155 | beta = torch.distributions.Beta( torch.tensor(0.5*(d-1)), torch.tensor(0.5*(d-1)) ) 156 | uniform = torch.distributions.Uniform(0.0, 1.0) 157 | v0 = torch.tensor([]).to(mu) 158 | while len(v0) < N: 159 | eps = beta.sample([1, rsf*(N-len(v0))]).squeeze().to(mu) 160 | uns = uniform.sample([1, rsf*(N-len(v0))]).squeeze().to(mu) 161 | w0 = (1 - (1+bb)*eps) / (1 - (1-bb)*eps) 162 | t0 = (2*aa*bb) / (1 - (1-bb)*eps) 163 | det = (d-1)*t0.log() - t0 + dd - uns.log() 164 | v0 = torch.cat([v0, torch.tensor(w0[det>=0]).to(mu)]) 165 | if len(v0) > N: 166 | v0 = v0[:N] 167 | break 168 | v0 = v0.reshape([N,1]) 169 | 170 | # Step-3: Form x = [v0; sqrt(1-v0^2)*v] 171 | samples = torch.cat([v0, (1-v0**2).sqrt()*v], 1) 172 | 173 | # Setup-4: Householder transformation 174 | e1mu = torch.zeros(d,1).to(mu); e1mu[0,0] = 1.0 175 | e1mu = e1mu - mu if len(mu.shape)==2 else e1mu - mu.unsqueeze(1) 176 | e1mu = e1mu / utils.norm(e1mu, dim=0) 177 | samples = samples - 2 * (samples @ e1mu) @ e1mu.t() 178 | 179 | return samples 180 | 181 | 182 | class MixvMF(nn.Module): 183 | 184 | ''' 185 | MixvMF(x) = \sum_{m=1}^M \alpha_m vMF(x; mu_m, kappa_m) 186 | ''' 187 | 188 | def __init__(self, x_dim, order, reg=1e-6): 189 | 190 | super(MixvMF, self).__init__() 191 | 192 | self.x_dim = x_dim 193 | self.order = order 194 | self.reg = reg 195 | 196 | self.alpha_logit = nn.Parameter(0.01*torch.randn(order)) 197 | self.comps = nn.ModuleList( 198 | [ vMF(x_dim, reg) for _ in range(order) ] 199 | ) 200 | 201 | def set_params(self, alpha, mus, kappas): 202 | 203 | with torch.no_grad(): 204 | self.alpha_logit.copy_(torch.log(alpha+utils.realmin)) 205 | for m in range(self.order): 206 | self.comps[m].mu_unnorm.copy_(mus[m]) 207 | self.comps[m].logkappa.copy_(torch.log(kappas[m]+utils.realmin)) 208 | 209 | def get_params(self): 210 | 211 | logalpha = self.alpha_logit.log_softmax(0) 212 | 213 | mus, kappas = [], [] 214 | for m in range(self.order): 215 | mu, kappa = self.comps[m].get_params() 216 | mus.append(mu) 217 | kappas.append(kappa) 218 | 219 | mus = torch.stack(mus, axis=0) 220 | kappas = torch.stack(kappas, axis=0) 221 | 222 | return logalpha, mus, kappas 223 | 224 | def forward(self, x): 225 | 226 | ''' 227 | Evaluate logliks, log p(x) 228 | 229 | Args: 230 | x = batch for x 231 | 232 | Returns: 233 | logliks = log p(x) 234 | logpcs = log p(x|c=m) 235 | ''' 236 | 237 | logalpha = self.alpha_logit.log_softmax(0) 238 | 239 | logpcs = [] 240 | for m in range(self.order): 241 | logpcs.append(self.comps[m](x)) 242 | logpcs = torch.stack(logpcs, dim=1) 243 | 244 | logliks = (logalpha.unsqueeze(0) + logpcs).logsumexp(1) 245 | 246 | return logliks, logpcs 247 | 248 | def sample(self, N=1, rsf=10): 249 | 250 | ''' 251 | Args: 252 | N = number of samples to generate 253 | rsf = multiplicative factor for extra backup samples in rejection sampling 254 | (used in sampling from vMF) 255 | 256 | Returns: 257 | samples = N samples generated 258 | cids = which components the samples come from; N-dim {0,1,...,M-1}-valued 259 | 260 | Notes: 261 | no autodiff 262 | ''' 263 | 264 | with torch.no_grad(): 265 | 266 | alpha = self.alpha_logit.log_softmax(0).exp() 267 | 268 | cids = torch.multinomial(alpha, N, replacement=True) 269 | 270 | samples = torch.zeros(N, self.x_dim) 271 | for c in range(self.order): 272 | Nc = (cids==c).sum() 273 | if Nc > 0: 274 | samples[cids==c,:] = self.comps[c].sample(N=Nc, rsf=rsf) 275 | 276 | return samples, cids 277 | 278 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | realmin = 1e-10 5 | 6 | 7 | def norm(input, p=2, dim=0, eps=1e-12): 8 | return input.norm(p, dim, keepdim=True).clamp(min=eps).expand_as(input) 9 | 10 | 11 | def mkdirs(path): 12 | if not os.path.exists(path): 13 | os.makedirs(path) 14 | 15 | --------------------------------------------------------------------------------