├── Images └── chart.png ├── README.md ├── metrics.py └── main.py /Images/chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepayan137/DeepClustering/HEAD/Images/chart.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepClustering 2 | 3 | A pytorch implementation of the paper [Unsupervised Deep Embedding for Clustering Analysis](https://arxiv.org/pdf/1511.06335.pdf). 4 | 5 | ## Getting Started 6 | 7 | Clone the project into your local system 8 | 9 | ``` 10 | git clone https://github.com/Deepayan137/DeepClustering.git 11 | cd DeepClustering 12 | 13 | ``` 14 | ### Prerequesites 15 | 16 | * python 3.5+ 17 | * pytorch 1.0 18 | 19 | ### Installation 20 | 21 | `conda install --file requirements.txt` 22 | 23 | 24 | ### Usage 25 | 26 | `python -m main --train_epochs 200 --pretrain_epochs 100` 27 | 28 | ### Results 29 | 30 | 31 | ![Results](Images/chart.png) 32 | 33 | ### Cluster Visualization 34 | 35 | 36 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score 3 | 4 | nmi = normalized_mutual_info_score 5 | ari = adjusted_rand_score 6 | 7 | 8 | def acc(y_true, y_pred): 9 | """ 10 | Calculate clustering accuracy. Require scikit-learn installed 11 | # Arguments 12 | y: true labels, numpy.array with shape `(n_samples,)` 13 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 14 | # Return 15 | accuracy, in [0,1] 16 | """ 17 | y_true = y_true.astype(np.int64) 18 | assert y_pred.size == y_true.size 19 | D = max(y_pred.max(), y_true.max()) + 1 20 | w = np.zeros((D, D), dtype=np.int64) 21 | for i in range(y_pred.size): 22 | w[y_pred[i], y_true[i]] += 1 23 | from sklearn.utils.linear_assignment_ import linear_assignment 24 | ind = linear_assignment(w.max() - w) 25 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pdb 4 | from tqdm import * 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from torch.nn import Parameter 10 | from torchvision import transforms 11 | from torchvision.datasets import MNIST 12 | from torchvision.utils import save_image 13 | from sklearn.cluster import KMeans 14 | import numpy as np 15 | from tqdm import * 16 | from metrics import * 17 | from sklearn.manifold import TSNE 18 | from matplotlib import pyplot as plt 19 | import pandas as pd 20 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 21 | 22 | class AutoEncoder(nn.Module): 23 | def __init__(self): 24 | super(AutoEncoder, self).__init__() 25 | self.encoder = nn.Sequential( 26 | nn.Linear(28 * 28, 500), 27 | nn.ReLU(True), 28 | nn.Linear(500, 500), 29 | nn.ReLU(True), 30 | nn.Linear(500, 500), 31 | nn.ReLU(True), 32 | nn.Linear(500, 2000), 33 | nn.ReLU(True), 34 | nn.Linear(2000, 10)) 35 | self.decoder = nn.Sequential( 36 | nn.Linear(10, 2000), 37 | nn.ReLU(True), 38 | nn.Linear(2000, 500), 39 | nn.ReLU(True), 40 | nn.Linear(500, 500), 41 | nn.ReLU(True), 42 | nn.Linear(500, 500), 43 | nn.ReLU(True), 44 | nn.Linear(500, 28 * 28)) 45 | self.model = nn.Sequential(self.encoder, self.decoder) 46 | def encode(self, x): 47 | return self.encoder(x) 48 | 49 | def forward(self, x): 50 | x = self.model(x) 51 | return x 52 | 53 | 54 | class ClusteringLayer(nn.Module): 55 | def __init__(self, n_clusters=10, hidden=10, cluster_centers=None, alpha=1.0): 56 | super(ClusteringLayer, self).__init__() 57 | self.n_clusters = n_clusters 58 | self.alpha = alpha 59 | self.hidden = hidden 60 | if cluster_centers is None: 61 | initial_cluster_centers = torch.zeros( 62 | self.n_clusters, 63 | self.hidden, 64 | dtype=torch.float 65 | ).cuda() 66 | nn.init.xavier_uniform_(initial_cluster_centers) 67 | else: 68 | initial_cluster_centers = cluster_centers 69 | self.cluster_centers = Parameter(initial_cluster_centers) 70 | def forward(self, x): 71 | norm_squared = torch.sum((x.unsqueeze(1) - self.cluster_centers)**2, 2) 72 | numerator = 1.0 / (1.0 + (norm_squared / self.alpha)) 73 | power = float(self.alpha + 1) / 2 74 | numerator = numerator**power 75 | t_dist = (numerator.t() / torch.sum(numerator, 1)).t() #soft assignment using t-distribution 76 | return t_dist 77 | 78 | class DEC(nn.Module): 79 | def __init__(self, n_clusters=10, autoencoder=None, hidden=10, cluster_centers=None, alpha=1.0): 80 | super(DEC, self).__init__() 81 | self.n_clusters = n_clusters 82 | self.alpha = alpha 83 | self.hidden = hidden 84 | self.cluster_centers = cluster_centers 85 | self.autoencoder = autoencoder 86 | self.clusteringlayer = ClusteringLayer(self.n_clusters, self.hidden, self.cluster_centers, self.alpha) 87 | 88 | def target_distribution(self, q_): 89 | weight = (q_ ** 2) / torch.sum(q_, 0) 90 | return (weight.t() / torch.sum(weight, 1)).t() 91 | 92 | def forward(self, x): 93 | x = self.autoencoder.encode(x) 94 | return self.clusteringlayer(x) 95 | 96 | def visualize(self, epoch,x): 97 | fig = plt.figure() 98 | ax = plt.subplot(111) 99 | x = self.autoencoder.encode(x).detach() 100 | x = x.cpu().numpy()[:2000] 101 | x_embedded = TSNE(n_components=2).fit_transform(x) 102 | plt.scatter(x_embedded[:,0], x_embedded[:,1]) 103 | fig.savefig('plots/mnist_{}.png'.format(epoch)) 104 | plt.close(fig) 105 | 106 | def add_noise(img): 107 | noise = torch.randn(img.size()) * 0.2 108 | noisy_img = img + noise 109 | return noisy_img 110 | 111 | def save_checkpoint(state, filename, is_best): 112 | """Save checkpoint if a new best is achieved""" 113 | if is_best: 114 | print("=> Saving new checkpoint") 115 | torch.save(state, filename) 116 | else: 117 | print("=> Validation Accuracy did not improve") 118 | 119 | def pretrain(**kwargs): 120 | data = kwargs['data'] 121 | model = kwargs['model'] 122 | num_epochs = kwargs['num_epochs'] 123 | savepath = kwargs['savepath'] 124 | checkpoint = kwargs['checkpoint'] 125 | start_epoch = checkpoint['epoch'] 126 | parameters = list(autoencoder.parameters()) 127 | optimizer = torch.optim.Adam(parameters, lr=1e-3, weight_decay=1e-5) 128 | train_loader = DataLoader(dataset=data, 129 | batch_size=128, 130 | shuffle=True) 131 | for epoch in range(start_epoch, num_epochs): 132 | for data in train_loader: 133 | img = data.float() 134 | noisy_img = add_noise(img) 135 | noisy_img = noisy_img.to(device) 136 | img = img.to(device) 137 | # ===================forward===================== 138 | output = model(noisy_img) 139 | output = output.squeeze(1) 140 | output = output.view(output.size(0), 28*28) 141 | loss = nn.MSELoss()(output, img) 142 | # ===================backward==================== 143 | optimizer.zero_grad() 144 | loss.backward() 145 | optimizer.step() 146 | # ===================log======================== 147 | print('epoch [{}/{}], MSE_loss:{:.4f}' 148 | .format(epoch + 1, num_epochs, loss.item())) 149 | state = loss.item() 150 | is_best = False 151 | if state < checkpoint['best']: 152 | checkpoint['best'] = state 153 | is_best = True 154 | 155 | save_checkpoint({ 156 | 'state_dict': model.state_dict(), 157 | 'best': state, 158 | 'epoch':epoch 159 | }, savepath, 160 | is_best) 161 | 162 | 163 | def train(**kwargs): 164 | data = kwargs['data'] 165 | labels = kwargs['labels'] 166 | model = kwargs['model'] 167 | num_epochs = kwargs['num_epochs'] 168 | savepath = kwargs['savepath'] 169 | checkpoint = kwargs['checkpoint'] 170 | start_epoch = checkpoint['epoch'] 171 | features = [] 172 | train_loader = DataLoader(dataset=data, 173 | batch_size=128, 174 | shuffle=False) 175 | 176 | for i, batch in enumerate(train_loader): 177 | img = batch.float() 178 | img = img.to(device) 179 | features.append(model.autoencoder.encode(img).detach().cpu()) 180 | features = torch.cat(features) 181 | # ============K-means======================================= 182 | kmeans = KMeans(n_clusters=10, random_state=0).fit(features) 183 | cluster_centers = kmeans.cluster_centers_ 184 | cluster_centers = torch.tensor(cluster_centers, dtype=torch.float).cuda() 185 | model.clusteringlayer.cluster_centers = torch.nn.Parameter(cluster_centers) 186 | # ========================================================= 187 | y_pred = kmeans.predict(features) 188 | accuracy = acc(y.cpu().numpy(), y_pred) 189 | print('Initial Accuracy: {}'.format(accuracy)) 190 | 191 | loss_function = nn.KLDivLoss(size_average=False) 192 | optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9) 193 | print('Training') 194 | row = [] 195 | for epoch in range(start_epoch, num_epochs): 196 | batch = data 197 | img = batch.float() 198 | img = img.to(device) 199 | output = model(img) 200 | target = model.target_distribution(output).detach() 201 | out = output.argmax(1) 202 | if epoch % 20 == 0: 203 | print('plotting') 204 | dec.visualize(epoch, img) 205 | loss = loss_function(output.log(), target) / output.shape[0] 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | accuracy = acc(y.cpu().numpy(), out.cpu().numpy()) 210 | row.append([epoch, accuracy]) 211 | print('Epochs: [{}/{}] Accuracy:{}, Loss:{}'.format(epoch, num_epochs, accuracy, loss)) 212 | state = loss.item() 213 | is_best = False 214 | if state < checkpoint['best']: 215 | checkpoint['best'] = state 216 | is_best = True 217 | 218 | save_checkpoint({ 219 | 'state_dict': model.state_dict(), 220 | 'best': state, 221 | 'epoch':epoch 222 | }, savepath, 223 | is_best) 224 | 225 | df = pd.DataFrame(row, columns=['epochs', 'accuracy']) 226 | df.to_csv('log.csv') 227 | 228 | def load_mnist(): 229 | # the data, shuffled and split between train and test sets 230 | train = MNIST(root='./data/', 231 | train=True, 232 | transform=transforms.ToTensor(), 233 | download=True) 234 | 235 | test = MNIST(root='./data/', 236 | train=False, 237 | transform=transforms.ToTensor()) 238 | x_train, y_train = train.train_data, train.train_labels 239 | x_test, y_test = test.test_data, test.test_labels 240 | x = torch.cat((x_train, x_test), 0) 241 | y = torch.cat((y_train, y_test), 0) 242 | x = x.reshape((x.shape[0], -1)) 243 | x = np.divide(x, 255.) 244 | print('MNIST samples', x.shape) 245 | return x, y 246 | 247 | if __name__ == '__main__': 248 | 249 | 250 | import argparse 251 | 252 | parser = argparse.ArgumentParser(description='train', 253 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 254 | 255 | parser.add_argument('--batch_size', default=128, type=int) 256 | parser.add_argument('--pretrain_epochs', default=20, type=int) 257 | parser.add_argument('--train_epochs', default=200, type=int) 258 | parser.add_argument('--save_dir', default='saves') 259 | args = parser.parse_args() 260 | print(args) 261 | epochs_pre = args.pretrain_epochs 262 | batch_size = args.batch_size 263 | 264 | x, y = load_mnist() 265 | autoencoder = AutoEncoder().to(device) 266 | ae_save_path = 'saves/sim_autoencoder.pth' 267 | 268 | if os.path.isfile(ae_save_path): 269 | print('Loading {}'.format(ae_save_path)) 270 | checkpoint = torch.load(ae_save_path) 271 | autoencoder.load_state_dict(checkpoint['state_dict']) 272 | else: 273 | print("=> no checkpoint found at '{}'".format(ae_save_path)) 274 | checkpoint = { 275 | "epoch": 0, 276 | "best": float("inf") 277 | } 278 | pretrain(data=x, model=autoencoder, num_epochs=epochs_pre, savepath=ae_save_path, checkpoint=checkpoint) 279 | 280 | 281 | dec_save_path='saves/dec.pth' 282 | dec = DEC(n_clusters=10, autoencoder=autoencoder, hidden=10, cluster_centers=None, alpha=1.0).to(device) 283 | if os.path.isfile(dec_save_path): 284 | print('Loading {}'.format(dec_save_path)) 285 | checkpoint = torch.load(dec_save_path) 286 | dec.load_state_dict(checkpoint['state_dict']) 287 | else: 288 | print("=> no checkpoint found at '{}'".format(dec_save_path)) 289 | checkpoint = { 290 | "epoch": 0, 291 | "best": float("inf") 292 | } 293 | train(data=x, labels=y, model=dec, num_epochs=args.train_epochs, savepath=dec_save_path, checkpoint=checkpoint) 294 | 295 | 296 | 297 | --------------------------------------------------------------------------------