├── DEC.ipynb ├── DEC.py └── README.md /DEC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "from torch.autograd import Variable\n", 13 | "import torchvision.datasets as dset\n", 14 | "import torchvision.transforms as transforms\n", 15 | "import torch.nn.functional as F\n", 16 | "import torch.optim as optim\n", 17 | "from sklearn.cluster import MiniBatchKMeans, KMeans\n", 18 | "## load mnist dataset\n", 19 | "use_cuda = torch.cuda.is_available()\n", 20 | "root = './data'\n", 21 | "if not os.path.exists(root):\n", 22 | " os.mkdir(root)\n", 23 | "#trans = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n", 24 | "trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n", 25 | "# if not exist, download mnist dataset\n", 26 | "train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)\n", 27 | "test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)\n", 28 | "batch_size = 100\n", 29 | "train_loader = torch.utils.data.DataLoader(\n", 30 | " dataset=train_set,\n", 31 | " batch_size=batch_size,\n", 32 | " shuffle=True)\n", 33 | "test_loader = torch.utils.data.DataLoader(\n", 34 | " dataset=test_set,\n", 35 | " batch_size=batch_size,\n", 36 | " shuffle=False)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import time\n", 46 | "timingResult = {}\n", 47 | "def logTime(theName, currentTime):\n", 48 | " if theName not in timingResult:\n", 49 | " timingResult[theName] = time.time() - currentTime\n", 50 | " else:\n", 51 | " timingResult[theName] = timingResult[theName] + (time.time() - currentTime)\n", 52 | " currentTime = time.time()\n", 53 | " return currentTime\n", 54 | "\n", 55 | "def printTiming(name):\n", 56 | " print('======== timing for {}: {} ======='.format(name,timingResult[name]))\n" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "class DEC_AE(nn.Module):\n", 66 | " def __init__(self, num_classes, num_features):\n", 67 | " super(DEC_AE,self).__init__()\n", 68 | " self.dropout = nn.Dropout(p=0.1)\n", 69 | " self.fc1 = nn.Linear(28*28,500)\n", 70 | " self.fc2 = nn.Linear(500,500)\n", 71 | " self.fc3 = nn.Linear(500,2000)\n", 72 | " self.fc4 = nn.Linear(2000,num_features)\n", 73 | " self.relu = nn.ReLU()\n", 74 | " self.fc_d1 = nn.Linear(500,28*28)\n", 75 | " self.fc_d2 = nn.Linear(500,500)\n", 76 | " self.fc_d3 = nn.Linear(2000,500)\n", 77 | " self.fc_d4 = nn.Linear(num_features,2000)\n", 78 | " self.alpha = 1.0\n", 79 | " self.clusterCenter = nn.Parameter(torch.zeros(num_classes,num_features))\n", 80 | " self.pretrainMode = True\n", 81 | " for m in self.modules():\n", 82 | " if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):\n", 83 | " torch.nn.init.xavier_uniform(m.weight)\n", 84 | "\n", 85 | " def setPretrain(self,mode):\n", 86 | " \"\"\"To set training mode to pretrain or not, \n", 87 | " so that it can control to run only the Encoder or Encoder+Decoder\"\"\"\n", 88 | " self.pretrainMode = mode\n", 89 | " def updateClusterCenter(self, cc):\n", 90 | " \"\"\"\n", 91 | " To update the cluster center. This is a method for pre-train phase.\n", 92 | " When a center is being provided by kmeans, we need to update it so\n", 93 | " that it is available for further training\n", 94 | " :param cc: the cluster centers to update, size of num_classes x num_features\n", 95 | " \"\"\"\n", 96 | " self.clusterCenter.data = torch.from_numpy(cc)\n", 97 | " def getTDistribution(self,x, clusterCenter):\n", 98 | " \"\"\"\n", 99 | " student t-distribution, as same as used in t-SNE algorithm.\n", 100 | " q_ij = 1/(1+dist(x_i, u_j)^2), then normalize it.\n", 101 | " \n", 102 | " :param x: input data, in this context it is encoder output\n", 103 | " :param clusterCenter: the cluster center from kmeans\n", 104 | " \"\"\"\n", 105 | " xe = torch.unsqueeze(x,1).cuda() - clusterCenter.cuda()\n", 106 | " q = 1.0 / (1.0 + (torch.sum(torch.mul(xe,xe), 2) / self.alpha))\n", 107 | " q = q ** (self.alpha + 1.0) / 2.0\n", 108 | " q = (q.t() / torch.sum(q, 1)).t() #due to divison, we need to transpose q\n", 109 | " return q\n", 110 | " \n", 111 | " def forward(self,x):\n", 112 | " x = x.view(-1, 1*28*28)\n", 113 | " # 32x32x1\n", 114 | " x = self.dropout(x)\n", 115 | " # 32x32x1\n", 116 | " x = self.fc1(x)\n", 117 | " # 17x17x50\n", 118 | " x = self.relu(x)\n", 119 | " # 17x17x50\n", 120 | " x = self.fc2(x)\n", 121 | " # 17x17x50\n", 122 | " x = self.relu(x)\n", 123 | " # 9x9x50\n", 124 | " x = self.fc3(x)\n", 125 | " # 17x17x50\n", 126 | " x = self.relu(x)\n", 127 | " x = self.fc4(x)\n", 128 | " # 9x9x50\n", 129 | " x_ae = x\n", 130 | " #if not in pretrain mode, we only need encoder\n", 131 | " if self.pretrainMode == False:\n", 132 | " return x, self.getTDistribution(x,self.clusterCenter)\n", 133 | " # 1x68\n", 134 | " ##### encoder is done, followed by decoder #####\n", 135 | " # 1x68\n", 136 | " x = self.fc_d4(x)\n", 137 | " # 1x4050\n", 138 | " x = self.relu(x)\n", 139 | " # 1x4050\n", 140 | " x = self.fc_d3(x)\n", 141 | " # 1x4050\n", 142 | " x = self.relu(x)\n", 143 | " x = self.fc_d2(x)\n", 144 | " # 1x4050\n", 145 | " x = self.relu(x)\n", 146 | " x = self.fc_d1(x)\n", 147 | " x_de = x.view(-1,1,28,28)\n", 148 | " # 1x4050\n", 149 | " return x_ae, x_de" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "import numpy as np\n", 159 | "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\n", 160 | "\n", 161 | "nmi = normalized_mutual_info_score\n", 162 | "ari = adjusted_rand_score\n", 163 | "\n", 164 | "\n", 165 | "def acc(y_true, y_pred):\n", 166 | " \"\"\"\n", 167 | " Calculate clustering accuracy. Require scikit-learn installed\n", 168 | " # Arguments\n", 169 | " y: true labels, numpy.array with shape `(n_samples,)`\n", 170 | " y_pred: predicted labels, numpy.array with shape `(n_samples,)`\n", 171 | " # Return\n", 172 | " accuracy, in [0,1]\n", 173 | " \"\"\"\n", 174 | " y_true = y_true.astype(np.int64)\n", 175 | " assert y_pred.size == y_true.size\n", 176 | " D = max(y_pred.max(), y_true.max()) + 1\n", 177 | " w = np.zeros((D, D), dtype=np.int64)\n", 178 | " for i in range(y_pred.size):\n", 179 | " w[y_pred[i], y_true[i]] += 1\n", 180 | " from sklearn.utils.linear_assignment_ import linear_assignment\n", 181 | " ind = linear_assignment(w.max() - w)\n", 182 | " return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "class DEC:\n", 192 | " \"\"\"The class for controlling the training process of DEC\"\"\"\n", 193 | " def __init__(self,n_clusters,alpha=1.0):\n", 194 | " self.n_clusters=n_clusters\n", 195 | " self.alpha = alpha\n", 196 | " \n", 197 | " @staticmethod\n", 198 | " def target_distribution(q):\n", 199 | " weight = q ** 2 / q.sum(0)\n", 200 | " #print('q',q)\n", 201 | " return Variable((weight.t() / weight.sum(1)).t().data, requires_grad=True)\n", 202 | " def logAccuracy(self,pred,label):\n", 203 | " print(' '*8 + '|==> acc: %.4f, nmi: %.4f <==|'\n", 204 | " % (acc(label, pred), nmi(label, pred)))\n", 205 | " @staticmethod\n", 206 | " def kld(q,p):\n", 207 | " res = torch.sum(p*torch.log(p/q),dim=-1)\n", 208 | " return res\n", 209 | " \n", 210 | " def validateOnCompleteTestData(self,test_loader,model):\n", 211 | " model.eval()\n", 212 | " to_eval = np.array([model(d[0].cuda())[0].data.cpu().numpy() for i,d in enumerate(test_loader)])\n", 213 | " true_labels = np.array([d[1].cpu().numpy() for i,d in enumerate(test_loader)])\n", 214 | " to_eval = np.reshape(to_eval,(to_eval.shape[0]*to_eval.shape[1],to_eval.shape[2]))\n", 215 | " true_labels = np.reshape(true_labels,true_labels.shape[0]*true_labels.shape[1])\n", 216 | " km = KMeans(n_clusters=len(np.unique(true_labels)), n_init=20, n_jobs=4)\n", 217 | " y_pred = km.fit_predict(to_eval)\n", 218 | " print(' '*8 + '|==> acc: %.4f, nmi: %.4f <==|'\n", 219 | " % (acc(true_labels, y_pred), nmi(true_labels, y_pred)))\n", 220 | " currentAcc = acc(true_labels, y_pred)\n", 221 | " return currentAcc\n", 222 | " \n", 223 | " def pretrain(self,train_loader, test_loader, epochs):\n", 224 | " dec_ae = DEC_AE(10,10).cuda() #auto encoder\n", 225 | " mseloss = nn.MSELoss()\n", 226 | " optimizer = optim.SGD(dec_ae.parameters(),lr = 1, momentum=0.9)\n", 227 | " best_acc = 0.0\n", 228 | " for epoch in range(epochs):\n", 229 | " dec_ae.train()\n", 230 | " running_loss=0.0\n", 231 | " for i,data in enumerate(train_loader):\n", 232 | " x, label = data\n", 233 | " x, label = Variable(x).cuda(),Variable(label).cuda()\n", 234 | " optimizer.zero_grad()\n", 235 | " x_ae,x_de = dec_ae(x)\n", 236 | " loss = F.mse_loss(x_de,x,reduce=True) \n", 237 | " loss.backward()\n", 238 | " optimizer.step()\n", 239 | " x_eval = x.data.cpu().numpy()\n", 240 | " label_eval = label.data.cpu().numpy()\n", 241 | " running_loss += loss.data.cpu().numpy()[0]\n", 242 | " if i % 100 == 99: # print every 100 mini-batches\n", 243 | " print('[%d, %5d] loss: %.7f' %\n", 244 | " (epoch + 1, i + 1, running_loss / 100))\n", 245 | " running_loss = 0.0\n", 246 | " #now we evaluate the accuracy with AE\n", 247 | " dec_ae.eval()\n", 248 | " currentAcc = self.validateOnCompleteTestData(test_loader,dec_ae)\n", 249 | " if currentAcc > best_acc: \n", 250 | " torch.save(dec_ae,'bestModel'.format(best_acc))\n", 251 | " best_acc = currentAcc\n", 252 | " def clustering(self,mbk,x,model):\n", 253 | " model.eval()\n", 254 | " y_pred_ae,_ = model(x)\n", 255 | " y_pred_ae = y_pred_ae.data.cpu().numpy()\n", 256 | " y_pred = mbk.partial_fit(y_pred_ae) #seems we can only get a centre from batch\n", 257 | " self.cluster_centers = mbk.cluster_centers_ #keep the cluster centers\n", 258 | " model.updateClusterCenter(self.cluster_centers)\n", 259 | " def train(self,train_loader, test_loader, epochs):\n", 260 | " \"\"\"This method will start training for DEC cluster\"\"\"\n", 261 | " ct = time.time()\n", 262 | " model = torch.load(\"bestModel\").cuda()\n", 263 | " model.setPretrain(False)\n", 264 | " optimizer = optim.SGD([\\\n", 265 | " {'params': model.parameters()}, \\\n", 266 | " ],lr = 0.01, momentum=0.9)\n", 267 | " print('Initializing cluster center with pre-trained weights')\n", 268 | " mbk = MiniBatchKMeans(n_clusters=self.n_clusters, n_init=20, batch_size=batch_size)\n", 269 | " got_cluster_center = False\n", 270 | " for epoch in range(epochs):\n", 271 | " for i,data in enumerate(train_loader):\n", 272 | " x, label = data\n", 273 | " x = Variable(x).cuda()\n", 274 | " optimizer.zero_grad()\n", 275 | " #step 1 - get cluster center from batch\n", 276 | " #here we are using minibatch kmeans to be able to cope with larger dataset.\n", 277 | " if not got_cluster_center:\n", 278 | " self.clustering(mbk,x,model)\n", 279 | " if epoch > 1:\n", 280 | " got_cluster_center = True\n", 281 | " else:\n", 282 | " model.train()\n", 283 | " #now we start training with acquired cluster center\n", 284 | " feature_pred,q = model(x)\n", 285 | " #get target distribution\n", 286 | " p = self.target_distribution(q)\n", 287 | " #print('q',q,'p',p)\n", 288 | " loss = self.kld(q,p).mean()\n", 289 | " loss.backward()\n", 290 | " optimizer.step()\n", 291 | " currentAcc = self.validateOnCompleteTestData(test_loader,model)\n" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 11, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "Initializing cluster center with pre-trained weights\n", 304 | " |==> acc: 0.8434, nmi: 0.7760 <==|\n", 305 | " |==> acc: 0.8415, nmi: 0.7748 <==|\n", 306 | " |==> acc: 0.8543, nmi: 0.8295 <==|\n", 307 | " |==> acc: 0.8629, nmi: 0.8423 <==|\n", 308 | " |==> acc: 0.8663, nmi: 0.8458 <==|\n", 309 | " |==> acc: 0.8677, nmi: 0.8521 <==|\n", 310 | " |==> acc: 0.8696, nmi: 0.8550 <==|\n", 311 | " |==> acc: 0.8715, nmi: 0.8586 <==|\n", 312 | " |==> acc: 0.8733, nmi: 0.8600 <==|\n", 313 | " |==> acc: 0.8734, nmi: 0.8600 <==|\n", 314 | " |==> acc: 0.8714, nmi: 0.8575 <==|\n", 315 | " |==> acc: 0.8745, nmi: 0.8625 <==|\n", 316 | " |==> acc: 0.8736, nmi: 0.8622 <==|\n", 317 | " |==> acc: 0.8736, nmi: 0.8622 <==|\n", 318 | " |==> acc: 0.8742, nmi: 0.8621 <==|\n", 319 | " |==> acc: 0.8725, nmi: 0.8615 <==|\n", 320 | " |==> acc: 0.8741, nmi: 0.8638 <==|\n", 321 | " |==> acc: 0.8726, nmi: 0.8610 <==|\n", 322 | " |==> acc: 0.8738, nmi: 0.8649 <==|\n", 323 | " |==> acc: 0.8744, nmi: 0.8626 <==|\n", 324 | " |==> acc: 0.8758, nmi: 0.8660 <==|\n", 325 | " |==> acc: 0.8769, nmi: 0.8683 <==|\n", 326 | " |==> acc: 0.8761, nmi: 0.8668 <==|\n", 327 | " |==> acc: 0.8762, nmi: 0.8674 <==|\n", 328 | " |==> acc: 0.8759, nmi: 0.8666 <==|\n", 329 | " |==> acc: 0.8743, nmi: 0.8662 <==|\n", 330 | " |==> acc: 0.8755, nmi: 0.8662 <==|\n", 331 | " |==> acc: 0.8760, nmi: 0.8672 <==|\n", 332 | " |==> acc: 0.8770, nmi: 0.8686 <==|\n", 333 | " |==> acc: 0.8757, nmi: 0.8654 <==|\n", 334 | " |==> acc: 0.8757, nmi: 0.8656 <==|\n", 335 | " |==> acc: 0.8760, nmi: 0.8670 <==|\n", 336 | " |==> acc: 0.8765, nmi: 0.8669 <==|\n", 337 | " |==> acc: 0.8762, nmi: 0.8658 <==|\n", 338 | " |==> acc: 0.8767, nmi: 0.8659 <==|\n", 339 | " |==> acc: 0.8772, nmi: 0.8689 <==|\n", 340 | " |==> acc: 0.8759, nmi: 0.8667 <==|\n", 341 | " |==> acc: 0.8754, nmi: 0.8654 <==|\n", 342 | " |==> acc: 0.8771, nmi: 0.8683 <==|\n", 343 | " |==> acc: 0.8758, nmi: 0.8663 <==|\n", 344 | " |==> acc: 0.8751, nmi: 0.8656 <==|\n", 345 | " |==> acc: 0.8754, nmi: 0.8670 <==|\n", 346 | " |==> acc: 0.8764, nmi: 0.8678 <==|\n", 347 | " |==> acc: 0.8730, nmi: 0.8640 <==|\n", 348 | " |==> acc: 0.8743, nmi: 0.8668 <==|\n", 349 | " |==> acc: 0.8746, nmi: 0.8673 <==|\n", 350 | " |==> acc: 0.8736, nmi: 0.8656 <==|\n", 351 | " |==> acc: 0.8752, nmi: 0.8674 <==|\n", 352 | " |==> acc: 0.8759, nmi: 0.8682 <==|\n", 353 | " |==> acc: 0.8743, nmi: 0.8645 <==|\n", 354 | " |==> acc: 0.8747, nmi: 0.8667 <==|\n", 355 | " |==> acc: 0.8751, nmi: 0.8677 <==|\n", 356 | " |==> acc: 0.8744, nmi: 0.8661 <==|\n", 357 | " |==> acc: 0.8756, nmi: 0.8695 <==|\n", 358 | " |==> acc: 0.8751, nmi: 0.8681 <==|\n", 359 | " |==> acc: 0.8752, nmi: 0.8675 <==|\n", 360 | " |==> acc: 0.8753, nmi: 0.8696 <==|\n", 361 | " |==> acc: 0.8753, nmi: 0.8693 <==|\n", 362 | " |==> acc: 0.8758, nmi: 0.8696 <==|\n", 363 | " |==> acc: 0.8754, nmi: 0.8679 <==|\n", 364 | " |==> acc: 0.8757, nmi: 0.8678 <==|\n", 365 | " |==> acc: 0.8747, nmi: 0.8683 <==|\n", 366 | " |==> acc: 0.8751, nmi: 0.8678 <==|\n", 367 | " |==> acc: 0.8747, nmi: 0.8671 <==|\n", 368 | " |==> acc: 0.8743, nmi: 0.8672 <==|\n", 369 | " |==> acc: 0.8732, nmi: 0.8641 <==|\n", 370 | " |==> acc: 0.8751, nmi: 0.8677 <==|\n", 371 | " |==> acc: 0.8758, nmi: 0.8682 <==|\n", 372 | " |==> acc: 0.8745, nmi: 0.8663 <==|\n", 373 | " |==> acc: 0.8757, nmi: 0.8680 <==|\n", 374 | " |==> acc: 0.8747, nmi: 0.8672 <==|\n", 375 | " |==> acc: 0.8758, nmi: 0.8688 <==|\n", 376 | " |==> acc: 0.8744, nmi: 0.8675 <==|\n", 377 | " |==> acc: 0.8750, nmi: 0.8678 <==|\n", 378 | " |==> acc: 0.8739, nmi: 0.8648 <==|\n", 379 | " |==> acc: 0.8752, nmi: 0.8674 <==|\n", 380 | " |==> acc: 0.8738, nmi: 0.8656 <==|\n", 381 | " |==> acc: 0.8730, nmi: 0.8641 <==|\n", 382 | " |==> acc: 0.8750, nmi: 0.8662 <==|\n", 383 | " |==> acc: 0.8747, nmi: 0.8671 <==|\n", 384 | " |==> acc: 0.8738, nmi: 0.8665 <==|\n", 385 | " |==> acc: 0.8744, nmi: 0.8671 <==|\n", 386 | " |==> acc: 0.8742, nmi: 0.8665 <==|\n", 387 | " |==> acc: 0.8740, nmi: 0.8661 <==|\n", 388 | " |==> acc: 0.8747, nmi: 0.8680 <==|\n", 389 | " |==> acc: 0.8736, nmi: 0.8650 <==|\n", 390 | " |==> acc: 0.8747, nmi: 0.8677 <==|\n", 391 | " |==> acc: 0.8743, nmi: 0.8667 <==|\n", 392 | " |==> acc: 0.8742, nmi: 0.8655 <==|\n", 393 | " |==> acc: 0.8747, nmi: 0.8672 <==|\n", 394 | " |==> acc: 0.8745, nmi: 0.8659 <==|\n", 395 | " |==> acc: 0.8746, nmi: 0.8661 <==|\n", 396 | " |==> acc: 0.8758, nmi: 0.8675 <==|\n", 397 | " |==> acc: 0.8757, nmi: 0.8675 <==|\n", 398 | " |==> acc: 0.8753, nmi: 0.8668 <==|\n", 399 | " |==> acc: 0.8755, nmi: 0.8673 <==|\n", 400 | " |==> acc: 0.8756, nmi: 0.8676 <==|\n", 401 | " |==> acc: 0.8757, nmi: 0.8667 <==|\n", 402 | " |==> acc: 0.8736, nmi: 0.8652 <==|\n", 403 | " |==> acc: 0.8766, nmi: 0.8693 <==|\n", 404 | " |==> acc: 0.8743, nmi: 0.8663 <==|\n", 405 | " |==> acc: 0.8760, nmi: 0.8680 <==|\n", 406 | " |==> acc: 0.8749, nmi: 0.8663 <==|\n", 407 | " |==> acc: 0.8749, nmi: 0.8660 <==|\n", 408 | " |==> acc: 0.8759, nmi: 0.8668 <==|\n", 409 | " |==> acc: 0.8753, nmi: 0.8656 <==|\n", 410 | " |==> acc: 0.8761, nmi: 0.8691 <==|\n", 411 | " |==> acc: 0.8744, nmi: 0.8645 <==|\n", 412 | " |==> acc: 0.8764, nmi: 0.8666 <==|\n", 413 | " |==> acc: 0.8738, nmi: 0.8637 <==|\n", 414 | " |==> acc: 0.8756, nmi: 0.8659 <==|\n", 415 | " |==> acc: 0.8742, nmi: 0.8647 <==|\n", 416 | " |==> acc: 0.8763, nmi: 0.8682 <==|\n", 417 | " |==> acc: 0.8765, nmi: 0.8674 <==|\n", 418 | " |==> acc: 0.8759, nmi: 0.8669 <==|\n", 419 | " |==> acc: 0.8743, nmi: 0.8655 <==|\n", 420 | " |==> acc: 0.8757, nmi: 0.8663 <==|\n", 421 | " |==> acc: 0.8763, nmi: 0.8675 <==|\n", 422 | " |==> acc: 0.8778, nmi: 0.8698 <==|\n", 423 | " |==> acc: 0.8774, nmi: 0.8690 <==|\n", 424 | " |==> acc: 0.8780, nmi: 0.8698 <==|\n", 425 | " |==> acc: 0.8761, nmi: 0.8665 <==|\n", 426 | " |==> acc: 0.8763, nmi: 0.8665 <==|\n", 427 | " |==> acc: 0.8756, nmi: 0.8671 <==|\n", 428 | " |==> acc: 0.8753, nmi: 0.8677 <==|\n", 429 | " |==> acc: 0.8751, nmi: 0.8674 <==|\n", 430 | " |==> acc: 0.8757, nmi: 0.8681 <==|\n", 431 | " |==> acc: 0.8747, nmi: 0.8643 <==|\n", 432 | " |==> acc: 0.8751, nmi: 0.8675 <==|\n", 433 | " |==> acc: 0.8757, nmi: 0.8661 <==|\n", 434 | " |==> acc: 0.8762, nmi: 0.8668 <==|\n", 435 | " |==> acc: 0.8759, nmi: 0.8665 <==|\n", 436 | " |==> acc: 0.8749, nmi: 0.8669 <==|\n", 437 | " |==> acc: 0.8751, nmi: 0.8662 <==|\n", 438 | " |==> acc: 0.8751, nmi: 0.8661 <==|\n", 439 | " |==> acc: 0.8769, nmi: 0.8683 <==|\n", 440 | " |==> acc: 0.8757, nmi: 0.8687 <==|\n", 441 | " |==> acc: 0.8747, nmi: 0.8670 <==|\n", 442 | " |==> acc: 0.8766, nmi: 0.8689 <==|\n", 443 | " |==> acc: 0.8752, nmi: 0.8654 <==|\n", 444 | " |==> acc: 0.8765, nmi: 0.8679 <==|\n", 445 | " |==> acc: 0.8753, nmi: 0.8659 <==|\n", 446 | " |==> acc: 0.8761, nmi: 0.8666 <==|\n", 447 | " |==> acc: 0.8766, nmi: 0.8692 <==|\n", 448 | " |==> acc: 0.8747, nmi: 0.8657 <==|\n", 449 | " |==> acc: 0.8756, nmi: 0.8676 <==|\n", 450 | " |==> acc: 0.8767, nmi: 0.8680 <==|\n", 451 | " |==> acc: 0.8773, nmi: 0.8687 <==|\n", 452 | " |==> acc: 0.8761, nmi: 0.8653 <==|\n", 453 | " |==> acc: 0.8765, nmi: 0.8659 <==|\n", 454 | " |==> acc: 0.8773, nmi: 0.8674 <==|\n", 455 | " |==> acc: 0.8764, nmi: 0.8673 <==|\n", 456 | " |==> acc: 0.8778, nmi: 0.8675 <==|\n", 457 | " |==> acc: 0.8772, nmi: 0.8670 <==|\n", 458 | " |==> acc: 0.8773, nmi: 0.8677 <==|\n", 459 | " |==> acc: 0.8780, nmi: 0.8678 <==|\n", 460 | " |==> acc: 0.8775, nmi: 0.8666 <==|\n", 461 | " |==> acc: 0.8772, nmi: 0.8667 <==|\n", 462 | " |==> acc: 0.8768, nmi: 0.8666 <==|\n", 463 | " |==> acc: 0.8779, nmi: 0.8667 <==|\n", 464 | " |==> acc: 0.8793, nmi: 0.8694 <==|\n", 465 | " |==> acc: 0.8785, nmi: 0.8688 <==|\n", 466 | " |==> acc: 0.8797, nmi: 0.8683 <==|\n", 467 | " |==> acc: 0.8789, nmi: 0.8695 <==|\n", 468 | " |==> acc: 0.8803, nmi: 0.8699 <==|\n", 469 | " |==> acc: 0.8794, nmi: 0.8683 <==|\n", 470 | " |==> acc: 0.8805, nmi: 0.8690 <==|\n", 471 | " |==> acc: 0.8791, nmi: 0.8679 <==|\n", 472 | " |==> acc: 0.8791, nmi: 0.8678 <==|\n", 473 | " |==> acc: 0.8776, nmi: 0.8663 <==|\n", 474 | " |==> acc: 0.8783, nmi: 0.8683 <==|\n", 475 | " |==> acc: 0.8774, nmi: 0.8685 <==|\n", 476 | " |==> acc: 0.8779, nmi: 0.8694 <==|\n", 477 | " |==> acc: 0.8775, nmi: 0.8683 <==|\n", 478 | " |==> acc: 0.8778, nmi: 0.8700 <==|\n", 479 | " |==> acc: 0.8783, nmi: 0.8701 <==|\n", 480 | " |==> acc: 0.8783, nmi: 0.8691 <==|\n" 481 | ] 482 | }, 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | " |==> acc: 0.8778, nmi: 0.8690 <==|\n", 488 | " |==> acc: 0.8769, nmi: 0.8657 <==|\n", 489 | " |==> acc: 0.8772, nmi: 0.8662 <==|\n", 490 | " |==> acc: 0.8772, nmi: 0.8666 <==|\n", 491 | " |==> acc: 0.8786, nmi: 0.8692 <==|\n", 492 | " |==> acc: 0.8775, nmi: 0.8686 <==|\n", 493 | " |==> acc: 0.8776, nmi: 0.8680 <==|\n", 494 | " |==> acc: 0.8782, nmi: 0.8694 <==|\n", 495 | " |==> acc: 0.8781, nmi: 0.8681 <==|\n", 496 | " |==> acc: 0.8779, nmi: 0.8694 <==|\n", 497 | " |==> acc: 0.8775, nmi: 0.8691 <==|\n", 498 | " |==> acc: 0.8770, nmi: 0.8685 <==|\n", 499 | " |==> acc: 0.8761, nmi: 0.8666 <==|\n", 500 | " |==> acc: 0.8762, nmi: 0.8669 <==|\n", 501 | " |==> acc: 0.8768, nmi: 0.8665 <==|\n", 502 | " |==> acc: 0.8773, nmi: 0.8668 <==|\n", 503 | " |==> acc: 0.8760, nmi: 0.8676 <==|\n", 504 | " |==> acc: 0.8775, nmi: 0.8687 <==|\n", 505 | " |==> acc: 0.8758, nmi: 0.8655 <==|\n", 506 | " |==> acc: 0.8769, nmi: 0.8681 <==|\n", 507 | " |==> acc: 0.8757, nmi: 0.8665 <==|\n", 508 | " |==> acc: 0.8777, nmi: 0.8688 <==|\n", 509 | " |==> acc: 0.8768, nmi: 0.8688 <==|\n" 510 | ] 511 | } 512 | ], 513 | "source": [ 514 | "#now start training\n", 515 | "import random\n", 516 | "random.seed(7)\n", 517 | "dec = DEC(10)\n", 518 | "#dec.pretrain(train_loader, test_loader, 200)\n", 519 | "dec.train(train_loader, test_loader, 200)" 520 | ] 521 | } 522 | ], 523 | "metadata": { 524 | "kernelspec": { 525 | "display_name": "Python 3", 526 | "language": "python", 527 | "name": "python3" 528 | }, 529 | "language_info": { 530 | "codemirror_mode": { 531 | "name": "ipython", 532 | "version": 3 533 | }, 534 | "file_extension": ".py", 535 | "mimetype": "text/x-python", 536 | "name": "python", 537 | "nbconvert_exporter": "python", 538 | "pygments_lexer": "ipython3", 539 | "version": "3.5.4" 540 | } 541 | }, 542 | "nbformat": 4, 543 | "nbformat_minor": 2 544 | } 545 | -------------------------------------------------------------------------------- /DEC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import random 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.datasets as dset 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | import torchvision.transforms as transforms 13 | from sklearn.cluster import MiniBatchKMeans, KMeans 14 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score 15 | 16 | nmi = normalized_mutual_info_score 17 | ari = adjusted_rand_score 18 | 19 | def acc(y_true, y_pred): 20 | """ 21 | Calculate clustering accuracy. Require scikit-learn installed 22 | # Arguments 23 | y: true labels, numpy.array with shape `(n_samples,)` 24 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 25 | # Return 26 | accuracy, in [0,1] 27 | """ 28 | y_true = y_true.astype(np.int64) 29 | assert y_pred.size == y_true.size 30 | D = max(y_pred.max(), y_true.max()) + 1 31 | w = np.zeros((D, D), dtype=np.int64) 32 | for i in range(y_pred.size): 33 | w[y_pred[i], y_true[i]] += 1 34 | from sklearn.utils.linear_assignment_ import linear_assignment 35 | ind = linear_assignment(w.max() - w) 36 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 37 | 38 | class DEC_AE(nn.Module): 39 | """ 40 | DEC auto encoder - this class is used to 41 | """ 42 | def __init__(self, num_classes, num_features): 43 | super(DEC_AE,self).__init__() 44 | self.dropout = nn.Dropout(p=0.1) 45 | self.fc1 = nn.Linear(28*28,500) 46 | self.fc2 = nn.Linear(500,500) 47 | self.fc3 = nn.Linear(500,2000) 48 | self.fc4 = nn.Linear(2000,num_features) 49 | self.relu = nn.ReLU() 50 | self.fc_d1 = nn.Linear(500,28*28) 51 | self.fc_d2 = nn.Linear(500,500) 52 | self.fc_d3 = nn.Linear(2000,500) 53 | self.fc_d4 = nn.Linear(num_features,2000) 54 | self.alpha = 1.0 55 | self.clusterCenter = nn.Parameter(torch.zeros(num_classes,num_features)) 56 | self.pretrainMode = True 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 59 | torch.nn.init.xavier_uniform(m.weight) 60 | 61 | def setPretrain(self,mode): 62 | """To set training mode to pretrain or not, 63 | so that it can control to run only the Encoder or Encoder+Decoder""" 64 | self.pretrainMode = mode 65 | def updateClusterCenter(self, cc): 66 | """ 67 | To update the cluster center. This is a method for pre-train phase. 68 | When a center is being provided by kmeans, we need to update it so 69 | that it is available for further training 70 | :param cc: the cluster centers to update, size of num_classes x num_features 71 | """ 72 | self.clusterCenter.data = torch.from_numpy(cc) 73 | def getTDistribution(self,x, clusterCenter): 74 | """ 75 | student t-distribution, as same as used in t-SNE algorithm. 76 | q_ij = 1/(1+dist(x_i, u_j)^2), then normalize it. 77 | 78 | :param x: input data, in this context it is encoder output 79 | :param clusterCenter: the cluster center from kmeans 80 | """ 81 | xe = torch.unsqueeze(x,1).cuda() - clusterCenter.cuda() 82 | q = 1.0 / (1.0 + (torch.sum(torch.mul(xe,xe), 2) / self.alpha)) 83 | q = q ** (self.alpha + 1.0) / 2.0 84 | q = (q.t() / torch.sum(q, 1)).t() #due to divison, we need to transpose q 85 | return q 86 | 87 | def getDistance(self,x, clusterCenter,alpha=1.0): 88 | """ 89 | it should minimize the distince to 90 | """ 91 | if not hasattr(self, 'clusterCenter'): 92 | self.clusterCenter = nn.Parameter(torch.zeros(num_classes,num_classes)) 93 | xe = torch.unsqueeze(x,1).cuda() - clusterCenter.cuda() 94 | # need to sum up all the point to the same center - axis 1 95 | d = torch.sum(torch.mul(xe,xe), 2) 96 | return d 97 | 98 | def forward(self,x): 99 | x = x.view(-1, 1*28*28) 100 | x = self.dropout(x) 101 | x = self.fc1(x) 102 | x = self.relu(x) 103 | x = self.fc2(x) 104 | x = self.relu(x) 105 | x = self.fc3(x) 106 | x = self.relu(x) 107 | x = self.fc4(x) 108 | x_e = x 109 | #if not in pretrain mode, we need encoder and t distribution output 110 | if self.pretrainMode == False: 111 | return x, self.getTDistribution(x,self.clusterCenter), self.getDistance(x_e,self.clusterCenter),F.softmax(x_e,dim=1) 112 | ##### encoder is done, followed by decoder ##### 113 | x = self.fc_d4(x) 114 | x = self.relu(x) 115 | x = self.fc_d3(x) 116 | x = self.relu(x) 117 | x = self.fc_d2(x) 118 | x = self.relu(x) 119 | x = self.fc_d1(x) 120 | x_de = x.view(-1,1,28,28) 121 | return x_e, x_de 122 | 123 | class DEC: 124 | """The class for controlling the training process of DEC""" 125 | def __init__(self,n_clusters,n_features,alpha=1.0): 126 | self.n_clusters=n_clusters 127 | self.n_features=n_features 128 | self.alpha = alpha 129 | @staticmethod 130 | def target_distribution(q): 131 | weight = (q ** 2 ) / q.sum(0) 132 | #print('q',q) 133 | return Variable((weight.t() / weight.sum(1)).t().data, requires_grad=True) 134 | def logAccuracy(self,pred,label): 135 | print(' '*8 + '|==> acc: %.4f, nmi: %.4f <==|' 136 | % (acc(label, pred), nmi(label, pred))) 137 | @staticmethod 138 | def kld(q,p): 139 | return torch.sum(p*torch.log(p/q),dim=-1) 140 | @staticmethod 141 | def cross_entropy(q,p): 142 | return torch.sum(torch.sum(p*torch.log(1/(q+1e-7)),dim=-1)) 143 | @staticmethod 144 | def depict_q(p): 145 | q1 = p / torch.sqrt(torch.sum(p,dim=0)) 146 | qik = q1 / q1.sum() 147 | return qik 148 | @staticmethod 149 | def distincePerClusterCenter(dist): 150 | totalDist =torch.sum(torch.sum(dist, dim=0)/(torch.max(dist) * dist.size(1))) 151 | return totalDist 152 | 153 | def validateOnCompleteTestData(self,test_loader,model): 154 | model.eval() 155 | to_eval = np.array([model(d[0].cuda())[0].data.cpu().numpy() for i,d in enumerate(test_loader)]) 156 | true_labels = np.array([d[1].cpu().numpy() for i,d in enumerate(test_loader)]) 157 | to_eval = np.reshape(to_eval,(to_eval.shape[0]*to_eval.shape[1],to_eval.shape[2])) 158 | true_labels = np.reshape(true_labels,true_labels.shape[0]*true_labels.shape[1]) 159 | km = KMeans(n_clusters=len(np.unique(true_labels)), n_init=20, n_jobs=4) 160 | y_pred = km.fit_predict(to_eval) 161 | print(' '*8 + '|==> acc: %.4f, nmi: %.4f <==|' 162 | % (acc(true_labels, y_pred), nmi(true_labels, y_pred))) 163 | currentAcc = acc(true_labels, y_pred) 164 | return currentAcc 165 | def pretrain(self,train_loader, test_loader, epochs): 166 | dec_ae = DEC_AE(self.n_clusters,self.n_features).cuda() #auto encoder 167 | mseloss = nn.MSELoss() 168 | optimizer = optim.SGD(dec_ae.parameters(),lr = 1, momentum=0.9) 169 | best_acc = 0.0 170 | for epoch in range(epochs): 171 | dec_ae.train() 172 | running_loss=0.0 173 | for i,data in enumerate(train_loader): 174 | x, label = data 175 | x, label = Variable(x).cuda(),Variable(label).cuda() 176 | optimizer.zero_grad() 177 | x_ae,x_de = dec_ae(x) 178 | loss = F.mse_loss(x_de,x,reduce=True) 179 | loss.backward() 180 | optimizer.step() 181 | x_eval = x.data.cpu().numpy() 182 | label_eval = label.data.cpu().numpy() 183 | running_loss += loss.data.cpu().numpy() 184 | if i % 100 == 99: # print every 100 mini-batches 185 | print('[%d, %5d] loss: %.7f' % 186 | (epoch + 1, i + 1, running_loss / 100)) 187 | running_loss = 0.0 188 | #now we evaluate the accuracy with AE 189 | dec_ae.eval() 190 | currentAcc = self.validateOnCompleteTestData(test_loader,dec_ae) 191 | if currentAcc > best_acc: 192 | torch.save(dec_ae,'bestModel'.format(best_acc)) 193 | best_acc = currentAcc 194 | def clustering(self,mbk,x,model): 195 | model.eval() 196 | y_pred_ae,_,_,_ = model(x) 197 | y_pred_ae = y_pred_ae.data.cpu().numpy() 198 | y_pred = mbk.partial_fit(y_pred_ae) #seems we can only get a centre from batch 199 | self.cluster_centers = mbk.cluster_centers_ #keep the cluster centers 200 | model.updateClusterCenter(self.cluster_centers) 201 | def train(self,train_loader, test_loader, epochs): 202 | """This method will start training for DEC cluster""" 203 | ct = time.time() 204 | model = torch.load("bestModel").cuda() 205 | model.setPretrain(False) 206 | optimizer = optim.SGD([\ 207 | {'params': model.parameters()}, \ 208 | ],lr = 0.01, momentum=0.9) 209 | print('Initializing cluster center with pre-trained weights') 210 | mbk = MiniBatchKMeans(n_clusters=self.n_clusters, n_init=20, batch_size=batch_size) 211 | got_cluster_center = False 212 | for epoch in range(epochs): 213 | for i,data in enumerate(train_loader): 214 | x, label = data 215 | x = Variable(x).cuda() 216 | optimizer.zero_grad() 217 | #step 1 - get cluster center from batch 218 | #here we are using minibatch kmeans to be able to cope with larger dataset. 219 | if not got_cluster_center: 220 | self.clustering(mbk,x,model) 221 | if epoch > 1: 222 | got_cluster_center = True 223 | else: 224 | model.train() 225 | #now we start training with acquired cluster center 226 | feature_pred,q,dist,clssfied = model(x) 227 | d = self.distincePerClusterCenter(dist) 228 | qik = self.depict_q(clssfied) 229 | loss1 = self.cross_entropy(clssfied,qik) 230 | loss = d + loss1 231 | loss.backward() 232 | optimizer.step() 233 | currentAcc = self.validateOnCompleteTestData(test_loader,model) 234 | if __name__ == "__main__": 235 | use_cuda = torch.cuda.is_available() 236 | root = './data' 237 | if not os.path.exists(root): 238 | os.mkdir(root) 239 | random.seed(7) 240 | 241 | parser = argparse.ArgumentParser(description='train', 242 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 243 | parser.add_argument('--batch_size', default=100, type=int) 244 | parser.add_argument('--pretrain_epochs', default=200, type=int) 245 | parser.add_argument('--train_epochs', default=200, type=int) 246 | args = parser.parse_args() 247 | 248 | trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) 249 | # if not exist, download mnist dataset 250 | train_set = dset.MNIST(root=root, train=True, transform=trans, download=True) 251 | test_set = dset.MNIST(root=root, train=False, transform=trans, download=True) 252 | batch_size = args.batch_size 253 | train_loader = torch.utils.data.DataLoader( 254 | dataset=train_set, 255 | batch_size=batch_size, 256 | shuffle=True) 257 | test_loader = torch.utils.data.DataLoader( 258 | dataset=test_set, 259 | batch_size=batch_size, 260 | shuffle=False) 261 | dec = DEC(10,10) 262 | if args.pretrain_epochs > 0: 263 | dec.pretrain(train_loader, test_loader, args.pretrain_epochs) 264 | dec.train(train_loader, test_loader, args.train_epochs) 265 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEC clustering in pyTorch 2 | This is an implementation of Junyuan Xie, Ross Girshick, and Ali Farhadi. Unsupervised deep embedding for clustering analysis. ICML 2016 https://arxiv.org/pdf/1511.06335.pdf 3 | ## Pre-requsit 4 | * pyTorch 0.3+ with CUDA environment 5 | * torchVision 6 | * scikit-learn 7 | ## Usage 8 | python DEC.py 9 | ## Results 10 | This code can reach around 87% accuracy on mnist test dataset http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz (trained on mnist training sets) 11 | ## References 12 | * The code references https://github.com/XifengGuo/DEC-keras and reuse some of the code (thanks for the good work :) 13 | * To keep in mind of the bigger dataset, miniBatchKMeans is used to get cluster center. 14 | * Right now only mnist dataset is tested, it should be easier to add more dataset through torchVision DataLoader. 15 | --------------------------------------------------------------------------------