├── README.md
├── deploy.py
├── imgs
├── architecture.png
├── prune.png
├── quantize.png
├── sparse_conv1.png
├── sparse_conv2.png
└── train_accuracy.png
├── layers.py
├── train.py
└── unit_tests.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Deep compression
2 |
3 | TensorFlow implementation of paper: Song Han, Huizi Mao, William J. Dally. Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding.
4 |
5 | The goal is to compress the neural network using weights pruning and quantization with no loss of accuracy.
6 |
7 | Neural network architecture:
8 |
9 |
10 | Test accuracy during training:
11 |
12 |
13 | ### 1. Full trainig.
14 |
15 | Train for number of iterations with gradient descent adjusting all the weights in every layer.
16 |
17 | ### 2. Pruning and finetuning.
18 |
19 | Once in a while remove weights lower than a threshold. In the meantime finetune remaining weights to recover accuracy.
20 |
21 |
22 | ### 3. Quantization and finetuning.
23 |
24 | Quantization is done after pruning. Cluster remainig weights using k-means. Ater that finetune centroids of remaining quantized weights to recover accuracy. Each layer weights are quantized independently.
25 |
26 |
27 | ### 4. Deployment.
28 |
29 | Fully connected layers are done as sparse matmul operation. TensorFlow doesn't allow to do sparse convolutions. Convolution layers are explicitly transformed to sparse matrix operations with full control over valid weights.
30 |
31 | Simple (input_depth=1, output_depth=1) convolution as matrix operation (notice padding type and stride value):
32 |
33 |
34 | Full (input_depth>1, output_depth>1) convolution as matrix operation:
35 |
36 |
37 | I do not make efficient use of quantization during deployment. It is possible to do it using TensorFlow operations, but it would be super slow, as for each output unit we need to create N_clusters sparse tensors from input data, reduce_sum in each tensor, multiply it by clusters and add tensor values resulting in output unit value. To do it efficiently, it requires to write kernel on GPU, which I intend to do in the future.
38 |
--------------------------------------------------------------------------------
/deploy.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import sys
4 | from layers import FcLayerDeploy, ConvLayerDeploy
5 |
6 | def load_weights(directory, name):
7 |
8 | weights = np.load(directory + '/' + name + '-weights.npy')
9 | prune_mask = np.load(directory + '/' + name + '-prune-mask.npy')
10 |
11 | return weights, prune_mask
12 |
13 | if __name__ == "__main__":
14 |
15 | weights_dir = './weights'
16 |
17 | x_PH = tf.placeholder(tf.float32, [None, 28, 28, 1])
18 |
19 | weights, prune_mask = load_weights(weights_dir, 'conv1')
20 | L1 = ConvLayerDeploy(weights, prune_mask, x_PH.shape[1], x_PH.shape[2], 2, 'conv1')
21 | x = L1.forward_matmul_preprocess(x_PH)
22 | x = tf.nn.relu(L1.forward_matmul(x))
23 | x = L1.forward_matmul_postprocess(x)
24 |
25 | weights, prune_mask = load_weights(weights_dir, 'conv2')
26 | L2 = ConvLayerDeploy(weights, prune_mask, x.shape[1], x.shape[2], 2, 'conv2')
27 | x = L2.forward_matmul_preprocess(x)
28 | x = tf.nn.relu(L2.forward_matmul(x))
29 | x = L2.forward_matmul_postprocess(x)
30 |
31 | x = tf.reshape(x, [-1, 7 * 7 * 64])
32 |
33 | weights, prune_mask = load_weights(weights_dir, 'fc1')
34 | L3 = FcLayerDeploy(weights, prune_mask, 'fc1')
35 | x = tf.nn.relu(L3.forward_matmul(x))
36 |
37 | weights, prune_mask = load_weights(weights_dir, 'fc2')
38 | L4 = FcLayerDeploy(weights, prune_mask, 'fc2')
39 | logits = L4.forward_matmul(x)
40 |
41 | labels = tf.placeholder(tf.float32, [None, 10])
42 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels, 1))
43 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
44 |
45 | from tensorflow.examples.tutorials.mnist import input_data
46 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
47 |
48 | sess = tf.Session()
49 |
50 | batches_acc = []
51 | for i in range(10):
52 |
53 | batch_x, batch_y = mnist.test.next_batch(1000)
54 | batch_x = np.reshape(batch_x,(-1, 28, 28, 1))
55 |
56 | batch_acc = sess.run(accuracy,feed_dict={x_PH: batch_x, labels: batch_y})
57 | batches_acc.append(batch_acc)
58 |
59 | acc = np.mean(batches_acc)
60 |
61 | print 'deploy accuracy:', acc
62 |
--------------------------------------------------------------------------------
/imgs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/architecture.png
--------------------------------------------------------------------------------
/imgs/prune.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/prune.png
--------------------------------------------------------------------------------
/imgs/quantize.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/quantize.png
--------------------------------------------------------------------------------
/imgs/sparse_conv1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/sparse_conv1.png
--------------------------------------------------------------------------------
/imgs/sparse_conv2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/sparse_conv2.png
--------------------------------------------------------------------------------
/imgs/train_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wojciechmo/deep-compression/dcbf90df9f6bb0639542fd5f34bbe680d524be44/imgs/train_accuracy.png
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import sys
5 |
6 | class LayerTrain(object):
7 |
8 | def __init__(self, in_depth, out_depth, N_clusters, name):
9 |
10 | self.name = name
11 |
12 | if 'conv' in name:
13 | self.w = tf.Variable(tf.random.normal([5, 5, in_depth, out_depth], stddev=0.1))
14 |
15 | elif 'fc' in name:
16 | self.w = tf.Variable(tf.random.normal([in_depth, out_depth], stddev=0.1))
17 |
18 | self.w_PH = tf.placeholder(tf.float32, self.w.shape)
19 | self.assign_w = tf.assign(self.w, self.w_PH)
20 | self.num_total_weights = np.prod(self.w.shape)
21 |
22 | # mask placeholder for pruning
23 | # ones - valid weights, zero - pruned weights
24 | self.pruning_mask_data = np.ones(self.w.shape, dtype=np.float32)
25 |
26 | self.N_clusters = N_clusters; # for quantization
27 |
28 | def forward(self,x):
29 |
30 | if 'conv' in self.name:
31 | return tf.nn.conv2d(x, self.w, strides=[1, 2, 2, 1], padding='SAME')
32 |
33 | elif 'fc' in self.name:
34 | return tf.matmul(x, self.w)
35 |
36 | def save_weights_histogram(self, sess, directory, iteration):
37 |
38 | w_data = sess.run(self.w).reshape(-1)
39 | valid_w_data = [x for x in w_data if x!=0.0]
40 |
41 | plt.grid(True)
42 | plt.hist(valid_w_data, 100, color='0.4')
43 | plt.gca().set_xlim([-0.4, 0.4])
44 | plt.savefig(directory + '/' + self.name + '-' + str(iteration), dpi=100)
45 | plt.gcf().clear()
46 |
47 | def save_weights(self, sess, directory):
48 |
49 | w_data = sess.run(self.w)
50 | np.save(directory + '/' + self.name + '-weights', w_data)
51 | np.save(directory + '/' + self.name + '-prune-mask', self.pruning_mask_data)
52 |
53 | # ------------------------------------------------------------------------
54 | # ----------------------------- pruning ----------------------------------
55 | # ------------------------------------------------------------------------
56 |
57 | # even though it's enough to impact forward pass in each train step and leave gradients updates intact
58 | # here impact only backward pass in each step - to stay consistent with quantization
59 | # theoretically weights can be pruned only once and from that point on gradients should be pruned in every iteration
60 | # but for numerical stability weights are also pruned in every iteration
61 |
62 | def prune_weights(self, sess, threshold):
63 |
64 | w_data = sess.run(self.w)
65 | # prune weights that are smaller then threshold - make them zero
66 | self.pruning_mask_data = (np.abs(w_data) >= threshold).astype(np.float32)
67 |
68 | print 'layer:', self.name
69 | print '\tremaining weights:', int(np.sum(self.pruning_mask_data))
70 | print '\ttotal weights:', self.num_total_weights
71 |
72 | sess.run(self.assign_w, feed_dict={self.w_PH: self.pruning_mask_data*w_data})
73 |
74 | def prune_weights_gradient(self, grad):
75 |
76 | return grad * self.pruning_mask_data
77 |
78 | # for numerical stability
79 | def prune_weights_update(self, sess):
80 |
81 | w_data = sess.run(self.w)
82 | sess.run(self.assign_w, feed_dict={self.w_PH: self.pruning_mask_data*w_data})
83 |
84 | # ------------------------------------------------------------------------
85 | # --------------------------- quantization -------------------------------
86 | # ------------------------------------------------------------------------
87 |
88 | def quantize_weights(self, sess):
89 |
90 | w_data = sess.run(self.w)
91 |
92 | # theoretically pruning mask should be taken into consideration to compute max and min data only among valid weights
93 | # but in practice with normal ditribution init there is 100% chances that min and max vals will be among valid weights
94 | max_val = np.max(w_data)
95 | min_val = np.min(w_data)
96 |
97 | # linearly initialize centroids between max and min
98 | self.centroids = np.linspace(min_val, max_val, self.N_clusters)
99 | w_data = np.expand_dims(w_data, 0)
100 |
101 | centroids_prev = np.copy(self.centroids)
102 |
103 | for i in range(20):
104 |
105 | if 'conv' in self.name:
106 | distances = np.abs(w_data - np.reshape(self.centroids,(-1, 1, 1, 1, 1)))
107 | distances = np.transpose(distances, (1,2,3,4,0))
108 |
109 | elif 'fc' in self.name:
110 | distances = np.abs(w_data - np.reshape(self.centroids,(-1, 1, 1)))
111 | distances = np.transpose(distances, (1,2,0))
112 |
113 | classes = np.argmin(distances, axis=-1)
114 |
115 | self.clusters_masks = []
116 | for i in range(self.N_clusters):
117 |
118 | cluster_mask = (classes == i).astype(np.float32) * self.pruning_mask_data
119 | self.clusters_masks.append(cluster_mask)
120 |
121 | num_weights_assigned = np.sum(cluster_mask)
122 |
123 | if num_weights_assigned!=0:
124 | self.centroids[i] = np.sum(cluster_mask * w_data) / num_weights_assigned
125 | else: # do not modify
126 | pass
127 |
128 | if np.array_equal(centroids_prev, self.centroids):
129 | break
130 |
131 | centroids_prev = np.copy(self.centroids)
132 |
133 | self.quantize_weights_update(sess)
134 |
135 | print 'layer:', self.name
136 | print '\tcentroids:', self.centroids
137 |
138 | def group_and_reduce_gradient(self, grad):
139 |
140 | grad_out = np.zeros(self.w.shape, dtype=np.float32)
141 |
142 | for i in range(self.N_clusters):
143 |
144 | cluster_mask = self.clusters_masks[i]
145 | centroid_grad = np.sum(grad * cluster_mask)
146 |
147 | grad_out = grad_out + cluster_mask * centroid_grad
148 |
149 | return grad_out
150 |
151 | # for numerical stability
152 | def quantize_centroids_update(self, sess):
153 |
154 | w_data = sess.run(self.w)
155 |
156 | for i in range(self.N_clusters):
157 |
158 | cluster_mask = self.clusters_masks[i]
159 | cluster_count = np.sum(cluster_mask)
160 |
161 | if cluster_count!=0:
162 | self.centroids[i] = np.sum(cluster_mask * w_data) / cluster_count
163 | else: # do not modify
164 | pass
165 |
166 | # for numerical stability
167 | def quantize_weights_update(self, sess):
168 |
169 | w_data_updated = np.zeros(self.w.shape, dtype=np.float32)
170 |
171 | for i in range(self.N_clusters):
172 |
173 | cluster_mask = self.clusters_masks[i]
174 | centroid = self.centroids[i]
175 |
176 | w_data_updated = w_data_updated + cluster_mask * centroid
177 |
178 | sess.run(self.assign_w, feed_dict={self.w_PH: self.pruning_mask_data * w_data_updated})
179 |
180 | #-------------------------------------------------------------------------------------------------
181 | #-------------------------------------------------------------------------------------------------
182 | #-------------------------------------------------------------------------------------------------
183 |
184 | class FcLayerDeploy(object):
185 |
186 | def __init__(self, matrix, prune_mask, name, dense=False):
187 |
188 | assert matrix.shape == prune_mask.shape
189 |
190 | self.dense = dense
191 | self.N_in, self.N_out = matrix.shape
192 |
193 | if self.dense == False:
194 |
195 | indices, values, dense_shape = [], [], [self.N_in, self.N_out] # sparse matrix representation
196 |
197 | for i in range(self.N_in):
198 | for j in range(self.N_out):
199 |
200 | # pruning mask: ones - valid weights, zero - pruned weights
201 | if prune_mask[i][j] == 0.0:
202 | continue
203 |
204 | indices.append([i,j])
205 | values.append(matrix[i][j])
206 |
207 | self.w_matrix = tf.SparseTensor(indices, values, dense_shape) # tf sparse matrix
208 |
209 | else:
210 |
211 | self.w_matrix = tf.constant(matrix * prune_mask) # tf dense matrix
212 |
213 | print 'layer:', name
214 | print '\tvalid matrix weights:', np.sum(prune_mask)
215 | print '\ttotal matrix weights:', np.product(self.w_matrix.shape)
216 |
217 | def forward_matmul(self, x):
218 |
219 | if self.dense == False:
220 |
221 | w = tf.sparse.transpose(self.w_matrix, (1, 0))
222 | x = tf.transpose(x, (1, 0))
223 | x = tf.sparse.matmul(w, x) # only left matrix can be sparse hence transpositions
224 | x = tf.transpose(x ,(1, 0))
225 |
226 | else:
227 |
228 | x = tf.matmul(x, self.w_matrix)
229 |
230 | return x
231 |
232 | #-------------------------------------------------------------------------------------------------
233 | #-------------------------------------------------------------------------------------------------
234 | #-------------------------------------------------------------------------------------------------
235 |
236 | class ConvLayerDeploy(object):
237 |
238 | def __init__(self, tensor, prune_mask, H_in, W_in, stride, name, dense=False):
239 |
240 | assert tensor.shape == prune_mask.shape
241 |
242 | self.stride = stride
243 | self.dense = dense
244 |
245 | if self.dense == False:
246 | indices, values, dense_shape = self.tensor_to_matrix(tensor, prune_mask, H_in, W_in, stride)
247 | self.w_matrix = tf.SparseTensor(indices, values, dense_shape) # tf sparse matrix
248 |
249 | else:
250 | matrix = self.tensor_to_matrix(tensor, prune_mask, H_in, W_in, stride)
251 | self.w_matrix = tf.constant(matrix) # tf dense matrix
252 |
253 | self.w_tensor = tf.constant(tensor * prune_mask)
254 |
255 | print 'layer:', name
256 | print '\tvalid matrix weights:', int(np.sum(prune_mask))
257 | print '\ttotal tensor weights:', np.product(self.w_tensor.shape)
258 | print '\ttotal matrix weights:', np.product(self.w_matrix.shape)
259 |
260 | def get_linear_pos(self, i, j, W): # row major
261 |
262 | return i * W + j
263 |
264 | def tensor_to_matrix(self, tensor, prune_mask, H_in, W_in, stride):
265 |
266 | # assume padding type 'SAME' and padding value 0
267 |
268 | H_out = int(H_in +1)/stride # padding 'SAME'
269 | W_out = int(W_in +1)/stride # padding 'SAME'
270 | H_in = int(H_in)
271 | W_in = int(W_in)
272 |
273 | kH, kW, D_in, D_out = tensor.shape
274 |
275 | self.D_out = D_out
276 | self.H_out = H_out
277 | self.W_out = W_out
278 |
279 | if self.dense == False:
280 | indices, values, dense_shape = [], [], [H_in * W_in * D_in, H_out* W_out * D_out] # sparse matrix
281 | else:
282 | matrix = np.zeros((H_in * W_in * D_in, H_out* W_out * D_out), dtype=np.float32) # dense matrix
283 |
284 | for d_in in range(D_in):
285 | for d_out in range(D_out):
286 |
287 | # tf.nn.conv2d implementation doesn't go from top-left spatial location but from bottom-right
288 | for i_in_center in np.arange(H_in-1, -1, -stride): # kernel input center for first axis
289 | for j_in_center in np.arange(W_in-1, -1, -stride): # kernel input center for second axis
290 |
291 | i_out = int(i_in_center / stride)
292 | j_out = int(j_in_center / stride)
293 |
294 | for i in range(kH):
295 |
296 | i_in = i_in_center + i - kH/2
297 |
298 | if i_in < 0 or i_in >= H_in: # padding value 0
299 | continue
300 |
301 | for j in range(kW):
302 |
303 | j_in = j_in_center + j -kW/2
304 |
305 | if j_in < 0 or j_in >= W_in: # padding value 0
306 | continue
307 |
308 | # pruning mask: ones - valid weights, zero - pruned weights
309 | if prune_mask[i][j][d_in][d_out] == 0.0:
310 | continue
311 |
312 | pos_in = self.get_linear_pos(i_in, j_in, W_in) + d_in * H_in * W_in
313 | pos_out = self.get_linear_pos(i_out, j_out, W_out) + d_out * H_out * W_out
314 |
315 | if self.dense == False:
316 | indices.append([pos_in, pos_out])
317 | values.append(tensor[i][j][d_in][d_out])
318 | else:
319 | matrix[pos_in][pos_out] = tensor[i][j][d_in][d_out]
320 |
321 | if self.dense == False:
322 | return indices, values, dense_shape
323 | else:
324 | return matrix
325 |
326 | def forward_matmul_preprocess(self, x):
327 |
328 | x = tf.transpose(x, (0, 3, 1, 2))
329 | x = tf.reshape(x,(-1, np.product(x.shape[1:])))
330 |
331 | return x
332 |
333 | def forward_matmul_postprocess(self, x):
334 |
335 | x = tf.reshape(x,(-1, self.D_out, self.H_out, self.W_out))
336 | x = tf.transpose(x, (0, 2, 3, 1))
337 |
338 | return x
339 |
340 | def forward_matmul(self, x):
341 |
342 | if self.dense == False:
343 | w = tf.sparse.transpose(self.w_matrix, (1, 0))
344 | x = tf.transpose(x, (1, 0))
345 | x = tf.sparse.matmul(w, x) # only left matrix can be sparse hence transpositions
346 | x = tf.transpose(x, (1, 0))
347 | else:
348 | x = tf.matmul(x, self.w_matrix)
349 |
350 | return x
351 |
352 | def forward_conv(self, x):
353 |
354 | return tf.nn.conv2d(x, self.w_tensor, strides=[1, self.stride, self.stride, 1], padding='SAME')
355 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import sys, os, shutil
5 |
6 | from layers import LayerTrain
7 |
8 | def make_dir(directory):
9 |
10 | if os.path.exists(directory):
11 | shutil.rmtree(directory, ignore_errors=True)
12 | os.makedirs(directory)
13 |
14 | if __name__ == "__main__":
15 |
16 | histograms_dir = './histograms'
17 | weights_dir = './weights'
18 |
19 | make_dir(histograms_dir)
20 | make_dir(weights_dir)
21 |
22 | L1 = LayerTrain(1, 32, N_clusters=5, name='conv1')
23 | L2 = LayerTrain(32, 64, N_clusters=5, name='conv2')
24 | L3 = LayerTrain(7 * 7 * 64, 1024, N_clusters=5, name='fc1')
25 | L4 = LayerTrain(1024, 10, N_clusters=5, name='fc2')
26 |
27 | LAYERS = [L1, L2, L3, L4]
28 | LAYERS_WIEGHTS = [L1.w, L2.w, L3.w, L4.w]
29 |
30 | x_PH = tf.placeholder(tf.float32, [None, 28, 28, 1])
31 | x = tf.nn.relu(L1.forward(x_PH))
32 | x = tf.nn.relu(L2.forward(x))
33 | x = tf.reshape(x, (-1, int(np.product(x.shape[1:]))))
34 | x = tf.nn.relu(L3.forward(x))
35 | logits = L4.forward(x)
36 |
37 | preds = tf.nn.softmax(logits)
38 | labels = tf.placeholder(tf.float32, [None, 10])
39 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits))
40 |
41 | optimizer = tf.train.AdamOptimizer(1e-4)
42 | gradients_vars = optimizer.compute_gradients(loss, LAYERS_WIEGHTS)
43 | grads = [grad for grad, var in gradients_vars]
44 | train_step = optimizer.apply_gradients(gradients_vars)
45 |
46 | correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
47 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
48 |
49 | sess = tf.Session()
50 | sess.run(tf.initialize_all_variables())
51 |
52 | from tensorflow.examples.tutorials.mnist import input_data
53 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
54 |
55 | iters = []
56 | iters_acc = []
57 |
58 | for i in range(1500):
59 |
60 | batch_x, batch_y = mnist.train.next_batch(50)
61 | batch_x = np.reshape(batch_x,(-1, 28, 28,1))
62 |
63 | feed_dict={x_PH: batch_x, labels: batch_y}
64 |
65 | # ------------------------------------------------------
66 | # --------------- full network training ----------------
67 | # ------------------------------------------------------
68 |
69 | if i < 500:
70 | sess.run(train_step, feed_dict=feed_dict)
71 |
72 | # ------------------------------------------------------
73 | # ---------------------- pruning -----------------------
74 | # ------------------------------------------------------
75 |
76 | elif i >= 500 and i < 1000:
77 |
78 | # prune from time to time, finetune in the meantime
79 | if i%500==0:
80 | print 'iter:', i, 'prune weights'
81 | for L in LAYERS:
82 | L.prune_weights(sess, threshold=0.1)
83 |
84 | grads_data = sess.run(grads, feed_dict={x_PH: batch_x, labels: batch_y})
85 | feed_dict = {}
86 | for L, grad, grad_data in zip(LAYERS, grads, grads_data):
87 | pruned_grad_data = L.prune_weights_gradient(grad_data)
88 | feed_dict[grad] = pruned_grad_data
89 |
90 | sess.run(train_step, feed_dict=feed_dict)
91 |
92 | # for numerical stability
93 | for L in LAYERS:
94 | L.prune_weights_update(sess)
95 |
96 | # ------------------------------------------------------
97 | # ------------------- quantization ---------------------
98 | # ------------------------------------------------------
99 |
100 | else:
101 |
102 | # quantize only once and then finetune
103 | if i==1000:
104 | print 'iter:', i, "quantize weights"
105 | for L in LAYERS:
106 | L.quantize_weights(sess)
107 |
108 | grads_data = sess.run(grads, feed_dict={x_PH: batch_x, labels: batch_y})
109 | feed_dict = {}
110 | for L, grad, grad_data in zip(LAYERS, grads, grads_data):
111 | grouped_grad_data = L.group_and_reduce_gradient(grad_data)
112 | feed_dict[grad] = grouped_grad_data
113 |
114 | sess.run(train_step, feed_dict=feed_dict)
115 |
116 | # for numerical stability
117 | for L in LAYERS:
118 | L.quantize_centroids_update(sess)
119 | L.quantize_weights_update(sess)
120 |
121 | # ------------------------------------------------------
122 | # --------------------- evaluation ---------------------
123 | # ------------------------------------------------------
124 |
125 | if i%10 == 0:
126 |
127 | batches_acc = []
128 | for j in range(10):
129 |
130 | batch_x, batch_y = mnist.test.next_batch(1000)
131 | batch_x = np.reshape(batch_x,(-1, 28, 28,1))
132 |
133 | batch_acc = sess.run(accuracy,feed_dict={x_PH: batch_x, labels: batch_y})
134 | batches_acc.append(batch_acc)
135 |
136 | acc = np.mean(batches_acc)
137 |
138 | iters.append(i)
139 | iters_acc.append(acc)
140 | print 'iter:', i, 'test accuracy:', acc
141 |
142 | for L in LAYERS:
143 | L.save_weights_histogram(sess, histograms_dir, i)
144 |
145 | for L in LAYERS:
146 | L.save_weights(sess, weights_dir)
147 |
148 | plt.figure(figsize=(10, 4))
149 | plt.ylabel('accuracy', fontsize=12)
150 | plt.xlabel('iteration', fontsize=12)
151 | plt.grid(True)
152 | plt.plot(iters, iters_acc, color='0.4')
153 | plt.savefig('./train_acc', dpi=1200)
154 |
155 | print 'Training finished'
156 |
--------------------------------------------------------------------------------
/unit_tests.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from layers import FcLayerDeploy, ConvLayerDeploy
5 |
6 | def test_ConvLayerDeploy():
7 |
8 | weights = np.random.uniform(-10.0, 10.0, size=(3, 3, 4, 5)).astype(np.float32)
9 | prune_mask = np.random.uniform(0.0, 2.0, size=(3, 3, 4, 5)).astype(np.int32).astype(np.float32)
10 |
11 | x = np.random.uniform(0.0, 1.0, size=(2, 14, 14, 4)).astype(np.float32)
12 | x = tf.constant(x)
13 |
14 | L = ConvLayerDeploy(weights, prune_mask, 14, 14, 2, 'conv')
15 |
16 | x_matmul_sparse = L.forward_matmul_preprocess(x)
17 | y_matmul_sparse = L.forward_matmul(x_matmul_sparse)
18 | y_matmul_sparse = L.forward_matmul_postprocess(y_matmul_sparse)
19 |
20 | y_conv = L.forward_conv(x)
21 |
22 | sess = tf.Session()
23 | y_matmul_sparse_data, y_conv_data = sess.run([y_matmul_sparse, y_conv])
24 |
25 | assert np.mean(np.abs(y_matmul_sparse_data - y_conv_data)) < 1e-6
26 |
27 | def test_FcLayerDeploy():
28 |
29 | weights = np.random.uniform(-10.0, 10.0, size=(5, 10)).astype(np.float32)
30 | prune_mask = np.random.uniform(0.0, 2.0, size=(5, 10)).astype(np.int32).astype(np.float32)
31 |
32 | x = np.random.uniform(0.0, 1.0, size=(2, 5)).astype(np.float32)
33 | x = tf.constant(x)
34 |
35 | L_sparse = FcLayerDeploy(weights, prune_mask, 'fc')
36 | L_dense = FcLayerDeploy(weights, prune_mask, 'fc', dense=True)
37 |
38 | y_sparse = L_sparse.forward_matmul(x)
39 | y_dense = L_dense.forward_matmul(x)
40 |
41 | sess = tf.Session()
42 | y_sparse_data, y_dense_data = sess.run([y_sparse, y_dense])
43 |
44 | assert np.mean(np.abs(y_sparse_data - y_dense_data)) < 1e-6
45 |
--------------------------------------------------------------------------------