├── animate.sh ├── figures └── test.png ├── .gitignore ├── dynamic.py ├── README.md ├── main.py ├── visualize.py ├── data └── mnist_np.py └── tsne_hack.py /animate.sh: -------------------------------------------------------------------------------- 1 | cd data 2 | python3 mnist_np.py 3 | cd .. 4 | python3 main.py 5 | -------------------------------------------------------------------------------- /figures/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/tSNE-Animation/HEAD/figures/test.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | data/ 4 | *.pkl 5 | *.gif 6 | *.mat 7 | test.py 8 | 9 | -------------------------------------------------------------------------------- /dynamic.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from visualize import savegif 3 | 4 | def main(dataset): 5 | 6 | data_path = './data/%s.pkl' % dataset 7 | with open(data_path, 'rb') as f: 8 | X, labels = pickle.load(f) 9 | 10 | with open('mnist.pkl', 'rb') as f: 11 | Y_seq = pickle.load(f) 12 | 13 | fig_name = '%s-tsne' % dataset 14 | fig_path = './figures/%s.gif' % fig_name 15 | savegif(Y_seq, labels, fig_name, fig_path) 16 | 17 | if __name__ == '__main__': 18 | main('mnist70k') 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tSNE-Animation 2 | Hacking sklearn's t-SNE implementation to animate embedding process 3 | 4 | To generate the animation, just run 5 | 6 | source animate.sh 7 | 8 | Result on MNIST dataset, early exaggeration for first 250 iters: 9 | 10 | ![mnist70k-tsne.gif](https://github.com/KellerJordan/figures/blob/master/mnist70k-tsne.gif) 11 | 12 | MNIST, early exaggeration all the way through, fixed window: 13 | 14 | ![mnist70k-500-499-tsne.gif](https://github.com/KellerJordan/figures/blob/master/mnist70k-500-499-tsne.gif) 15 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import argparse 5 | 6 | from sklearn.manifold import TSNE 7 | 8 | from tsne_hack import extract_sequence 9 | from visualize import savegif 10 | 11 | def main(args): 12 | data_path = './data/%s.pkl' % args.dataset 13 | with open(data_path, 'rb') as f: 14 | X, labels = pickle.load(f) 15 | 16 | tsne = TSNE(n_iter=args.early_iters, verbose=True) 17 | tsne._EXPLORATION_N_ITER = args.early_iters 18 | Y_seq = extract_sequence(tsne, X) 19 | with open('results/res.pkl', 'wb') as f: 20 | pickle.dump(Y_seq, f) 21 | 22 | if not os.path.exists('figures'): 23 | os.mkdir('figures') 24 | 25 | lo = Y_seq.min(axis=0).min(axis=0).max() 26 | hi = Y_seq.max(axis=0).max(axis=0).min() 27 | limits = ([lo, hi], [lo, hi]) 28 | fig_name = '%s-%d-%d-tsne' % (args.dataset, args.num_iters, args.early_iters) 29 | fig_path = './figures/%s.gif' % fig_name 30 | savegif(Y_seq, labels, fig_name, fig_path, limits=limits) 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--dataset', default='mnist70k') 35 | parser.add_argument('--num_iters', type=int, default=1000) 36 | parser.add_argument('--early_iters', type=int, default=250) 37 | args = parser.parse_args() 38 | main(args) 39 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import matplotlib as mpl 4 | mpl.use('Agg') 5 | import matplotlib.pyplot as plt 6 | plt.rcParams['image.cmap'] = 'tab10' 7 | from matplotlib.animation import FuncAnimation 8 | import matplotlib.patches as mpatches 9 | 10 | 11 | def init_plot(): 12 | cmap = plt.get_cmap() 13 | fig, ax = plt.subplots() 14 | fig.set_figheight(10) 15 | fig.set_figwidth(10) 16 | patches = [mpatches.Patch(color=cmap.colors[i], label=str(i)) for i in range(10)] 17 | return fig, ax, patches 18 | 19 | def savegif(Y_seq, labels, fig_name, path, limits=None): 20 | fig, ax, patches = init_plot() 21 | 22 | def init(): 23 | return scatter, 24 | 25 | def update(i): 26 | if (i+1) % 50 == 0: 27 | print('[%d / %d] Animating frames' % (i+1, len(Y_seq))) 28 | ax.clear() 29 | if limits is not None: 30 | ax.set_xlim(limits[0]) 31 | ax.set_ylim(limits[1]) 32 | plt.legend(handles=patches, loc='upper right') 33 | ax.scatter(Y_seq[i][:, 0], Y_seq[i][:, 1], 1, labels) 34 | ax.set_title('%s (epoch %d)' % (fig_name, i)) 35 | return ax, scatter 36 | 37 | anim = FuncAnimation(fig, update, init_func=init, 38 | frames=len(Y_seq), interval=50) 39 | print('[*] Saving animation as %s' % path) 40 | anim.save(path, writer='imagemagick', fps=30) 41 | 42 | def savepng(Y, labels, fig_name, path): 43 | fig, ax, patches = init_plot() 44 | ax.scatter(Y[:, 0], Y[:, 1], 1, labels) 45 | ax.set_title(fig_name) 46 | print('[*] Saving figure as %s' % path) 47 | plt.savefig(path) 48 | 49 | def scatter(Y, labels): 50 | fig, ax, patches = init_plot() 51 | ax.scatter(Y[:, 0], Y[:, 1], 1, labels) 52 | plt.show() 53 | -------------------------------------------------------------------------------- /data/mnist_np.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from 3 | https://github.com/hsjeong5/MNIST-for-Numpy/blob/master/mnist.py 4 | 5 | Run this from within /data/! 6 | """ 7 | 8 | import numpy as np 9 | from urllib import request 10 | import gzip 11 | import pickle 12 | 13 | filename = [ 14 | ["training_images","train-images-idx3-ubyte.gz"], 15 | ["test_images","t10k-images-idx3-ubyte.gz"], 16 | ["training_labels","train-labels-idx1-ubyte.gz"], 17 | ["test_labels","t10k-labels-idx1-ubyte.gz"] 18 | ] 19 | 20 | def download_mnist(): 21 | base_url = "http://yann.lecun.com/exdb/mnist/" 22 | for name in filename: 23 | print("Downloading "+name[1]+"...") 24 | request.urlretrieve(base_url+name[1], name[1]) 25 | print("Download complete.") 26 | 27 | def save_mnist(): 28 | mnist = {} 29 | for name in filename[:2]: 30 | with gzip.open(name[1], 'rb') as f: 31 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 784) 32 | for name in filename[-2:]: 33 | with gzip.open(name[1], 'rb') as f: 34 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) 35 | 36 | # modification: merge training and test datasets 37 | mnist_merge = {} 38 | mnist_merge['images'] = np.concatenate([mnist['training_images'], mnist['test_images']]) 39 | mnist_merge['labels'] = np.concatenate([mnist['training_labels'], mnist['test_labels']]) 40 | datasets = {} 41 | datasets['mnist70k'] = (mnist_merge['images'], mnist_merge['labels']) 42 | datasets['mnist10k'] = (mnist_merge['images'][:10000], mnist_merge['labels'][:10000]) 43 | datasets['mnist2500'] = (mnist_merge['images'][:2500], mnist_merge['labels'][:2500]) 44 | datasets['mnist250'] = (mnist_merge['images'][:250], mnist_merge['labels'][:250]) 45 | 46 | for name, data in datasets.items(): 47 | with open('%s.pkl' % name, 'wb') as f: 48 | pickle.dump(data, f) 49 | print("Save complete.") 50 | 51 | if __name__ == '__main__': 52 | download_mnist() 53 | save_mnist() 54 | -------------------------------------------------------------------------------- /tsne_hack.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | from sklearn import manifold 5 | 6 | def extract_sequence(tsne, X): 7 | sklearn_grad = manifold.t_sne._gradient_descent 8 | Y_seq = [] 9 | 10 | # modified from sklearn source https://github.com/scikit-learn/scikit-learn/blob/a24c8b46/sklearn/manifold/t_sne.py#L442 11 | # to save the sequence of embeddings at each training iteration 12 | def _gradient_descent(objective, p0, it, n_iter, 13 | n_iter_check=1, n_iter_without_progress=300, 14 | momentum=0.8, learning_rate=200.0, min_gain=0.01, 15 | min_grad_norm=1e-7, verbose=0, args=None, kwargs=None): 16 | if args is None: 17 | args = [] 18 | if kwargs is None: 19 | kwargs = {} 20 | 21 | p = p0.copy().ravel() 22 | update = np.zeros_like(p) 23 | gains = np.ones_like(p) 24 | error = np.finfo(np.float).max 25 | best_error = np.finfo(np.float).max 26 | best_iter = i = it 27 | 28 | tic = time() 29 | for i in range(it, n_iter): 30 | 31 | # save the current state 32 | Y_seq.append(p.copy().reshape(-1, 2)) 33 | 34 | error, grad = objective(p, *args, **kwargs) 35 | grad_norm = np.linalg.norm(grad) 36 | 37 | inc = update * grad < 0.0 38 | dec = np.invert(inc) 39 | gains[inc] += 0.2 40 | gains[dec] *= 0.8 41 | np.clip(gains, min_gain, np.inf, out=gains) 42 | grad *= gains 43 | update = momentum * update - learning_rate * grad 44 | p += update 45 | 46 | if (i + 1) % n_iter_check == 0: 47 | toc = time() 48 | duration = toc - tic 49 | tic = toc 50 | 51 | if verbose >= 2: 52 | print("[t-SNE] Iteration %d: error = %.7f," 53 | " gradient norm = %.7f" 54 | " (%s iterations in %0.3fs)" 55 | % (i + 1, error, grad_norm, n_iter_check, duration)) 56 | 57 | if error < best_error: 58 | best_error = error 59 | best_iter = i 60 | elif i - best_iter > n_iter_without_progress: 61 | if verbose >= 2: 62 | print("[t-SNE] Iteration %d: did not make any progress " 63 | "during the last %d episodes. Finished." 64 | % (i + 1, n_iter_without_progress)) 65 | break 66 | if grad_norm <= min_grad_norm: 67 | if verbose >= 2: 68 | print("[t-SNE] Iteration %d: gradient norm %f. Finished." 69 | % (i + 1, grad_norm)) 70 | break 71 | 72 | return p, error, i 73 | 74 | # replace with modified gradient descent 75 | manifold.t_sne._gradient_descent = _gradient_descent 76 | # train given tsne object with new gradient function 77 | X_proj = tsne.fit_transform(X) 78 | # return to default version 79 | manifold.t_sne._gradient_descent = sklearn_grad 80 | 81 | return np.array(Y_seq) 82 | 83 | --------------------------------------------------------------------------------