├── .gitignore ├── README.md ├── data_utils.py ├── extract.py ├── transfer_cifar10_softmax.py └── tsne.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Transfer Learning on CIFAR-10 2 | Trains a softmax regression model on CIFAR-10 using CNN pool_3 weights from inception-v3 3 | 4 | Make sure to pre-compute CNN codes before training by running `serialize_data()` 5 | 6 | This is the code that supplements my original [blog post](https://medium.com/@st553/using-transfer-learning-to-classify-images-with-tensorflow-b0f3142b9366) 7 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import numpy as np 3 | import os 4 | 5 | # Make sure to download and extract CIFAR-10 data before 6 | # running this (https://www.cs.toronto.edu/~kriz/cifar.html) 7 | 8 | def load_CIFAR_batch(filename): 9 | """ load single batch of cifar """ 10 | with open(filename, 'rb') as f: 11 | datadict = pickle.load(f) 12 | X = datadict['data'] 13 | Y = datadict['labels'] 14 | X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float") 15 | Y = np.array(Y) 16 | return X, Y 17 | 18 | def load_CIFAR10(ROOT): 19 | """ load all of cifar """ 20 | xs = [] 21 | ys = [] 22 | for b in range(1,6): 23 | f = os.path.join(ROOT, 'data_batch_%d' % (b, )) 24 | X, Y = load_CIFAR_batch(f) 25 | xs.append(X) 26 | ys.append(Y) 27 | Xtr = np.concatenate(xs) 28 | Ytr = np.concatenate(ys) 29 | del X, Y 30 | Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) 31 | return Xtr, Ytr, Xte, Yte 32 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.platform import gfile 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | model = 'resources/classify_image_graph_def.pb' 6 | 7 | def create_graph(): 8 | """"Creates a graph from saved GraphDef file and returns a saver.""" 9 | # Creates graph from saved graph_def.pb. 10 | print 'Loading graph...' 11 | with tf.Session() as sess: 12 | with gfile.FastGFile(model, 'rb') as f: 13 | graph_def = tf.GraphDef() 14 | graph_def.ParseFromString(f.read()) 15 | _ = tf.import_graph_def(graph_def, name='') 16 | return sess.graph 17 | 18 | 19 | def pool3_features(sess,X_input): 20 | """ 21 | Call create_graph() before calling this 22 | """ 23 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 24 | pool3_features = sess.run(pool3,{'DecodeJpeg:0': X_input[i,:]}) 25 | return np.squeeze(pool3_features) 26 | 27 | def batch_pool3_features(sess,X_input): 28 | """ 29 | Currently tensorflow can't extract pool3 in batch so this is slow: 30 | https://github.com/tensorflow/tensorflow/issues/1021 31 | """ 32 | n_train = X_input.shape[0] 33 | print 'Extracting features for %i rows' % n_train 34 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 35 | X_pool3 = [] 36 | for i in range(n_train): 37 | print 'Iteration %i' % i 38 | pool3_features = sess.run(pool3,{'DecodeJpeg:0': X_input[i,:]}) 39 | X_pool3.append(np.squeeze(pool3_features)) 40 | return np.array(X_pool3) 41 | 42 | def iterate_mini_batches(X_input,Y_input,batch_size): 43 | n_train = X_input.shape[0] 44 | for ndx in range(0, n_train, batch_size): 45 | yield X_input[ndx:min(ndx + batch_size, n_train)], Y_input[ndx:min(ndx + batch_size, n_train)] 46 | 47 | -------------------------------------------------------------------------------- /transfer_cifar10_softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trying out the transfer learning example from: 3 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py 4 | """ 5 | 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | from sklearn import cross_validation 10 | from data_utils import load_CIFAR10 11 | from extract import create_graph, iterate_mini_batches, batch_pool3_features 12 | from datetime import datetime 13 | import matplotlib.pyplot as plt 14 | from tsne import tsne 15 | import seaborn as sns 16 | import pandas as pd 17 | 18 | # 19 | # CIFAR10 Ops 20 | # # 21 | 22 | def load_pool3_data(): 23 | # Update these file names after you serialize pool_3 values 24 | X_test_file = 'X_test_20160212-00:06:14.npy' 25 | y_test_file = 'y_test_20160212-00:06:14.npy' 26 | X_train_file = 'X_train_20160212-00:06:14.npy' 27 | y_train_file = 'y_train_20160212-00:06:14.npy' 28 | return np.load(X_train_file), np.load(y_train_file), np.load(X_test_file), np.load(y_test_file) 29 | 30 | def serialize_cifar_pool3(X,filename): 31 | print 'About to generate file: %s' % filename 32 | sess = tf.InteractiveSession() 33 | X_pool3 = batch_pool3_features(sess,X) 34 | np.save(filename,X_pool3) 35 | 36 | def serialize_data(): 37 | X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) 38 | datetime_str = datetime.datetime.today().strftime('%Y%m%d-%H:%M:%S') 39 | serialize_cifar_pool3(X_train, 'X_train_'+datetime_str) 40 | serialize_cifar_pool3(X_test, 'X_test_'+datetime_str) 41 | np.save('y_train_'+datetime_str,y_train) 42 | np.save('y_test_'+datetime_str,y_test) 43 | 44 | classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']) 45 | cifar10_dir = 'resources/datasets/cifar-10-batches-py' 46 | X_train_orig, y_train_orig, X_test_orig, y_test_orig = load_CIFAR10(cifar10_dir) 47 | X_train_pool3, y_train_pool3, X_test_pool3, y_test_pool3 = load_pool3_data() 48 | X_train, X_validation, Y_train, y_validation = cross_validation.train_test_split(X_train_pool3, y_train_pool3, test_size=0.20, random_state=42) 49 | 50 | 51 | print 'Training data shape: ', X_train_pool3.shape 52 | print 'Training labels shape: ', y_train_pool3.shape 53 | print 'Test data shape: ', X_test_pool3.shape 54 | print 'Test labels shape: ', y_test_pool3.shape 55 | 56 | # 57 | # Tensorflow stuff 58 | # # 59 | 60 | FLAGS = tf.app.flags.FLAGS 61 | BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape' 62 | BOTTLENECK_TENSOR_SIZE = 2048 63 | tf.app.flags.DEFINE_integer('how_many_training_steps', 100, 64 | """How many training steps to run before ending.""") 65 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 66 | """How large a learning rate to use when training.""") 67 | tf.app.flags.DEFINE_string('final_tensor_name', 'final_result', 68 | """The name of the output classification layer in""" 69 | """ the retrained graph.""") 70 | tf.app.flags.DEFINE_integer('eval_step_interval', 10, 71 | """How often to evaluate the training results.""") 72 | 73 | 74 | 75 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py 76 | def ensure_name_has_port(tensor_name): 77 | """Makes sure that there's a port number at the end of the tensor name. 78 | Args: 79 | tensor_name: A string representing the name of a tensor in a graph. 80 | Returns: 81 | The input string with a :0 appended if no port was specified. 82 | """ 83 | if ':' not in tensor_name: 84 | name_with_port = tensor_name + ':0' 85 | else: 86 | name_with_port = tensor_name 87 | return name_with_port 88 | 89 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py 90 | def add_final_training_ops(graph, class_count, final_tensor_name, 91 | ground_truth_tensor_name): 92 | """Adds a new softmax and fully-connected layer for training. 93 | We need to retrain the top layer to identify our new classes, so this function 94 | adds the right operations to the graph, along with some variables to hold the 95 | weights, and then sets up all the gradients for the backward pass. 96 | The set up for the softmax and fully-connected layers is based on: 97 | https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html 98 | Args: 99 | graph: Container for the existing model's Graph. 100 | class_count: Integer of how many categories of things we're trying to 101 | recognize. 102 | final_tensor_name: Name string for the new final node that produces results. 103 | ground_truth_tensor_name: Name string of the node we feed ground truth data 104 | into. 105 | Returns: 106 | Nothing. 107 | """ 108 | bottleneck_tensor = graph.get_tensor_by_name(ensure_name_has_port( 109 | BOTTLENECK_TENSOR_NAME)) 110 | layer_weights = tf.Variable( 111 | tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), 112 | name='final_weights') 113 | layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') 114 | logits = tf.matmul(bottleneck_tensor, layer_weights, 115 | name='final_matmul') + layer_biases 116 | tf.nn.softmax(logits, name=final_tensor_name) 117 | ground_truth_placeholder = tf.placeholder(tf.float32, 118 | [None, class_count], 119 | name=ground_truth_tensor_name) 120 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 121 | logits, ground_truth_placeholder) 122 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 123 | train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize( 124 | cross_entropy_mean) 125 | return train_step, cross_entropy_mean 126 | 127 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py 128 | def add_evaluation_step(graph, final_tensor_name, ground_truth_tensor_name): 129 | """Inserts the operations we need to evaluate the accuracy of our results. 130 | Args: 131 | graph: Container for the existing model's Graph. 132 | final_tensor_name: Name string for the new final node that produces results. 133 | ground_truth_tensor_name: Name string for the node we feed ground truth data 134 | into. 135 | Returns: 136 | Nothing. 137 | """ 138 | result_tensor = graph.get_tensor_by_name(ensure_name_has_port( 139 | final_tensor_name)) 140 | ground_truth_tensor = graph.get_tensor_by_name(ensure_name_has_port( 141 | ground_truth_tensor_name)) 142 | correct_prediction = tf.equal( 143 | tf.argmax(result_tensor, 1), tf.argmax(ground_truth_tensor, 1)) 144 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, 'float')) 145 | return evaluation_step 146 | 147 | def encode_one_hot(nclasses,y): 148 | return np.eye(nclasses)[y] 149 | 150 | def do_train(sess,X_input, Y_input, X_validation, Y_validation): 151 | ground_truth_tensor_name = 'ground_truth' 152 | mini_batch_size = 10 153 | n_train = X_input.shape[0] 154 | 155 | graph = create_graph() 156 | 157 | train_step, cross_entropy = add_final_training_ops( 158 | graph, len(classes), FLAGS.final_tensor_name, 159 | ground_truth_tensor_name) 160 | 161 | init = tf.initialize_all_variables() 162 | sess.run(init) 163 | 164 | evaluation_step = add_evaluation_step(graph, FLAGS.final_tensor_name, ground_truth_tensor_name) 165 | 166 | # Get some layers we'll need to access during training. 167 | bottleneck_tensor = graph.get_tensor_by_name(ensure_name_has_port(BOTTLENECK_TENSOR_NAME)) 168 | ground_truth_tensor = graph.get_tensor_by_name(ensure_name_has_port(ground_truth_tensor_name)) 169 | 170 | i=0 171 | epocs = 1 172 | for epoch in range(epocs): 173 | shuffledRange = np.random.permutation(n_train) 174 | y_one_hot_train = encode_one_hot(len(classes), Y_input) 175 | y_one_hot_validation = encode_one_hot(len(classes), Y_validation) 176 | shuffledX = X_input[shuffledRange,:] 177 | shuffledY = y_one_hot_train[shuffledRange] 178 | for Xi, Yi in iterate_mini_batches(shuffledX, shuffledY, mini_batch_size): 179 | sess.run(train_step, 180 | feed_dict={bottleneck_tensor: Xi, 181 | ground_truth_tensor: Yi}) 182 | # Every so often, print out how well the graph is training. 183 | is_last_step = (i + 1 == FLAGS.how_many_training_steps) 184 | if (i % FLAGS.eval_step_interval) == 0 or is_last_step: 185 | train_accuracy, cross_entropy_value = sess.run( 186 | [evaluation_step, cross_entropy], 187 | feed_dict={bottleneck_tensor: Xi, 188 | ground_truth_tensor: Yi}) 189 | validation_accuracy = sess.run( 190 | evaluation_step, 191 | feed_dict={bottleneck_tensor: X_validation, 192 | ground_truth_tensor: y_one_hot_validation}) 193 | print('%s: Step %d: Train accuracy = %.1f%%, Cross entropy = %f, Validation accuracy = %.1f%%' % 194 | (datetime.now(), i, train_accuracy * 100, cross_entropy_value, validation_accuracy * 100)) 195 | i+=1 196 | 197 | test_accuracy = sess.run( 198 | evaluation_step, 199 | feed_dict={bottleneck_tensor: X_test_pool3, 200 | ground_truth_tensor: encode_one_hot(len(classes), y_test_pool3)}) 201 | print('Final test accuracy = %.1f%%' % (test_accuracy * 100)) 202 | 203 | def show_test_images(sess,X_img, X_features, Y): 204 | n = X_img.shape[0] 205 | 206 | def rand_ordering(): 207 | return np.random.permutation(n) 208 | 209 | def sequential_ordering(): 210 | return range(n) 211 | 212 | for i in sequential_ordering(): 213 | Xi_img=X_img[i,:] 214 | Xi_features=X_features[i,:].reshape(1,2048) 215 | result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name)) 216 | probs = sess.run(result_tensor, 217 | feed_dict={'pool_3/_reshape:0': Xi_features}) 218 | predicted_class=classes[np.argmax(probs)] 219 | Yi = Y[i] 220 | Yi_label = classes[Yi] 221 | plt.title('true=%s, predicted=%s' % (Yi_label, predicted_class)) 222 | plt.imshow(Xi_img.astype('uint8')) 223 | print Yi_label 224 | plt.show() 225 | plt.close() 226 | 227 | def plot_tsne(X_input_pool3, Y_input,n=10000): 228 | indicies = np.random.permutation(X_input_pool3.shape[0])[0:n] 229 | Y=tsne(X_input_pool3[indicies]) 230 | num_labels=Y_input[indicies] 231 | labels=classes[num_labels] 232 | df=pd.DataFrame(np.column_stack((Y,num_labels,labels)), columns=["x1","x2","y","y_label"]) 233 | sns.lmplot("x1","x2",data=df.convert_objects(convert_numeric=True),hue="y_label",fit_reg=False,legend=True,palette="Set1") 234 | print 'done' 235 | 236 | # plot_tsne(X_train_pool3,y_train_pool3) 237 | sess = tf.InteractiveSession() 238 | do_train(sess,X_train,Y_train,X_validation,y_validation) 239 | # show_test_images(sess,X_test_orig, X_test_pool3, y_test_orig) 240 | 241 | 242 | -------------------------------------------------------------------------------- /tsne.py: -------------------------------------------------------------------------------- 1 | # 2 | # tsne.py 3 | # 4 | # Implementation of t-SNE in Python. The implementation was tested on Python 2.7.10, and it requires a working 5 | # installation of NumPy. The implementation comes with an example on the MNIST dataset. In order to plot the 6 | # results of this example, a working installation of matplotlib is required. 7 | # 8 | # The example can be run by executing: `ipython tsne.py` 9 | # 10 | # 11 | # Created by Laurens van der Maaten on 20-12-08. 12 | # Copyright (c) 2008 Tilburg University. All rights reserved. 13 | 14 | import numpy as Math 15 | import pylab as Plot 16 | 17 | def Hbeta(D = Math.array([]), beta = 1.0): 18 | """Compute the perplexity and the P-row for a specific value of the precision of a Gaussian distribution.""" 19 | 20 | # Compute P-row and corresponding perplexity 21 | P = Math.exp(-D.copy() * beta); 22 | sumP = sum(P); 23 | H = Math.log(sumP) + beta * Math.sum(D * P) / sumP; 24 | P = P / sumP; 25 | return H, P; 26 | 27 | 28 | def x2p(X = Math.array([]), tol = 1e-5, perplexity = 30.0): 29 | """Performs a binary search to get P-values in such a way that each conditional Gaussian has the same perplexity.""" 30 | 31 | # Initialize some variables 32 | print "Computing pairwise distances..." 33 | (n, d) = X.shape; 34 | sum_X = Math.sum(Math.square(X), 1); 35 | D = Math.add(Math.add(-2 * Math.dot(X, X.T), sum_X).T, sum_X); 36 | P = Math.zeros((n, n)); 37 | beta = Math.ones((n, 1)); 38 | logU = Math.log(perplexity); 39 | 40 | # Loop over all datapoints 41 | for i in range(n): 42 | 43 | # Print progress 44 | if i % 500 == 0: 45 | print "Computing P-values for point ", i, " of ", n, "..." 46 | 47 | # Compute the Gaussian kernel and entropy for the current precision 48 | betamin = -Math.inf; 49 | betamax = Math.inf; 50 | Di = D[i, Math.concatenate((Math.r_[0:i], Math.r_[i+1:n]))]; 51 | (H, thisP) = Hbeta(Di, beta[i]); 52 | 53 | # Evaluate whether the perplexity is within tolerance 54 | Hdiff = H - logU; 55 | tries = 0; 56 | while Math.abs(Hdiff) > tol and tries < 50: 57 | 58 | # If not, increase or decrease precision 59 | if Hdiff > 0: 60 | betamin = beta[i].copy(); 61 | if betamax == Math.inf or betamax == -Math.inf: 62 | beta[i] = beta[i] * 2; 63 | else: 64 | beta[i] = (beta[i] + betamax) / 2; 65 | else: 66 | betamax = beta[i].copy(); 67 | if betamin == Math.inf or betamin == -Math.inf: 68 | beta[i] = beta[i] / 2; 69 | else: 70 | beta[i] = (beta[i] + betamin) / 2; 71 | 72 | # Recompute the values 73 | (H, thisP) = Hbeta(Di, beta[i]); 74 | Hdiff = H - logU; 75 | tries = tries + 1; 76 | 77 | # Set the final row of P 78 | P[i, Math.concatenate((Math.r_[0:i], Math.r_[i+1:n]))] = thisP; 79 | 80 | # Return final P-matrix 81 | print "Mean value of sigma: ", Math.mean(Math.sqrt(1 / beta)); 82 | return P; 83 | 84 | 85 | def pca(X = Math.array([]), no_dims = 50): 86 | """Runs PCA on the NxD array X in order to reduce its dimensionality to no_dims dimensions.""" 87 | 88 | print "Preprocessing the data using PCA..." 89 | (n, d) = X.shape; 90 | X = X - Math.tile(Math.mean(X, 0), (n, 1)); 91 | (l, M) = Math.linalg.eig(Math.dot(X.T, X)); 92 | Y = Math.dot(X, M[:,0:no_dims]); 93 | return Y; 94 | 95 | 96 | def tsne(X = Math.array([]), no_dims = 2, initial_dims = 50, perplexity = 30.0): 97 | """Runs t-SNE on the dataset in the NxD array X to reduce its dimensionality to no_dims dimensions. 98 | The syntaxis of the function is Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.""" 99 | 100 | # Check inputs 101 | if isinstance(no_dims, float): 102 | print "Error: array X should have type float."; 103 | return -1; 104 | if round(no_dims) != no_dims: 105 | print "Error: number of dimensions should be an integer."; 106 | return -1; 107 | 108 | # Initialize variables 109 | X = pca(X, initial_dims).real; 110 | (n, d) = X.shape; 111 | max_iter = 1000; 112 | initial_momentum = 0.5; 113 | final_momentum = 0.8; 114 | eta = 500; 115 | min_gain = 0.01; 116 | Y = Math.random.randn(n, no_dims); 117 | dY = Math.zeros((n, no_dims)); 118 | iY = Math.zeros((n, no_dims)); 119 | gains = Math.ones((n, no_dims)); 120 | 121 | # Compute P-values 122 | P = x2p(X, 1e-5, perplexity); 123 | P = P + Math.transpose(P); 124 | P = P / Math.sum(P); 125 | P = P * 4; # early exaggeration 126 | P = Math.maximum(P, 1e-12); 127 | 128 | # Run iterations 129 | for iter in range(max_iter): 130 | 131 | # Compute pairwise affinities 132 | sum_Y = Math.sum(Math.square(Y), 1); 133 | num = 1 / (1 + Math.add(Math.add(-2 * Math.dot(Y, Y.T), sum_Y).T, sum_Y)); 134 | num[range(n), range(n)] = 0; 135 | Q = num / Math.sum(num); 136 | Q = Math.maximum(Q, 1e-12); 137 | 138 | # Compute gradient 139 | PQ = P - Q; 140 | for i in range(n): 141 | dY[i,:] = Math.sum(Math.tile(PQ[:,i] * num[:,i], (no_dims, 1)).T * (Y[i,:] - Y), 0); 142 | 143 | # Perform the update 144 | if iter < 20: 145 | momentum = initial_momentum 146 | else: 147 | momentum = final_momentum 148 | gains = (gains + 0.2) * ((dY > 0) != (iY > 0)) + (gains * 0.8) * ((dY > 0) == (iY > 0)); 149 | gains[gains < min_gain] = min_gain; 150 | iY = momentum * iY - eta * (gains * dY); 151 | Y = Y + iY; 152 | Y = Y - Math.tile(Math.mean(Y, 0), (n, 1)); 153 | 154 | # Compute current value of cost function 155 | if (iter + 1) % 10 == 0: 156 | C = Math.sum(P * Math.log(P / Q)); 157 | print "Iteration ", (iter + 1), ": error is ", C 158 | 159 | # Stop lying about P-values 160 | if iter == 100: 161 | P = P / 4; 162 | 163 | # Return solution 164 | return Y; 165 | 166 | 167 | if __name__ == "__main__": 168 | print "Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset." 169 | print "Running example on 2,500 MNIST digits..." 170 | X = Math.loadtxt("mnist2500_X.txt"); 171 | labels = Math.loadtxt("mnist2500_labels.txt"); 172 | Y = tsne(X, 2, 50, 20.0); 173 | Plot.scatter(Y[:,0], Y[:,1], 20, labels); 174 | Plot.show(); 175 | print 'done' 176 | --------------------------------------------------------------------------------