├── README.md ├── embedder.py ├── mnist_embedding_visualization.jpg ├── model.ckpt.data-00000-of-00001 ├── model.ckpt.index ├── model.ckpt.meta ├── test_mnist.py └── test_mnist_large_data.py /README.md: -------------------------------------------------------------------------------- 1 | # tensorboard-embedding-visualization 2 | Easily visualize embedding on tensorboard with thumbnail images and labels. 3 | 4 | Currently this repo is compatible with Tensorflow r1.0.1 5 | 6 | ![alt text](https://raw.githubusercontent.com/jireh-father/tensorboard-embedding-visualization/master/mnist_embedding_visualization.jpg) 7 | 8 | 9 | ## Getting Started 10 | 11 | ```python 12 | import embedder 13 | 14 | # create the model graph and get the last layer's output. 15 | logits = model() 16 | 17 | # init session and restore pre-trained model file 18 | sess = tf.Session() 19 | sess.run(tf.global_variables_initializer()) 20 | saver = tf.train.Saver() 21 | saver.restore(sess, os.path.join(test_path, 'model.ckpt')) 22 | 23 | # read your dataset and labels 24 | data_sets, labels = read_data_sets() 25 | 26 | # run your model 27 | feed_dict = {input_placeholder: dataset, label_placeholder: labels} 28 | activations = sess.run(logits, feed_dict) 29 | 30 | embedder.summary_embedding(sess=sess, dataset=data_sets, embedding_list=[activations], 31 | embedding_path="your embedding path", image_size=your_image_size, channel=3, 32 | labels=labels) 33 | ``` 34 | 35 | 36 | ### If you want to use large data. 37 | ```python 38 | import embedder 39 | 40 | # create the model graph 41 | logits = model() 42 | 43 | # init session and restore pre-trained model file 44 | sess = tf.Session() 45 | sess.run(tf.global_variables_initializer()) 46 | saver = tf.train.Saver() 47 | saver.restore(sess, os.path.join(test_path, 'model.ckpt')) 48 | 49 | total_dataset = None 50 | total_labels = None 51 | total_activations = None 52 | for i in range(10) 53 | data_sets, labels = get_batch(i) 54 | feed_dict = {input_placeholder: dataset, label_placeholder: labels} 55 | activations = sess.run(logits, feed_dict) 56 | if total_dataset is None: 57 | total_dataset = data_sets 58 | total_labels = labels 59 | total_activations = activations 60 | else: 61 | total_dataset = np.append(data_sets, total_dataset, axis=0) 62 | total_labels = np.append(labels, total_labels, axis=0) 63 | total_activations = np.append(activations, total_activations, axis=0) 64 | 65 | embedder.summary_embedding(sess=sess, dataset=total_dataset, embedding_list=[total_activations], 66 | embedding_path="your embedding path", image_size=your_image_size, channel=3, 67 | labels=total_labels) 68 | ``` 69 | 70 | --- 71 | 72 | 73 | ## Running mnist test 74 | 75 | ```shell 76 | python test_mnist.py 77 | (python test_mnist_large_data.py) 78 | tensorboard --log_dir=./ 79 | ``` 80 | 81 | This should print that TensorBoard has started. Next, connect http://localhost:6006 and click the EMBEDDINGS menu. 82 | 83 | 84 | --- 85 | 86 | 87 | ## API Reference 88 | 89 | ```python 90 | def summary_embedding(sess, dataset, embedding_list, embedding_path, image_size, channel=3, labels=None): 91 | pass 92 | 93 | ``` 94 | 95 | 96 | --- 97 | 98 | 99 | ## Acknowledgments 100 | http://www.pinchofintelligence.com/simple-introduction-to-tensorboard-embedding-visualisation/ 101 | https://github.com/tensorflow/tensorflow/issues/6322 102 | 103 | 104 | -------------------------------------------------------------------------------- /embedder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.contrib.tensorboard.plugins import projector 6 | import scipy.misc 7 | import os 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | 12 | def summary_embedding(sess, dataset, embedding_list, embedding_path, image_size, channel=3, labels=None): 13 | if not os.path.exists(embedding_path): 14 | os.makedirs(embedding_path) 15 | 16 | if len(dataset.shape) == 2: 17 | dataset = dataset.reshape((-1, image_size * image_size * channel)) 18 | 19 | summary_writer = tf.summary.FileWriter(embedding_path, sess.graph) 20 | 21 | config = projector.ProjectorConfig() 22 | 23 | for embed_idx, embed_vectors in enumerate(embedding_list): 24 | embed_tensor = make_embed_tensor(sess, embed_vectors, embed_idx) 25 | write_projector_config(config, embed_tensor.name, embedding_path, image_size, channel, summary_writer, labels) 26 | 27 | summary_writer.close() 28 | 29 | save_model(sess, embedding_path) 30 | 31 | # Make sprite and labels. 32 | make_sprite(dataset, image_size, channel, embedding_path) 33 | if labels is not None and len(labels) > 0: 34 | make_metadata(labels, embedding_path) 35 | 36 | 37 | def images_to_sprite(data): 38 | """Creates the sprite image along with any necessary padding 39 | 40 | Args: 41 | data: NxHxW[x3] tensor containing the images. 42 | 43 | Returns: 44 | data: Properly shaped HxWx3 image with any necessary padding. 45 | """ 46 | if len(data.shape) == 3: 47 | data = np.tile(data[..., np.newaxis], (1, 1, 1, 3)) 48 | data = data.astype(np.float32) 49 | min_data = np.min(data.reshape((data.shape[0], -1)), axis=1) 50 | data = (data.transpose(1, 2, 3, 0) - min_data).transpose(3, 0, 1, 2) 51 | max_data = np.max(data.reshape((data.shape[0], -1)), axis=1) 52 | data = (data.transpose(1, 2, 3, 0) / max_data).transpose(3, 0, 1, 2) 53 | 54 | n = int(np.ceil(np.sqrt(data.shape[0]))) 55 | padding = ((0, n ** 2 - data.shape[0]), (0, 0), 56 | (0, 0)) + ((0, 0),) * (data.ndim - 3) 57 | data = np.pad(data, padding, mode='constant', 58 | constant_values=0) 59 | # Tile the individual thumbnails into an image. 60 | data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) 61 | data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) 62 | data = (data * 255).astype(np.uint8) 63 | return data 64 | 65 | 66 | def make_sprite(dataset, image_size, channel, output_path): 67 | if channel == 1: 68 | images = np.array(dataset).reshape((-1, image_size, image_size)).astype(np.float32) 69 | else: 70 | images = np.array(dataset).reshape((-1, image_size, image_size, channel)).astype(np.float32) 71 | sprite = images_to_sprite(images) 72 | scipy.misc.imsave(os.path.join(output_path, 'sprite.png'), sprite) 73 | 74 | 75 | def make_metadata(labels, output_path): 76 | if len(labels.shape) == 2: 77 | labels = labels.argmax(axis=1) 78 | metadata_file = open(os.path.join(output_path, 'labels.tsv'), 'w') 79 | metadata_file.write('Name\tClass\n') 80 | for i in range(len(labels)): 81 | metadata_file.write('%06d\t%d\n' % (i, labels[i])) 82 | metadata_file.close() 83 | 84 | 85 | def make_embed_tensor(sess, embed_vectors, embed_idx): 86 | if len(embed_vectors.shape) != 2: 87 | embed_tensor = tf.Variable(np.array(embed_vectors).reshape(len(embed_vectors), -1), 88 | name=('embed_%s' % embed_idx)) 89 | else: 90 | embed_tensor = tf.Variable(embed_vectors, name=('embed_%s' % embed_idx)) 91 | 92 | sess.run(embed_tensor.initializer) 93 | return embed_tensor 94 | 95 | 96 | def write_projector_config(config, tensor_name, output_path, image_size, channel, summary_writer, labels): 97 | embedding = config.embeddings.add() 98 | embedding.tensor_name = tensor_name 99 | if labels is not None and len(labels) > 0: 100 | embedding.metadata_path = os.path.join(output_path, 'labels.tsv') 101 | embedding.sprite.image_path = os.path.join(output_path, 'sprite.png') 102 | if channel == 1: 103 | embedding.sprite.single_image_dim.extend([image_size, image_size]) 104 | else: 105 | embedding.sprite.single_image_dim.extend([image_size, image_size, channel]) 106 | projector.visualize_embeddings(summary_writer, config) 107 | 108 | 109 | def save_model(sess, output_path): 110 | # saver = tf.train.Saver([embed_tensor]) 111 | saver = tf.train.Saver() 112 | saver.save(sess, os.path.join(output_path, 'model_embed.ckpt')) 113 | -------------------------------------------------------------------------------- /mnist_embedding_visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jireh-father/tensorboard-embedding-visualization/9cd52d942482c486fd86be50b0016118393ed8d5/mnist_embedding_visualization.jpg -------------------------------------------------------------------------------- /model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jireh-father/tensorboard-embedding-visualization/9cd52d942482c486fd86be50b0016118393ed8d5/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jireh-father/tensorboard-embedding-visualization/9cd52d942482c486fd86be50b0016118393ed8d5/model.ckpt.index -------------------------------------------------------------------------------- /model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jireh-father/tensorboard-embedding-visualization/9cd52d942482c486fd86be50b0016118393ed8d5/model.ckpt.meta -------------------------------------------------------------------------------- /test_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import embedder 3 | import os 4 | import numpy as np 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | IMAGE_SIZE = 28 8 | NUM_CHANNELS = 1 9 | NUM_LABELS = 10 10 | BATCH_SIZE = 64 11 | 12 | test_path = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | if not os.path.exists(os.path.join(test_path, 'embedding')): 15 | os.makedirs(os.path.join(test_path, 'embedding')) 16 | 17 | 18 | # 1. load model graph 19 | def model(): 20 | input_placeholder = tf.placeholder(tf.float32, shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) 21 | 22 | conv1_weights = tf.Variable(tf.truncated_normal([5, 5, NUM_CHANNELS, 32], stddev=0.1, dtype=tf.float32)) 23 | conv1_biases = tf.Variable(tf.zeros([32], dtype=tf.float32)) 24 | conv2_weights = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1, dtype=tf.float32)) 25 | conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=tf.float32)) 26 | fc1_weights = tf.Variable( 27 | tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512], stddev=0.1, dtype=tf.float32)) 28 | fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=tf.float32)) 29 | fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS], stddev=0.1, dtype=tf.float32)) 30 | fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS], dtype=tf.float32)) 31 | 32 | conv = tf.nn.conv2d(input_placeholder, conv1_weights, strides=[1, 1, 1, 1], padding='SAME') 33 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases)) 34 | pool = tf.nn.max_pool(relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 35 | conv = tf.nn.conv2d(pool, conv2_weights, strides=[1, 1, 1, 1], padding='SAME') 36 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases)) 37 | pool = tf.nn.max_pool(relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 38 | pool_shape = pool.get_shape().as_list() 39 | reshape = tf.reshape(pool, [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]]) 40 | hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases) 41 | 42 | return input_placeholder, tf.matmul(hidden, fc2_weights) + fc2_biases 43 | 44 | 45 | input_placeholder, logits = model() 46 | 47 | # 2. load dataset to visualize embedding 48 | data_sets = input_data.read_data_sets(test_path, validation_size=BATCH_SIZE) 49 | batch_dataset, batch_labels = data_sets.validation.next_batch(BATCH_SIZE) 50 | 51 | # 3. init session 52 | sess = tf.Session() 53 | sess.run(tf.global_variables_initializer()) 54 | 55 | # 4. load pre-trained model file 56 | saver = tf.train.Saver() 57 | saver.restore(sess, os.path.join(test_path, 'model.ckpt')) 58 | 59 | feed_dict = {input_placeholder: batch_dataset.reshape([BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS])} 60 | activations = sess.run(logits, feed_dict) 61 | 62 | # 5. summary embedding 63 | embedder.summary_embedding(sess=sess, dataset=batch_dataset, embedding_list=[activations], 64 | embedding_path=os.path.join(test_path, 'embedding'), 65 | image_size=IMAGE_SIZE, channel=NUM_CHANNELS, labels=batch_labels) 66 | -------------------------------------------------------------------------------- /test_mnist_large_data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import embedder 3 | import os 4 | import numpy as np 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | IMAGE_SIZE = 28 8 | NUM_CHANNELS = 1 9 | NUM_LABELS = 10 10 | BATCH_SIZE = 64 11 | 12 | test_path = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | if not os.path.exists(os.path.join(test_path, 'embedding')): 15 | os.makedirs(os.path.join(test_path, 'embedding')) 16 | 17 | 18 | # 1. load model graph 19 | def model(): 20 | input_placeholder = tf.placeholder(tf.float32, shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) 21 | 22 | conv1_weights = tf.Variable(tf.truncated_normal([5, 5, NUM_CHANNELS, 32], stddev=0.1, dtype=tf.float32)) 23 | conv1_biases = tf.Variable(tf.zeros([32], dtype=tf.float32)) 24 | conv2_weights = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1, dtype=tf.float32)) 25 | conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=tf.float32)) 26 | fc1_weights = tf.Variable( 27 | tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512], stddev=0.1, dtype=tf.float32)) 28 | fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=tf.float32)) 29 | fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS], stddev=0.1, dtype=tf.float32)) 30 | fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS], dtype=tf.float32)) 31 | 32 | conv = tf.nn.conv2d(input_placeholder, conv1_weights, strides=[1, 1, 1, 1], padding='SAME') 33 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases)) 34 | pool = tf.nn.max_pool(relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 35 | conv = tf.nn.conv2d(pool, conv2_weights, strides=[1, 1, 1, 1], padding='SAME') 36 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases)) 37 | pool = tf.nn.max_pool(relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 38 | pool_shape = pool.get_shape().as_list() 39 | reshape = tf.reshape(pool, [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]]) 40 | hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases) 41 | 42 | return input_placeholder, tf.matmul(hidden, fc2_weights) + fc2_biases 43 | 44 | 45 | input_placeholder, logits = model() 46 | 47 | # 2. load dataset to visualize embedding 48 | data_sets = input_data.read_data_sets(test_path, validation_size=BATCH_SIZE) 49 | 50 | # 3. init session 51 | sess = tf.Session() 52 | sess.run(tf.global_variables_initializer()) 53 | 54 | # 4. load pre-trained model file 55 | saver = tf.train.Saver() 56 | saver.restore(sess, os.path.join(test_path, 'model.ckpt')) 57 | 58 | # 6. if you want to use large data 59 | total_dataset = None 60 | total_labels = None 61 | total_activations = None 62 | for i in range(10): 63 | batch_dataset, batch_labels = data_sets.validation.next_batch(BATCH_SIZE) 64 | feed_dict = {input_placeholder: batch_dataset.reshape([BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS])} 65 | activations = sess.run(logits, feed_dict) 66 | 67 | if total_dataset is None: 68 | total_dataset = batch_dataset 69 | total_labels = batch_labels 70 | total_activations = activations 71 | else: 72 | total_dataset = np.append(batch_dataset, total_dataset, axis=0) 73 | total_labels = np.append(batch_labels, total_labels, axis=0) 74 | total_activations = np.append(activations, total_activations, axis=0) 75 | 76 | embedder.summary_embedding(sess=sess, dataset=total_dataset, embedding_list=[total_activations], 77 | embedding_path=os.path.join(test_path, 'embedding'), image_size=IMAGE_SIZE, 78 | channel=NUM_CHANNELS, labels=total_labels) 79 | --------------------------------------------------------------------------------