├── mnist_process.mp4 ├── README.md └── mlp_param_tsne.py /mnist_process.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaburo-ch/Parametric-t-SNE-in-Keras/HEAD/mnist_process.mp4 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Parametric t-SNE 2 | ==== 3 | An implementation of ["Parametric t-SNE"](https://lvdmaaten.github.io/publications/papers/AISTATS_2009.pdf) in Keras. 4 | Authors used stacked RBM in the paper, but I used simple ReLU units instead. 5 | 6 | I used [the python implementation of t-SNE](https://lvdmaaten.github.io/tsne/code/tsne_python.zip) by [Laurens van der Maaten](https://lvdmaaten.github.io/) as a reference. 7 | 8 | For some reason, this code is not working on Keras 1.0.4. 9 | if you use 1.0.4, please reinstall by ```pip install Keras==1.0.3``` 10 | -------------------------------------------------------------------------------- /mlp_param_tsne.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(71) 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | 7 | from keras import backend as K 8 | from keras.models import Sequential 9 | from keras.layers.core import Dense, Dropout, Activation, Flatten 10 | from keras.layers.convolutional import Convolution2D, MaxPooling2D 11 | from keras.optimizers import SGD 12 | from keras.callbacks import Callback 13 | from keras.utils import np_utils 14 | from keras.objectives import categorical_crossentropy 15 | from keras.datasets import cifar10, mnist 16 | 17 | import matplotlib.pyplot as plt 18 | from matplotlib.animation import ArtistAnimation 19 | 20 | import multiprocessing as mp 21 | 22 | 23 | batch_size = 5000 24 | low_dim = 2 25 | nb_epoch = 100 26 | shuffle_interval = nb_epoch + 1 27 | n_jobs = 4 28 | perplexity = 30.0 29 | 30 | 31 | def Hbeta(D, beta): 32 | P = np.exp(-D * beta) 33 | sumP = np.sum(P) 34 | H = np.log(sumP) + beta * np.sum(D * P) / sumP 35 | P = P / sumP 36 | return H, P 37 | 38 | 39 | def x2p_job(data): 40 | i, Di, tol, logU = data 41 | beta = 1.0 42 | betamin = -np.inf 43 | betamax = np.inf 44 | H, thisP = Hbeta(Di, beta) 45 | 46 | Hdiff = H - logU 47 | tries = 0 48 | while np.abs(Hdiff) > tol and tries < 50: 49 | if Hdiff > 0: 50 | betamin = beta 51 | if betamax == -np.inf: 52 | beta = beta * 2 53 | else: 54 | beta = (betamin + betamax) / 2 55 | else: 56 | betamax = beta 57 | if betamin == -np.inf: 58 | beta = beta / 2 59 | else: 60 | beta = (betamin + betamax) / 2 61 | 62 | H, thisP = Hbeta(Di, beta) 63 | Hdiff = H - logU 64 | tries += 1 65 | 66 | return i, thisP 67 | 68 | 69 | def x2p(X): 70 | tol = 1e-5 71 | n = X.shape[0] 72 | logU = np.log(perplexity) 73 | 74 | sum_X = np.sum(np.square(X), axis=1) 75 | D = sum_X + (sum_X.reshape([-1, 1]) - 2 * np.dot(X, X.T)) 76 | 77 | idx = (1 - np.eye(n)).astype(bool) 78 | D = D[idx].reshape([n, -1]) 79 | 80 | def generator(): 81 | for i in xrange(n): 82 | yield i, D[i], tol, logU 83 | 84 | pool = mp.Pool(n_jobs) 85 | result = pool.map(x2p_job, generator()) 86 | P = np.zeros([n, n]) 87 | for i, thisP in result: 88 | P[i, idx[i]] = thisP 89 | 90 | return P 91 | 92 | 93 | def calculate_P(X): 94 | print "Computing pairwise distances..." 95 | n = X.shape[0] 96 | P = np.zeros([n, batch_size]) 97 | for i in xrange(0, n, batch_size): 98 | P_batch = x2p(X[i:i + batch_size]) 99 | P_batch[np.isnan(P_batch)] = 0 100 | P_batch = P_batch + P_batch.T 101 | P_batch = P_batch / P_batch.sum() 102 | P_batch = np.maximum(P_batch, 1e-12) 103 | P[i:i + batch_size] = P_batch 104 | return P 105 | 106 | 107 | def KLdivergence(P, Y): 108 | alpha = low_dim - 1. 109 | sum_Y = K.sum(K.square(Y), axis=1) 110 | eps = K.variable(10e-15) 111 | D = sum_Y + K.reshape(sum_Y, [-1, 1]) - 2 * K.dot(Y, K.transpose(Y)) 112 | Q = K.pow(1 + D / alpha, -(alpha + 1) / 2) 113 | Q *= K.variable(1 - np.eye(batch_size)) 114 | Q /= K.sum(Q) 115 | Q = K.maximum(Q, eps) 116 | C = K.log((P + eps) / (Q + eps)) 117 | C = K.sum(P * C) 118 | return C 119 | 120 | 121 | print "load data" 122 | # # cifar-10 123 | # (X_train, y_train), (X_test, y_test) = cifar10.load_data() 124 | # n, channel, row, col = X_train.shape 125 | 126 | # # mnist 127 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 128 | n, row, col = X_train.shape 129 | channel = 1 130 | 131 | X_train = X_train.reshape(-1, channel * row * col) 132 | X_test = X_test.reshape(-1, channel * row * col) 133 | X_train = X_train.astype('float32') 134 | X_test = X_test.astype('float32') 135 | X_train /= 255 136 | X_test /= 255 137 | print "X_train.shape:", X_train.shape 138 | print "X_test.shape:", X_test.shape 139 | 140 | batch_num = int(n // batch_size) 141 | m = batch_num * batch_size 142 | 143 | 144 | print "build model" 145 | model = Sequential() 146 | model.add(Dense(500, input_shape=(X_train.shape[1],))) 147 | model.add(Activation('relu')) 148 | model.add(Dense(500)) 149 | model.add(Activation('relu')) 150 | model.add(Dense(2000)) 151 | model.add(Activation('relu')) 152 | model.add(Dense(2)) 153 | 154 | model.compile(loss=KLdivergence, optimizer="adam") 155 | 156 | 157 | print "fit" 158 | images = [] 159 | fig = plt.figure(figsize=(5, 5)) 160 | 161 | for epoch in range(nb_epoch): 162 | # shuffle X_train and calculate P 163 | if epoch % shuffle_interval == 0: 164 | X = X_train[np.random.permutation(n)[:m]] 165 | P = calculate_P(X) 166 | 167 | # train 168 | loss = 0 169 | for i in xrange(0, n, batch_size): 170 | loss += model.train_on_batch(X[i:i+batch_size], P[i:i+batch_size]) 171 | print "Epoch: {}/{}, loss: {}".format(epoch+1, nb_epoch, loss / batch_num) 172 | 173 | # visualize training process 174 | pred = model.predict(X_test) 175 | img = plt.scatter(pred[:, 0], pred[:, 1], c=y_test, 176 | marker='o', s=3, edgecolor='') 177 | images.append([img]) 178 | 179 | ani = ArtistAnimation(fig, images, interval=100, repeat_delay=2000) 180 | ani.save("mlp_process.mp4") 181 | 182 | plt.clf() 183 | fig = plt.figure(figsize=(5, 5)) 184 | pred = model.predict(X_test) 185 | plt.scatter(pred[:, 0], pred[:, 1], c=y_test, marker='o', s=4, edgecolor='') 186 | fig.tight_layout() 187 | 188 | plt.savefig("mlp_result.png") 189 | --------------------------------------------------------------------------------