├── DAC_mnist.py ├── README.md ├── module.py └── util.py /DAC_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | import numpy as np 4 | 5 | import module 6 | import util 7 | 8 | 9 | mode = 'Training' 10 | num_cluster = 10 11 | eps = 1e-10 # term added for numerical stability of log computations 12 | 13 | 14 | # ------------------------------------build the computation graph------------------------------------------ 15 | image_pool_input = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name='image_pool_input') 16 | u_thres = tf.placeholder(shape=[], dtype=tf.float32, name='u_thres') 17 | l_thres = tf.placeholder(shape=[], dtype=tf.float32, name='l_thres') 18 | lr = tf.placeholder(shape=[], dtype=tf.float32, name='learning_rate') 19 | 20 | # get similarity matrix 21 | label_feat = module.mnistNetwork(image_pool_input, num_cluster, name='mnistNetwork', reuse=False) 22 | label_feat_norm = tf.nn.l2_normalize(label_feat, dim=1) 23 | sim_mat = tf.matmul(label_feat_norm, label_feat_norm, transpose_b=True) 24 | 25 | pos_loc = tf.greater(sim_mat, u_thres, name='greater') 26 | neg_loc = tf.less(sim_mat, l_thres, name='less') 27 | # select_mask = tf.cast(tf.logical_or(pos_loc, neg_loc, name='mask'), dtype=tf.float32) 28 | pos_loc_mask = tf.cast(pos_loc, dtype=tf.float32) 29 | neg_loc_mask = tf.cast(neg_loc, dtype=tf.float32) 30 | 31 | # get clusters 32 | pred_label = tf.argmax(label_feat, axis=1) 33 | 34 | # define losses and train op 35 | pos_entropy = tf.multiply(-tf.log(tf.clip_by_value(sim_mat, eps, 1.0)), pos_loc_mask) 36 | neg_entropy = tf.multiply(-tf.log(tf.clip_by_value(1-sim_mat, eps, 1.0)), neg_loc_mask) 37 | 38 | loss_sum = tf.reduce_mean(pos_entropy) + tf.reduce_mean(neg_entropy) 39 | train_op = tf.train.RMSPropOptimizer(lr).minimize(loss_sum) 40 | # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 41 | # with tf.control_dependencies(update_ops): 42 | # train_op = tf.train.RMSPropOptimizer(lr).minimize(loss) 43 | 44 | 45 | # -------------------------------------------prepared datasets---------------------------------------------- 46 | # read mnist data (1 channel) 47 | # mnist_1 = tf.contrib.learn.datasets.load_dataset("mnist") 48 | mnist = input_data.read_data_sets('MNIST-data') # your mnist data should be stored at 'MNIST-data' 49 | mnist_train = mnist.train.images 50 | mnist_train = np.reshape(mnist_train, (-1, 28, 28, 1)) # reshape into 1-channel image 51 | mnist_train_labels = np.asarray(mnist.train.labels, dtype=np.int32) 52 | mnist_test = mnist.test.images 53 | mnist_test = np.reshape(mnist_test, (-1, 28, 28, 1)) # reshape into 1-channel image 54 | mnist_test_labels = np.asarray(mnist.test.labels, dtype=np.int32) 55 | 56 | mnist_data = np.concatenate([mnist_train, mnist_test], axis=0) 57 | mnist_labels = np.concatenate([mnist_train_labels, mnist_test_labels], axis=0) 58 | # print(len(mnist_labels)) 59 | 60 | # # read cifar data 61 | # cifar_data = [] 62 | # cifar_label = [] 63 | # for i in range(1, 6): 64 | # file_name = 'cifar-10-data/' + 'data_batch_' + str(i) 65 | # with open(file_name, 'rb') as fo: 66 | # cifar_dict = cPickle.load(fo) 67 | # data = cifar_dict['data'] 68 | # label = cifar_dict['labels'] 69 | 70 | # data = data.astype('float32')/255 71 | # data = np.reshape(data, (-1, 3, 32, 32)) 72 | # data = np.transpose(data, (0, 2, 3, 1)) 73 | # cifar_data.append(data) 74 | # cifar_label.append(label) 75 | 76 | # cifar_data = np.concatenate(cifar_data, axis=0) 77 | # cifar_label = np.concatenate(cifar_label, axis=0) 78 | # # print cifar_data.shape 79 | 80 | 81 | # --------------------------------------------run the graph------------------------------------------------- 82 | saver = tf.train.Saver() 83 | if mode == 'Training': 84 | batch_size = 128 85 | base_lr = 0.001 86 | with tf.Session() as sess: 87 | sess.run(tf.global_variables_initializer()) 88 | 89 | lamda = 0 90 | epoch = 1 91 | u = 0.95 92 | l = 0.455 93 | while u > l: 94 | u = 0.95 - lamda 95 | l = 0.455 + 0.1*lamda 96 | for i in range(1, int(1001)): # 1000 iterations is roughly 1 epoch 97 | data_samples, _ = util.get_mnist_batch(batch_size, mnist_data, mnist_labels) 98 | feed_dict={image_pool_input: data_samples, 99 | u_thres: u, 100 | l_thres: l, 101 | lr: base_lr} 102 | train_loss, _ = sess.run([loss_sum, train_op], feed_dict=feed_dict) 103 | if i % 20 == 0: 104 | print('training loss at iter %d is %f' % (i, train_loss)) 105 | 106 | lamda += 1.1 * 0.009 107 | 108 | # run testing every epoch 109 | data_samples, data_labels = util.get_mnist_batch(512, mnist_data, mnist_labels) 110 | feed_dict={image_pool_input: data_samples} 111 | pred_cluster = sess.run(pred_label, feed_dict=feed_dict) 112 | 113 | acc = util.clustering_acc(data_labels, pred_cluster) 114 | nmi = util.NMI(data_labels, pred_cluster) 115 | ari = util.ARI(data_labels, pred_cluster) 116 | print('testing NMI, ARI, ACC at epoch %d is %f, %f, %f.' % (epoch, nmi, ari, acc)) 117 | 118 | if epoch % 5 == 0: # save model at every 5 epochs 119 | model_name = 'DAC_ep_' + str(epoch) + '.ckpt' 120 | save_path = saver.save(sess, 'DAC_models/' + model_name) 121 | print("Model saved in file: %s" % save_path) 122 | 123 | epoch += 1 124 | 125 | elif mode == 'Testing': 126 | batch_size = 1000 127 | with tf.Session() as sess: 128 | saver.restore(sess, "DAC_models/DAC_ep_45.ckpt") 129 | print('model restored!') 130 | all_predictions = np.zeros([len(mnist_labels)], dtype=np.float32) 131 | for i in range(65): 132 | data_samples = util.get_mnist_batch_test(batch_size, mnist_data, i) 133 | feed_dict={image_pool_input: data_samples} 134 | pred_cluster = sess.run(pred_label, feed_dict=feed_dict) 135 | all_predictions[i*batch_size:(i+1)*batch_size] = pred_cluster 136 | 137 | acc = util.clustering_acc(mnist_labels.astype(int), all_predictions.astype(int)) 138 | nmi = util.NMI(mnist_labels.astype(int), all_predictions.astype(int)) 139 | ari = util.ARI(mnist_labels.astype(int), all_predictions.astype(int)) 140 | print('testing NMI, ARI, ACC are %f, %f, %f.' % (nmi, ari, acc)) 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAC-tensorflow 2 | Tensorflow implementation of Deep Adaptive Image Clustering 3 | 4 | The original ICCV paper: http://openaccess.thecvf.com/content_ICCV_2017/papers/Chang_Deep_Adaptive_Image_ICCV_2017_paper.pdf 5 | 6 | Author implementation in Keras: https://github.com/vector-1127/DAC 7 | 8 | Code tested on Tensorflow 1.8. 9 | 10 | ## Results on MNIST 11 | 12 | NMI: 0.9414, ARI: 0.9416, ACC: 0.9731 13 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def mnistNetwork(in_img, num_cluster, name='mnistNetwork', reuse=False): 5 | with tf.variable_scope(name, reuse=reuse): 6 | # conv1 7 | conv1 = tf.layers.conv2d(in_img, 64, [3,3], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 8 | conv1 = tf.layers.batch_normalization(conv1, axis=-1, epsilon=1e-5, training=True, trainable=False) 9 | conv1 = tf.nn.relu(conv1) 10 | # conv2 11 | conv2 = tf.layers.conv2d(conv1, 64, [3,3], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 12 | conv2 = tf.layers.batch_normalization(conv2, axis=-1, epsilon=1e-5, training=True, trainable=False) 13 | conv2 = tf.nn.relu(conv2) 14 | # conv3 15 | conv3 = tf.layers.conv2d(conv2, 64, [3,3], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 16 | conv3 = tf.layers.batch_normalization(conv3, axis=-1, epsilon=1e-5, training=True, trainable=False) 17 | conv3 = tf.nn.relu(conv3) 18 | conv3 = tf.layers.max_pooling2d(conv3, [2,2], [2,2]) 19 | conv3 = tf.layers.batch_normalization(conv3, axis=-1, epsilon=1e-5, training=True, trainable=False) 20 | # conv4 21 | conv4 = tf.layers.conv2d(conv3, 128, [3,3], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 22 | conv4 = tf.layers.batch_normalization(conv4, axis=-1, epsilon=1e-5, training=True, trainable=False) 23 | conv4 = tf.nn.relu(conv4) 24 | # conv5 25 | conv5 = tf.layers.conv2d(conv4, 128, [3,3], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 26 | conv5 = tf.layers.batch_normalization(conv5, axis=-1, epsilon=1e-5, training=True, trainable=False) 27 | conv5 = tf.nn.relu(conv5) 28 | # conv6 29 | conv6 = tf.layers.conv2d(conv5, 128, [3,3], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 30 | conv6 = tf.layers.batch_normalization(conv6, axis=-1, epsilon=1e-5, training=True, trainable=False) 31 | conv6 = tf.nn.relu(conv6) 32 | conv6 = tf.layers.max_pooling2d(conv6, [2,2], [2,2]) 33 | conv6 = tf.layers.batch_normalization(conv6, axis=-1, epsilon=1e-5, training=True, trainable=False) 34 | # conv7 35 | conv7 = tf.layers.conv2d(conv6, 10, [1,1], [1,1], padding='valid', activation=None, kernel_initializer=tf.keras.initializers.he_normal()) 36 | conv7 = tf.layers.batch_normalization(conv7, axis=-1, epsilon=1e-5, training=True, trainable=False) 37 | conv7 = tf.nn.relu(conv7) 38 | conv7 = tf.layers.average_pooling2d(conv7, [2,2], [2,2]) 39 | conv7 = tf.layers.batch_normalization(conv7, axis=-1, epsilon=1e-5, training=True, trainable=False) 40 | conv7_flat = tf.layers.flatten(conv7) 41 | 42 | # dense8 43 | fc8 = tf.layers.dense(conv7_flat, 10, kernel_initializer=tf.initializers.identity()) 44 | fc8 = tf.layers.batch_normalization(fc8, axis=-1, epsilon=1e-5, training=True, trainable=False) 45 | fc8 = tf.nn.relu(fc8) 46 | # dense9 47 | fc9 = tf.layers.dense(fc8, num_cluster, kernel_initializer=tf.initializers.identity()) 48 | fc9 = tf.layers.batch_normalization(fc9, axis=-1, epsilon=1e-5, training=True, trainable=False) 49 | fc9 = tf.nn.relu(fc9) 50 | 51 | out = tf.nn.softmax(fc9) 52 | 53 | return out -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from sklearn import metrics 4 | from sklearn.utils.linear_assignment_ import linear_assignment 5 | 6 | 7 | def get_cifar_batch(batch_size, cifar_data, cifar_label): 8 | batch_index = random.sample(range(len(cifar_label)), batch_size) 9 | 10 | batch_data = np.empty([batch_size, 32, 32, 3], dtype=np.float32) 11 | batch_label = np.empty([batch_size], dtype=np.int32) 12 | for n, i in enumerate(batch_index): 13 | batch_data[n, ...] = cifar_data[i, ...] 14 | batch_label[n] = cifar_label[i] 15 | 16 | return batch_data, batch_label 17 | 18 | 19 | def get_mnist_batch(batch_size, mnist_data, mnist_labels): 20 | batch_index = random.sample(range(len(mnist_labels)), batch_size) 21 | 22 | batch_data = np.empty([batch_size, 28, 28, 1], dtype=np.float32) 23 | batch_label = np.empty([batch_size], dtype=np.int32) 24 | for n, i in enumerate(batch_index): 25 | batch_data[n, ...] = mnist_data[i, ...] 26 | batch_label[n] = mnist_labels[i] 27 | 28 | return batch_data, batch_label 29 | 30 | 31 | def get_mnist_batch_test(batch_size, mnist_data, i): 32 | batch_data = np.copy(mnist_data[batch_size*i:batch_size*(i+1), ...]) 33 | # batch_label = np.copy(mnist_labels[batch_size*i:batch_size*(i+1)]) 34 | 35 | return batch_data 36 | 37 | 38 | def get_svhn_batch(batch_size, svhn_data, svhn_labels): 39 | batch_index = random.sample(range(len(svhn_labels)), batch_size) 40 | 41 | batch_data = np.empty([batch_size, 32, 32, 3], dtype=np.float32) 42 | batch_label = np.empty([batch_size], dtype=np.int32) 43 | for n, i in enumerate(batch_index): 44 | batch_data[n, ...] = svhn_data[i, ...] 45 | batch_label[n] = svhn_labels[i] 46 | 47 | return batch_data, batch_label 48 | 49 | 50 | def clustering_acc(y_true, y_pred): 51 | y_true = y_true.astype(np.int64) 52 | assert y_pred.size == y_true.size 53 | D = max(y_pred.max(), y_true.max()) + 1 54 | w = np.zeros((D, D), dtype=np.int64) 55 | for i in range(y_pred.size): 56 | w[y_pred[i], y_true[i]] += 1 57 | ind = linear_assignment(w.max() - w) 58 | 59 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 60 | 61 | 62 | def NMI(y_true,y_pred): 63 | return metrics.normalized_mutual_info_score(y_true, y_pred) 64 | 65 | 66 | def ARI(y_true,y_pred): 67 | return metrics.adjusted_rand_score(y_true, y_pred) 68 | --------------------------------------------------------------------------------