├── README.md ├── deploy.py ├── imgs ├── architecture.png ├── prune.png ├── quantize.png ├── sparse_conv1.png ├── sparse_conv2.png └── train_accuracy.png ├── layers.py ├── train.py └── unit_tests.py /README.md: -------------------------------------------------------------------------------- 1 | ## Deep compression 2 | 3 | TensorFlow implementation of paper: Song Han, Huizi Mao, William J. Dally. Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding. 4 | 5 | The goal is to compress the neural network using weights pruning and quantization with no loss of accuracy. 6 | 7 | Neural network architecture:
8 | 9 | 10 | Test accuracy during training:
11 | 12 | 13 | ### 1. Full trainig. 14 | 15 | Train for number of iterations with gradient descent adjusting all the weights in every layer. 16 | 17 | ### 2. Pruning and finetuning. 18 | 19 | Once in a while remove weights lower than a threshold. In the meantime finetune remaining weights to recover accuracy.
20 | 21 | 22 | ### 3. Quantization and finetuning. 23 | 24 | Quantization is done after pruning. Cluster remainig weights using k-means. Ater that finetune centroids of remaining quantized weights to recover accuracy. Each layer weights are quantized independently.
25 | 26 | 27 | ### 4. Deployment. 28 | 29 | Fully connected layers are done as sparse matmul operation. TensorFlow doesn't allow to do sparse convolutions. Convolution layers are explicitly transformed to sparse matrix operations with full control over valid weights. 30 | 31 | Simple (input_depth=1, output_depth=1) convolution as matrix operation (notice padding type and stride value):
32 | 33 | 34 | Full (input_depth>1, output_depth>1) convolution as matrix operation:
35 | 36 | 37 | I do not make efficient use of quantization during deployment. It is possible to do it using TensorFlow operations, but it would be super slow, as for each output unit we need to create N_clusters sparse tensors from input data, reduce_sum in each tensor, multiply it by clusters and add tensor values resulting in output unit value. To do it efficiently, it requires to write kernel on GPU, which I intend to do in the future. 38 | -------------------------------------------------------------------------------- /deploy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | from layers import FcLayerDeploy, ConvLayerDeploy 5 | 6 | def load_weights(directory, name): 7 | 8 | weights = np.load(directory + '/' + name + '-weights.npy') 9 | prune_mask = np.load(directory + '/' + name + '-prune-mask.npy') 10 | 11 | return weights, prune_mask 12 | 13 | if __name__ == "__main__": 14 | 15 | weights_dir = './weights' 16 | 17 | x_PH = tf.placeholder(tf.float32, [None, 28, 28, 1]) 18 | 19 | weights, prune_mask = load_weights(weights_dir, 'conv1') 20 | L1 = ConvLayerDeploy(weights, prune_mask, x_PH.shape[1], x_PH.shape[2], 2, 'conv1') 21 | x = L1.forward_matmul_preprocess(x_PH) 22 | x = tf.nn.relu(L1.forward_matmul(x)) 23 | x = L1.forward_matmul_postprocess(x) 24 | 25 | weights, prune_mask = load_weights(weights_dir, 'conv2') 26 | L2 = ConvLayerDeploy(weights, prune_mask, x.shape[1], x.shape[2], 2, 'conv2') 27 | x = L2.forward_matmul_preprocess(x) 28 | x = tf.nn.relu(L2.forward_matmul(x)) 29 | x = L2.forward_matmul_postprocess(x) 30 | 31 | x = tf.reshape(x, [-1, 7 * 7 * 64]) 32 | 33 | weights, prune_mask = load_weights(weights_dir, 'fc1') 34 | L3 = FcLayerDeploy(weights, prune_mask, 'fc1') 35 | x = tf.nn.relu(L3.forward_matmul(x)) 36 | 37 | weights, prune_mask = load_weights(weights_dir, 'fc2') 38 | L4 = FcLayerDeploy(weights, prune_mask, 'fc2') 39 | logits = L4.forward_matmul(x) 40 | 41 | labels = tf.placeholder(tf.float32, [None, 10]) 42 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels, 1)) 43 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 44 | 45 | from tensorflow.examples.tutorials.mnist import input_data 46 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 47 | 48 | sess = tf.Session() 49 | 50 | batches_acc = [] 51 | for i in range(10): 52 | 53 | batch_x, batch_y = mnist.test.next_batch(1000) 54 | batch_x = np.reshape(batch_x,(-1, 28, 28, 1)) 55 | 56 | batch_acc = sess.run(accuracy,feed_dict={x_PH: batch_x, labels: batch_y}) 57 | batches_acc.append(batch_acc) 58 | 59 | acc = np.mean(batches_acc) 60 | 61 | print 'deploy accuracy:', acc 62 | -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/architecture.png -------------------------------------------------------------------------------- /imgs/prune.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/prune.png -------------------------------------------------------------------------------- /imgs/quantize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/quantize.png -------------------------------------------------------------------------------- /imgs/sparse_conv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/sparse_conv1.png -------------------------------------------------------------------------------- /imgs/sparse_conv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/sparse_conv2.png -------------------------------------------------------------------------------- /imgs/train_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/train_accuracy.png -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import sys 5 | 6 | class LayerTrain(object): 7 | 8 | def __init__(self, in_depth, out_depth, N_clusters, name): 9 | 10 | self.name = name 11 | 12 | if 'conv' in name: 13 | self.w = tf.Variable(tf.random.normal([5, 5, in_depth, out_depth], stddev=0.1)) 14 | 15 | elif 'fc' in name: 16 | self.w = tf.Variable(tf.random.normal([in_depth, out_depth], stddev=0.1)) 17 | 18 | self.w_PH = tf.placeholder(tf.float32, self.w.shape) 19 | self.assign_w = tf.assign(self.w, self.w_PH) 20 | self.num_total_weights = np.prod(self.w.shape) 21 | 22 | # mask placeholder for pruning 23 | # ones - valid weights, zero - pruned weights 24 | self.pruning_mask_data = np.ones(self.w.shape, dtype=np.float32) 25 | 26 | self.N_clusters = N_clusters; # for quantization 27 | 28 | def forward(self,x): 29 | 30 | if 'conv' in self.name: 31 | return tf.nn.conv2d(x, self.w, strides=[1, 2, 2, 1], padding='SAME') 32 | 33 | elif 'fc' in self.name: 34 | return tf.matmul(x, self.w) 35 | 36 | def save_weights_histogram(self, sess, directory, iteration): 37 | 38 | w_data = sess.run(self.w).reshape(-1) 39 | valid_w_data = [x for x in w_data if x!=0.0] 40 | 41 | plt.grid(True) 42 | plt.hist(valid_w_data, 100, color='0.4') 43 | plt.gca().set_xlim([-0.4, 0.4]) 44 | plt.savefig(directory + '/' + self.name + '-' + str(iteration), dpi=100) 45 | plt.gcf().clear() 46 | 47 | def save_weights(self, sess, directory): 48 | 49 | w_data = sess.run(self.w) 50 | np.save(directory + '/' + self.name + '-weights', w_data) 51 | np.save(directory + '/' + self.name + '-prune-mask', self.pruning_mask_data) 52 | 53 | # ------------------------------------------------------------------------ 54 | # ----------------------------- pruning ---------------------------------- 55 | # ------------------------------------------------------------------------ 56 | 57 | # even though it's enough to impact forward pass in each train step and leave gradients updates intact 58 | # here impact only backward pass in each step - to stay consistent with quantization 59 | # theoretically weights can be pruned only once and from that point on gradients should be pruned in every iteration 60 | # but for numerical stability weights are also pruned in every iteration 61 | 62 | def prune_weights(self, sess, threshold): 63 | 64 | w_data = sess.run(self.w) 65 | # prune weights that are smaller then threshold - make them zero 66 | self.pruning_mask_data = (np.abs(w_data) >= threshold).astype(np.float32) 67 | 68 | print 'layer:', self.name 69 | print '\tremaining weights:', int(np.sum(self.pruning_mask_data)) 70 | print '\ttotal weights:', self.num_total_weights 71 | 72 | sess.run(self.assign_w, feed_dict={self.w_PH: self.pruning_mask_data*w_data}) 73 | 74 | def prune_weights_gradient(self, grad): 75 | 76 | return grad * self.pruning_mask_data 77 | 78 | # for numerical stability 79 | def prune_weights_update(self, sess): 80 | 81 | w_data = sess.run(self.w) 82 | sess.run(self.assign_w, feed_dict={self.w_PH: self.pruning_mask_data*w_data}) 83 | 84 | # ------------------------------------------------------------------------ 85 | # --------------------------- quantization ------------------------------- 86 | # ------------------------------------------------------------------------ 87 | 88 | def quantize_weights(self, sess): 89 | 90 | w_data = sess.run(self.w) 91 | 92 | # theoretically pruning mask should be taken into consideration to compute max and min data only among valid weights 93 | # but in practice with normal ditribution init there is 100% chances that min and max vals will be among valid weights 94 | max_val = np.max(w_data) 95 | min_val = np.min(w_data) 96 | 97 | # linearly initialize centroids between max and min 98 | self.centroids = np.linspace(min_val, max_val, self.N_clusters) 99 | w_data = np.expand_dims(w_data, 0) 100 | 101 | centroids_prev = np.copy(self.centroids) 102 | 103 | for i in range(20): 104 | 105 | if 'conv' in self.name: 106 | distances = np.abs(w_data - np.reshape(self.centroids,(-1, 1, 1, 1, 1))) 107 | distances = np.transpose(distances, (1,2,3,4,0)) 108 | 109 | elif 'fc' in self.name: 110 | distances = np.abs(w_data - np.reshape(self.centroids,(-1, 1, 1))) 111 | distances = np.transpose(distances, (1,2,0)) 112 | 113 | classes = np.argmin(distances, axis=-1) 114 | 115 | self.clusters_masks = [] 116 | for i in range(self.N_clusters): 117 | 118 | cluster_mask = (classes == i).astype(np.float32) * self.pruning_mask_data 119 | self.clusters_masks.append(cluster_mask) 120 | 121 | num_weights_assigned = np.sum(cluster_mask) 122 | 123 | if num_weights_assigned!=0: 124 | self.centroids[i] = np.sum(cluster_mask * w_data) / num_weights_assigned 125 | else: # do not modify 126 | pass 127 | 128 | if np.array_equal(centroids_prev, self.centroids): 129 | break 130 | 131 | centroids_prev = np.copy(self.centroids) 132 | 133 | self.quantize_weights_update(sess) 134 | 135 | print 'layer:', self.name 136 | print '\tcentroids:', self.centroids 137 | 138 | def group_and_reduce_gradient(self, grad): 139 | 140 | grad_out = np.zeros(self.w.shape, dtype=np.float32) 141 | 142 | for i in range(self.N_clusters): 143 | 144 | cluster_mask = self.clusters_masks[i] 145 | centroid_grad = np.sum(grad * cluster_mask) 146 | 147 | grad_out = grad_out + cluster_mask * centroid_grad 148 | 149 | return grad_out 150 | 151 | # for numerical stability 152 | def quantize_centroids_update(self, sess): 153 | 154 | w_data = sess.run(self.w) 155 | 156 | for i in range(self.N_clusters): 157 | 158 | cluster_mask = self.clusters_masks[i] 159 | cluster_count = np.sum(cluster_mask) 160 | 161 | if cluster_count!=0: 162 | self.centroids[i] = np.sum(cluster_mask * w_data) / cluster_count 163 | else: # do not modify 164 | pass 165 | 166 | # for numerical stability 167 | def quantize_weights_update(self, sess): 168 | 169 | w_data_updated = np.zeros(self.w.shape, dtype=np.float32) 170 | 171 | for i in range(self.N_clusters): 172 | 173 | cluster_mask = self.clusters_masks[i] 174 | centroid = self.centroids[i] 175 | 176 | w_data_updated = w_data_updated + cluster_mask * centroid 177 | 178 | sess.run(self.assign_w, feed_dict={self.w_PH: self.pruning_mask_data * w_data_updated}) 179 | 180 | #------------------------------------------------------------------------------------------------- 181 | #------------------------------------------------------------------------------------------------- 182 | #------------------------------------------------------------------------------------------------- 183 | 184 | class FcLayerDeploy(object): 185 | 186 | def __init__(self, matrix, prune_mask, name, dense=False): 187 | 188 | assert matrix.shape == prune_mask.shape 189 | 190 | self.dense = dense 191 | self.N_in, self.N_out = matrix.shape 192 | 193 | if self.dense == False: 194 | 195 | indices, values, dense_shape = [], [], [self.N_in, self.N_out] # sparse matrix representation 196 | 197 | for i in range(self.N_in): 198 | for j in range(self.N_out): 199 | 200 | # pruning mask: ones - valid weights, zero - pruned weights 201 | if prune_mask[i][j] == 0.0: 202 | continue 203 | 204 | indices.append([i,j]) 205 | values.append(matrix[i][j]) 206 | 207 | self.w_matrix = tf.SparseTensor(indices, values, dense_shape) # tf sparse matrix 208 | 209 | else: 210 | 211 | self.w_matrix = tf.constant(matrix * prune_mask) # tf dense matrix 212 | 213 | print 'layer:', name 214 | print '\tvalid matrix weights:', np.sum(prune_mask) 215 | print '\ttotal matrix weights:', np.product(self.w_matrix.shape) 216 | 217 | def forward_matmul(self, x): 218 | 219 | if self.dense == False: 220 | 221 | w = tf.sparse.transpose(self.w_matrix, (1, 0)) 222 | x = tf.transpose(x, (1, 0)) 223 | x = tf.sparse.matmul(w, x) # only left matrix can be sparse hence transpositions 224 | x = tf.transpose(x ,(1, 0)) 225 | 226 | else: 227 | 228 | x = tf.matmul(x, self.w_matrix) 229 | 230 | return x 231 | 232 | #------------------------------------------------------------------------------------------------- 233 | #------------------------------------------------------------------------------------------------- 234 | #------------------------------------------------------------------------------------------------- 235 | 236 | class ConvLayerDeploy(object): 237 | 238 | def __init__(self, tensor, prune_mask, H_in, W_in, stride, name, dense=False): 239 | 240 | assert tensor.shape == prune_mask.shape 241 | 242 | self.stride = stride 243 | self.dense = dense 244 | 245 | if self.dense == False: 246 | indices, values, dense_shape = self.tensor_to_matrix(tensor, prune_mask, H_in, W_in, stride) 247 | self.w_matrix = tf.SparseTensor(indices, values, dense_shape) # tf sparse matrix 248 | 249 | else: 250 | matrix = self.tensor_to_matrix(tensor, prune_mask, H_in, W_in, stride) 251 | self.w_matrix = tf.constant(matrix) # tf dense matrix 252 | 253 | self.w_tensor = tf.constant(tensor * prune_mask) 254 | 255 | print 'layer:', name 256 | print '\tvalid matrix weights:', int(np.sum(prune_mask)) 257 | print '\ttotal tensor weights:', np.product(self.w_tensor.shape) 258 | print '\ttotal matrix weights:', np.product(self.w_matrix.shape) 259 | 260 | def get_linear_pos(self, i, j, W): # row major 261 | 262 | return i * W + j 263 | 264 | def tensor_to_matrix(self, tensor, prune_mask, H_in, W_in, stride): 265 | 266 | # assume padding type 'SAME' and padding value 0 267 | 268 | H_out = int(H_in +1)/stride # padding 'SAME' 269 | W_out = int(W_in +1)/stride # padding 'SAME' 270 | H_in = int(H_in) 271 | W_in = int(W_in) 272 | 273 | kH, kW, D_in, D_out = tensor.shape 274 | 275 | self.D_out = D_out 276 | self.H_out = H_out 277 | self.W_out = W_out 278 | 279 | if self.dense == False: 280 | indices, values, dense_shape = [], [], [H_in * W_in * D_in, H_out* W_out * D_out] # sparse matrix 281 | else: 282 | matrix = np.zeros((H_in * W_in * D_in, H_out* W_out * D_out), dtype=np.float32) # dense matrix 283 | 284 | for d_in in range(D_in): 285 | for d_out in range(D_out): 286 | 287 | # tf.nn.conv2d implementation doesn't go from top-left spatial location but from bottom-right 288 | for i_in_center in np.arange(H_in-1, -1, -stride): # kernel input center for first axis 289 | for j_in_center in np.arange(W_in-1, -1, -stride): # kernel input center for second axis 290 | 291 | i_out = int(i_in_center / stride) 292 | j_out = int(j_in_center / stride) 293 | 294 | for i in range(kH): 295 | 296 | i_in = i_in_center + i - kH/2 297 | 298 | if i_in < 0 or i_in >= H_in: # padding value 0 299 | continue 300 | 301 | for j in range(kW): 302 | 303 | j_in = j_in_center + j -kW/2 304 | 305 | if j_in < 0 or j_in >= W_in: # padding value 0 306 | continue 307 | 308 | # pruning mask: ones - valid weights, zero - pruned weights 309 | if prune_mask[i][j][d_in][d_out] == 0.0: 310 | continue 311 | 312 | pos_in = self.get_linear_pos(i_in, j_in, W_in) + d_in * H_in * W_in 313 | pos_out = self.get_linear_pos(i_out, j_out, W_out) + d_out * H_out * W_out 314 | 315 | if self.dense == False: 316 | indices.append([pos_in, pos_out]) 317 | values.append(tensor[i][j][d_in][d_out]) 318 | else: 319 | matrix[pos_in][pos_out] = tensor[i][j][d_in][d_out] 320 | 321 | if self.dense == False: 322 | return indices, values, dense_shape 323 | else: 324 | return matrix 325 | 326 | def forward_matmul_preprocess(self, x): 327 | 328 | x = tf.transpose(x, (0, 3, 1, 2)) 329 | x = tf.reshape(x,(-1, np.product(x.shape[1:]))) 330 | 331 | return x 332 | 333 | def forward_matmul_postprocess(self, x): 334 | 335 | x = tf.reshape(x,(-1, self.D_out, self.H_out, self.W_out)) 336 | x = tf.transpose(x, (0, 2, 3, 1)) 337 | 338 | return x 339 | 340 | def forward_matmul(self, x): 341 | 342 | if self.dense == False: 343 | w = tf.sparse.transpose(self.w_matrix, (1, 0)) 344 | x = tf.transpose(x, (1, 0)) 345 | x = tf.sparse.matmul(w, x) # only left matrix can be sparse hence transpositions 346 | x = tf.transpose(x, (1, 0)) 347 | else: 348 | x = tf.matmul(x, self.w_matrix) 349 | 350 | return x 351 | 352 | def forward_conv(self, x): 353 | 354 | return tf.nn.conv2d(x, self.w_tensor, strides=[1, self.stride, self.stride, 1], padding='SAME') 355 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import sys, os, shutil 5 | 6 | from layers import LayerTrain 7 | 8 | def make_dir(directory): 9 | 10 | if os.path.exists(directory): 11 | shutil.rmtree(directory, ignore_errors=True) 12 | os.makedirs(directory) 13 | 14 | if __name__ == "__main__": 15 | 16 | histograms_dir = './histograms' 17 | weights_dir = './weights' 18 | 19 | make_dir(histograms_dir) 20 | make_dir(weights_dir) 21 | 22 | L1 = LayerTrain(1, 32, N_clusters=5, name='conv1') 23 | L2 = LayerTrain(32, 64, N_clusters=5, name='conv2') 24 | L3 = LayerTrain(7 * 7 * 64, 1024, N_clusters=5, name='fc1') 25 | L4 = LayerTrain(1024, 10, N_clusters=5, name='fc2') 26 | 27 | LAYERS = [L1, L2, L3, L4] 28 | LAYERS_WIEGHTS = [L1.w, L2.w, L3.w, L4.w] 29 | 30 | x_PH = tf.placeholder(tf.float32, [None, 28, 28, 1]) 31 | x = tf.nn.relu(L1.forward(x_PH)) 32 | x = tf.nn.relu(L2.forward(x)) 33 | x = tf.reshape(x, (-1, int(np.product(x.shape[1:])))) 34 | x = tf.nn.relu(L3.forward(x)) 35 | logits = L4.forward(x) 36 | 37 | preds = tf.nn.softmax(logits) 38 | labels = tf.placeholder(tf.float32, [None, 10]) 39 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)) 40 | 41 | optimizer = tf.train.AdamOptimizer(1e-4) 42 | gradients_vars = optimizer.compute_gradients(loss, LAYERS_WIEGHTS) 43 | grads = [grad for grad, var in gradients_vars] 44 | train_step = optimizer.apply_gradients(gradients_vars) 45 | 46 | correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) 47 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 48 | 49 | sess = tf.Session() 50 | sess.run(tf.initialize_all_variables()) 51 | 52 | from tensorflow.examples.tutorials.mnist import input_data 53 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 54 | 55 | iters = [] 56 | iters_acc = [] 57 | 58 | for i in range(1500): 59 | 60 | batch_x, batch_y = mnist.train.next_batch(50) 61 | batch_x = np.reshape(batch_x,(-1, 28, 28,1)) 62 | 63 | feed_dict={x_PH: batch_x, labels: batch_y} 64 | 65 | # ------------------------------------------------------ 66 | # --------------- full network training ---------------- 67 | # ------------------------------------------------------ 68 | 69 | if i < 500: 70 | sess.run(train_step, feed_dict=feed_dict) 71 | 72 | # ------------------------------------------------------ 73 | # ---------------------- pruning ----------------------- 74 | # ------------------------------------------------------ 75 | 76 | elif i >= 500 and i < 1000: 77 | 78 | # prune from time to time, finetune in the meantime 79 | if i%500==0: 80 | print 'iter:', i, 'prune weights' 81 | for L in LAYERS: 82 | L.prune_weights(sess, threshold=0.1) 83 | 84 | grads_data = sess.run(grads, feed_dict={x_PH: batch_x, labels: batch_y}) 85 | feed_dict = {} 86 | for L, grad, grad_data in zip(LAYERS, grads, grads_data): 87 | pruned_grad_data = L.prune_weights_gradient(grad_data) 88 | feed_dict[grad] = pruned_grad_data 89 | 90 | sess.run(train_step, feed_dict=feed_dict) 91 | 92 | # for numerical stability 93 | for L in LAYERS: 94 | L.prune_weights_update(sess) 95 | 96 | # ------------------------------------------------------ 97 | # ------------------- quantization --------------------- 98 | # ------------------------------------------------------ 99 | 100 | else: 101 | 102 | # quantize only once and then finetune 103 | if i==1000: 104 | print 'iter:', i, "quantize weights" 105 | for L in LAYERS: 106 | L.quantize_weights(sess) 107 | 108 | grads_data = sess.run(grads, feed_dict={x_PH: batch_x, labels: batch_y}) 109 | feed_dict = {} 110 | for L, grad, grad_data in zip(LAYERS, grads, grads_data): 111 | grouped_grad_data = L.group_and_reduce_gradient(grad_data) 112 | feed_dict[grad] = grouped_grad_data 113 | 114 | sess.run(train_step, feed_dict=feed_dict) 115 | 116 | # for numerical stability 117 | for L in LAYERS: 118 | L.quantize_centroids_update(sess) 119 | L.quantize_weights_update(sess) 120 | 121 | # ------------------------------------------------------ 122 | # --------------------- evaluation --------------------- 123 | # ------------------------------------------------------ 124 | 125 | if i%10 == 0: 126 | 127 | batches_acc = [] 128 | for j in range(10): 129 | 130 | batch_x, batch_y = mnist.test.next_batch(1000) 131 | batch_x = np.reshape(batch_x,(-1, 28, 28,1)) 132 | 133 | batch_acc = sess.run(accuracy,feed_dict={x_PH: batch_x, labels: batch_y}) 134 | batches_acc.append(batch_acc) 135 | 136 | acc = np.mean(batches_acc) 137 | 138 | iters.append(i) 139 | iters_acc.append(acc) 140 | print 'iter:', i, 'test accuracy:', acc 141 | 142 | for L in LAYERS: 143 | L.save_weights_histogram(sess, histograms_dir, i) 144 | 145 | for L in LAYERS: 146 | L.save_weights(sess, weights_dir) 147 | 148 | plt.figure(figsize=(10, 4)) 149 | plt.ylabel('accuracy', fontsize=12) 150 | plt.xlabel('iteration', fontsize=12) 151 | plt.grid(True) 152 | plt.plot(iters, iters_acc, color='0.4') 153 | plt.savefig('./train_acc', dpi=1200) 154 | 155 | print 'Training finished' 156 | -------------------------------------------------------------------------------- /unit_tests.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from layers import FcLayerDeploy, ConvLayerDeploy 5 | 6 | def test_ConvLayerDeploy(): 7 | 8 | weights = np.random.uniform(-10.0, 10.0, size=(3, 3, 4, 5)).astype(np.float32) 9 | prune_mask = np.random.uniform(0.0, 2.0, size=(3, 3, 4, 5)).astype(np.int32).astype(np.float32) 10 | 11 | x = np.random.uniform(0.0, 1.0, size=(2, 14, 14, 4)).astype(np.float32) 12 | x = tf.constant(x) 13 | 14 | L = ConvLayerDeploy(weights, prune_mask, 14, 14, 2, 'conv') 15 | 16 | x_matmul_sparse = L.forward_matmul_preprocess(x) 17 | y_matmul_sparse = L.forward_matmul(x_matmul_sparse) 18 | y_matmul_sparse = L.forward_matmul_postprocess(y_matmul_sparse) 19 | 20 | y_conv = L.forward_conv(x) 21 | 22 | sess = tf.Session() 23 | y_matmul_sparse_data, y_conv_data = sess.run([y_matmul_sparse, y_conv]) 24 | 25 | assert np.mean(np.abs(y_matmul_sparse_data - y_conv_data)) < 1e-6 26 | 27 | def test_FcLayerDeploy(): 28 | 29 | weights = np.random.uniform(-10.0, 10.0, size=(5, 10)).astype(np.float32) 30 | prune_mask = np.random.uniform(0.0, 2.0, size=(5, 10)).astype(np.int32).astype(np.float32) 31 | 32 | x = np.random.uniform(0.0, 1.0, size=(2, 5)).astype(np.float32) 33 | x = tf.constant(x) 34 | 35 | L_sparse = FcLayerDeploy(weights, prune_mask, 'fc') 36 | L_dense = FcLayerDeploy(weights, prune_mask, 'fc', dense=True) 37 | 38 | y_sparse = L_sparse.forward_matmul(x) 39 | y_dense = L_dense.forward_matmul(x) 40 | 41 | sess = tf.Session() 42 | y_sparse_data, y_dense_data = sess.run([y_sparse, y_dense]) 43 | 44 | assert np.mean(np.abs(y_sparse_data - y_dense_data)) < 1e-6 45 | --------------------------------------------------------------------------------