├── scatter.gif ├── README.md ├── tsne.py ├── wrapper.py ├── run.py ├── vtsne.py └── topic_sne.py /scatter.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cemoody/topicsne/HEAD/scatter.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo has a simple t-SNE model implemented in pytorch. 2 | 3 | ![MNIST t-SNE](scatter.gif) 4 | -------------------------------------------------------------------------------- /tsne.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch import nn 6 | 7 | import numpy as np 8 | 9 | 10 | def pairwise(data): 11 | n_obs, dim = data.size() 12 | xk = data.unsqueeze(0).expand(n_obs, n_obs, dim) 13 | xl = data.unsqueeze(1).expand(n_obs, n_obs, dim) 14 | dkl2 = ((xk - xl)**2.0).sum(2).squeeze() 15 | return dkl2 16 | 17 | 18 | class TSNE(nn.Module): 19 | def __init__(self, n_points, n_topics, n_dim): 20 | self.n_points = n_points 21 | self.n_dim = n_dim 22 | super(TSNE, self).__init__() 23 | # Logit of datapoint-to-topic weight 24 | self.logits = nn.Embedding(n_points, n_topics) 25 | 26 | def forward(self, pij, i, j): 27 | # Get for all points 28 | x = self.logits.weight 29 | # Compute squared pairwise distances 30 | dkl2 = pairwise(x) 31 | # Compute partition function 32 | n_diagonal = dkl2.size()[0] 33 | part = (1 + dkl2).pow(-1.0).sum() - n_diagonal 34 | # Compute the numerator 35 | xi = self.logits(i) 36 | xj = self.logits(j) 37 | num = ((1. + (xi - xj)**2.0).sum(1)).pow(-1.0).squeeze() 38 | # This probability is the probability of picking the (i, j) 39 | # relationship out of N^2 other possible pairs in the 2D embedding. 40 | qij = num / part.expand_as(num) 41 | # Compute KLD 42 | loss_kld = pij * (torch.log(pij) - torch.log(qij)) 43 | return loss_kld.sum() 44 | 45 | def __call__(self, *args): 46 | return self.forward(*args) 47 | -------------------------------------------------------------------------------- /wrapper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | 8 | def chunks(n, *args): 9 | """Yield successive n-sized chunks from l.""" 10 | endpoints = [] 11 | start = 0 12 | for stop in range(0, len(args[0]), n): 13 | if stop - start > 0: 14 | endpoints.append((start, stop)) 15 | start = stop 16 | random.shuffle(endpoints) 17 | for start, stop in endpoints: 18 | yield [a[start: stop] for a in args] 19 | 20 | 21 | class Wrapper(): 22 | def __init__(self, model, cuda=True, log_interval=100, epochs=1000, 23 | batchsize=1024): 24 | self.batchsize = batchsize 25 | self.epochs = epochs 26 | self.cuda = cuda 27 | self.model = model 28 | if cuda: 29 | self.model.cuda() 30 | self.optimizer = optim.Adam(model.parameters(), lr=1e-2) 31 | self.log_interval = log_interval 32 | 33 | def fit(self, *args): 34 | self.model.train() 35 | if self.cuda: 36 | self.model.cuda() 37 | for epoch in range(self.epochs): 38 | total = 0.0 39 | for itr, datas in enumerate(chunks(self.batchsize, *args)): 40 | datas = [Variable(torch.from_numpy(data)) for data in datas] 41 | if self.cuda: 42 | datas = [data.cuda() for data in datas] 43 | self.optimizer.zero_grad() 44 | loss = self.model(*datas) 45 | loss.backward() 46 | self.optimizer.step() 47 | total += loss.data[0] 48 | msg = 'Train Epoch: {} \tLoss: {:.6e}' 49 | msg = msg.format(epoch, total / (len(args[0]) * 1.0)) 50 | print(msg) 51 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from sklearn import manifold, datasets 2 | from sklearn.metrics.pairwise import pairwise_distances 3 | from scipy.spatial.distance import squareform 4 | from matplotlib.patches import Ellipse 5 | 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | from wrapper import Wrapper 12 | # from tsne import TSNE 13 | from vtsne import VTSNE 14 | 15 | 16 | def preprocess(perplexity=30, metric='euclidean'): 17 | """ Compute pairiwse probabilities for MNIST pixels. 18 | """ 19 | digits = datasets.load_digits(n_class=6) 20 | pos = digits.data 21 | y = digits.target 22 | n_points = pos.shape[0] 23 | distances2 = pairwise_distances(pos, metric=metric, squared=True) 24 | # This return a n x (n-1) prob array 25 | pij = manifold.t_sne._joint_probabilities(distances2, perplexity, False) 26 | # Convert to n x n prob array 27 | pij = squareform(pij) 28 | return n_points, pij, y 29 | 30 | 31 | draw_ellipse = True 32 | n_points, pij2d, y = preprocess() 33 | i, j = np.indices(pij2d.shape) 34 | i = i.ravel() 35 | j = j.ravel() 36 | pij = pij2d.ravel().astype('float32') 37 | # Remove self-indices 38 | idx = i != j 39 | i, j, pij = i[idx], j[idx], pij[idx] 40 | 41 | n_topics = 2 42 | n_dim = 2 43 | print(n_points, n_dim, n_topics) 44 | 45 | model = VTSNE(n_points, n_topics, n_dim) 46 | wrap = Wrapper(model, batchsize=4096, epochs=1) 47 | for itr in range(500): 48 | wrap.fit(pij, i, j) 49 | 50 | # Visualize the results 51 | embed = model.logits.weight.cpu().data.numpy() 52 | f = plt.figure() 53 | if not draw_ellipse: 54 | plt.scatter(embed[:, 0], embed[:, 1], c=y * 1.0 / y.max()) 55 | plt.axis('off') 56 | plt.savefig('scatter_{:03d}.png'.format(itr), bbox_inches='tight') 57 | plt.close(f) 58 | else: 59 | # Visualize with ellipses 60 | var = np.sqrt(model.logits_lv.weight.clone().exp_().cpu().data.numpy()) 61 | ax = plt.gca() 62 | for xy, (w, h), c in zip(embed, var, y): 63 | e = Ellipse(xy=xy, width=w, height=h, ec=None, lw=0.0) 64 | e.set_facecolor(plt.cm.Paired(c * 1.0 / y.max())) 65 | e.set_alpha(0.5) 66 | ax.add_artist(e) 67 | ax.set_xlim(-9, 9) 68 | ax.set_ylim(-9, 9) 69 | plt.axis('off') 70 | plt.savefig('scatter_{:03d}.png'.format(itr), bbox_inches='tight') 71 | plt.close(f) 72 | 73 | 74 | -------------------------------------------------------------------------------- /vtsne.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch import nn 6 | 7 | import numpy as np 8 | 9 | 10 | def pairwise(data): 11 | n_obs, dim = data.size() 12 | xk = data.unsqueeze(0).expand(n_obs, n_obs, dim) 13 | xl = data.unsqueeze(1).expand(n_obs, n_obs, dim) 14 | dkl2 = ((xk - xl)**2.0).sum(2).squeeze() 15 | return dkl2 16 | 17 | 18 | class VTSNE(nn.Module): 19 | def __init__(self, n_points, n_topics, n_dim): 20 | self.n_points = n_points 21 | self.n_dim = n_dim 22 | super(VTSNE, self).__init__() 23 | # Logit of datapoint-to-topic weight 24 | self.logits_mu = nn.Embedding(n_points, n_topics) 25 | self.logits_lv = nn.Embedding(n_points, n_topics) 26 | 27 | @property 28 | def logits(self): 29 | return self.logits_mu 30 | 31 | def reparametrize(self, mu, logvar): 32 | # From VAE example 33 | # https://github.com/pytorch/examples/blob/master/vae/main.py 34 | std = logvar.mul(0.5).exp_() 35 | eps = torch.cuda.FloatTensor(std.size()).normal_() 36 | eps = Variable(eps) 37 | z = eps.mul(std).add_(mu) 38 | kld = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 39 | kld = torch.sum(kld).mul_(-0.5) 40 | return z, kld 41 | 42 | def sample_logits(self, i=None): 43 | if i is None: 44 | return self.reparametrize(self.logits_mu.weight, self.logits_lv.weight) 45 | else: 46 | return self.reparametrize(self.logits_mu(i), self.logits_lv(i)) 47 | 48 | def forward(self, pij, i, j): 49 | # Get for all points 50 | x, loss_kldrp = self.sample_logits() 51 | # Compute squared pairwise distances 52 | dkl2 = pairwise(x) 53 | # Compute partition function 54 | n_diagonal = dkl2.size()[0] 55 | part = (1 + dkl2).pow(-1.0).sum() - n_diagonal 56 | # Compute the numerator 57 | xi, _ = self.sample_logits(i) 58 | xj, _ = self.sample_logits(j) 59 | num = ((1. + (xi - xj)**2.0).sum(1)).pow(-1.0).squeeze() 60 | qij = num / part.expand_as(num) 61 | # Compute KLD(pij || qij) 62 | loss_kld = pij * (torch.log(pij) - torch.log(qij)) 63 | # Compute sum of all variational terms 64 | return loss_kld.sum() + loss_kldrp.sum() * 1e-7 65 | 66 | def __call__(self, *args): 67 | return self.forward(*args) 68 | -------------------------------------------------------------------------------- /topic_sne.py: -------------------------------------------------------------------------------- 1 | # Build a topic-like embedding for t-SNE 2 | # 0. Measure distances d_ij of points 3 | # 4 | # 1. Construct the input p_ij 5 | # 6 | # 2. Define qij = (1 + dij2)^-1 / SUM(over k, l, k!=l) (1 + dkl2)^-1 7 | # qij = (1 + dij2)^-1 / (-N + SUM(over k, l) (1 + dkl2)^-1) 8 | # qij = (1 + dij2)^-1 / Z 9 | # dij2 = ||x_i - x_j||^2 10 | # Where x_i = gs(r_i) . M 11 | # r_i = is a loading of a document onto topics 12 | # M = translation from topics to vector space 13 | # gs = gumbel-softmax of input rep 14 | # 15 | # 3. Algorithm: 16 | # 3.a Precompute p_ij 17 | # 3.b Build pairwise matrix Sum dkl2 18 | # For all points, sample x_i = gs(r_i) . M 19 | # Build N^2 matrix of pairwise distances: dkl2 = ||xk||^2 + ||xl||^2 - 2 xk . xl 20 | # Z = Sum over all, then subtract N to compensate for diagonal entries 21 | # 3.c For input minibatch of ij, minimize p_ij (log(p_ij) - log(q_ij)) 22 | # 3. SGD minimize p_ij log(p_ij / q_ij) 23 | 24 | import torch 25 | import torch.autograd 26 | import torch.nn.functional as F 27 | from torch.autograd import Variable 28 | from torch import nn 29 | 30 | import numpy as np 31 | 32 | 33 | def gumbel_sample(logits, tau=1.0, temperature=0.0): 34 | # Uniform sample 35 | with torch.cuda.device(logits.get_device()): 36 | noise = torch.rand(logits.size()) 37 | noise.add_(1e-9).log_().neg_() 38 | noise.add_(1e-9).log_().neg_() 39 | gumbel = Variable(noise).cuda() 40 | sample = (logits + gumbel) / tau + temperature 41 | sample = F.softmax(sample.view(sample.size(0), -1)) 42 | return sample.view_as(logits) 43 | 44 | 45 | def pairwise(data): 46 | n_obs, dim = data.size() 47 | xk = data.unsqueeze(0).expand(n_obs, n_obs, dim) 48 | xl = data.unsqueeze(1).expand(n_obs, n_obs, dim) 49 | dkl2 = ((xk - xl)**2.0).sum(2).squeeze() 50 | return dkl2 51 | 52 | 53 | class TopicSNE(nn.Module): 54 | def __init__(self, n_points, n_topics, n_dim): 55 | self.n_points = n_points 56 | self.n_dim = n_dim 57 | super(TopicSNE, self).__init__() 58 | # Logit of datapoint-to-topic weight 59 | self.logits = nn.Embedding(n_points, n_topics) 60 | # Vector for each topic 61 | self.topic = nn.Linear(n_topics, n_dim) 62 | 63 | def positions(self): 64 | # x = self.topic(F.softmax(self.logits.weight)) 65 | x = self.logits.weight 66 | return x 67 | 68 | def dirichlet_ll(self): 69 | pass 70 | 71 | def forward(self, pij, i, j): 72 | # Get for all points 73 | with torch.cuda.device(pij.get_device()): 74 | alli = torch.from_numpy(np.arange(self.n_points)) 75 | alli = Variable(alli).cuda() 76 | # x = self.topic(gumbel_sample(self.logits(alli))) 77 | x = self.logits(alli) 78 | # Compute squared pairwise distances 79 | dkl2 = pairwise(x) 80 | # Compute partition function 81 | n_diagonal = dkl2.size()[0] 82 | part = (1 + dkl2).pow(-1.0).sum() - n_diagonal 83 | # Compute the numerator 84 | # xi = self.topic(gumbel_sample(self.logits(i))) 85 | # xj = self.topic(gumbel_sample(self.logits(j))) 86 | xi = self.logits(i) 87 | xj = self.logits(j) 88 | num = ((1. + (xi - xj)**2.0).sum(1)).pow(-1.0).squeeze() 89 | qij = num / part.expand_as(num) 90 | # Compute KLD 91 | loss_kld = pij * (torch.log(pij) - torch.log(qij)) 92 | # Compute Dirichlet likelihood 93 | # loss_dir = self.dirichlet_ll() 94 | return loss_kld.sum() # + loss_dir 95 | 96 | def __call__(self, *args): 97 | return self.forward(*args) 98 | --------------------------------------------------------------------------------