├── dataset.py ├── empirical_priors.py ├── extended_keras.py ├── figures ├── arrow.jpg ├── han_clustering.png ├── han_cutandtrain.gif ├── han_final_cut.png ├── han_pretraining.gif ├── han_retraining.gif ├── post-processed.png ├── post-processed_log.png ├── reference.png ├── reference_log.png ├── retrained.png ├── retrained_log.png └── retraining.gif ├── helpers.py ├── optimizers.py └── tutorial.ipynb /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Loading the MNIST in the right format. 6 | 7 | Implementation is close to [1]. 8 | 9 | Karen Ullrich, Jan 2017 10 | 11 | ... [1] [Keras Tutorial on CNNs](https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py) 12 | """ 13 | 14 | from numpy import transpose 15 | 16 | from keras import backend as K 17 | from keras.datasets import mnist as MNIST 18 | from keras.utils import np_utils 19 | 20 | 21 | def mnist(): 22 | img_rows, img_cols = 28, 28 23 | nb_classes = 10 24 | # the data, shuffled and split between train and test sets 25 | (X_train, y_train), (X_test, y_test) = MNIST.load_data() 26 | 27 | X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) 28 | X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols) 29 | 30 | if K._BACKEND == "tensorflow": 31 | X_train = transpose(X_train, axes=[0, 2, 3, 1]) 32 | X_test = transpose(X_test, axes=[0, 2, 3, 1]) 33 | 34 | X_train = X_train.astype('float32') 35 | X_test = X_test.astype('float32') 36 | X_train /= 255 37 | X_test /= 255 38 | 39 | print("Successfully loaded %d train samples and %d test samples." % (X_train.shape[0], X_test.shape[0])) 40 | 41 | # convert class vectors to binary class matrices 42 | Y_train = np_utils.to_categorical(y_train, nb_classes) 43 | Y_test = np_utils.to_categorical(y_test, nb_classes) 44 | 45 | return [X_train, X_test], [Y_train, Y_test], [img_rows, img_cols], nb_classes 46 | -------------------------------------------------------------------------------- /empirical_priors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Keras Layer for weight quantization regularizers. 6 | 7 | Karen Ullrich, Jan 2017 8 | """ 9 | import numpy as np 10 | 11 | from keras import backend as K 12 | from keras.engine.topology import Layer 13 | 14 | from helpers import special_flatten 15 | from extended_keras import logsumexp 16 | 17 | if K._BACKEND == "tensorflow": 18 | import tensorflow as tf 19 | 20 | class GaussianMixturePrior(Layer): 21 | """A Gaussian Mixture prior for Neural Networks """ 22 | def __init__(self, nb_components, network_weights, pretrained_weights, pi_zero, **kwargs): 23 | self.nb_components = nb_components 24 | self.network_weights = [K.flatten(w) for w in network_weights] 25 | self.pretrained_weights = special_flatten(pretrained_weights) 26 | self.pi_zero = pi_zero 27 | 28 | super(GaussianMixturePrior, self).__init__(**kwargs) 29 | 30 | def build(self, input_shape): 31 | J = self.nb_components 32 | 33 | # create trainable ... 34 | # ... means 35 | init_mean = np.linspace(-0.6, 0.6, J - 1) 36 | self.means = K.variable(init_mean, name='means') 37 | # ... the variance (we will work in log-space for more stability) 38 | init_stds = np.tile(0.25, J) 39 | init_gamma = - np.log(np.power(init_stds, 2)) 40 | self.gammas = K.variable(init_gamma, name='gammas') 41 | # ... the mixing proportions 42 | init_mixing_proportions = np.ones((J - 1)) 43 | init_mixing_proportions *= (1. - self.pi_zero) / (J - 1) 44 | self.rhos = K.variable(np.log(init_mixing_proportions), name='rhos') 45 | # Finally, add the variables to the trainable parameters 46 | self.trainable_weights = [self.means] + [self.gammas] + [self.rhos] 47 | 48 | def call(self, x, mask=None): 49 | J = self.nb_components 50 | loss = K.variable(0.) 51 | # here we stack together the trainable and non-trainable params 52 | # ... the mean vector 53 | means = K.concatenate([K.variable([0.]), self.means],axis=0) 54 | # ... the variances 55 | precision = K.exp(self.gammas) 56 | # ... the mixing proportions (we are using the log-sum-exp trick here) 57 | min_rho = K.min(self.rhos) 58 | mixing_proportions = K.exp(self.rhos - min_rho) 59 | mixing_proportions = (1 - self.pi_zero) * mixing_proportions / K.sum(mixing_proportions) 60 | mixing_proportions = K.concatenate([K.variable([self.pi_zero]), mixing_proportions],axis=0) 61 | 62 | # compute the loss given by the gaussian mixture 63 | for weights in self.network_weights: 64 | loss = loss + self.compute_loss(weights, mixing_proportions, means, precision) 65 | 66 | # GAMMA PRIOR ON PRECISION 67 | # ... for the zero component 68 | (alpha, beta) = (5e3,20e-1) 69 | neglogprop = (1 - alpha) * K.gather(self.gammas, [0]) + beta * K.gather(precision, [0]) 70 | loss = loss + K.sum(neglogprop) 71 | # ... and all other components 72 | alpha, beta = (2.5e2,1e-1) 73 | idx = np.arange(1, J) 74 | neglogprop = (1 - alpha) * K.gather(self.gammas, idx) + beta * K.gather(precision, idx) 75 | loss = loss + K.sum(neglogprop) 76 | 77 | return loss 78 | 79 | def compute_loss(self, weights, mixing_proportions, means, precision): 80 | if K._BACKEND == "tensorflow": 81 | diff = tf.expand_dims(weights, 1) - tf.expand_dims(means, 0) 82 | else: 83 | diff = weights[:, None] - means # shape: (nb_params, nb_components) 84 | unnormalized_log_likelihood = - (diff ** 2) / 2 * K.flatten(precision) 85 | Z = K.sqrt(precision / (2 * np.pi)) 86 | log_likelihood = logsumexp(unnormalized_log_likelihood, w=K.flatten(mixing_proportions * Z), axis=1) 87 | 88 | # return the neg. log-likelihood for the prior 89 | return - K.sum(log_likelihood) 90 | 91 | def get_output_shape_for(self, input_shape): 92 | return (input_shape[0], 1) 93 | -------------------------------------------------------------------------------- /extended_keras.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Methods to compliment the keras engine 6 | 7 | Karen Ullrich, Jan 2017 8 | """ 9 | import os 10 | import numpy as np 11 | 12 | import matplotlib.pyplot as plt 13 | import seaborn as sns 14 | from IPython import display 15 | import imageio 16 | 17 | import keras 18 | from keras import backend as K 19 | 20 | from helpers import special_flatten 21 | 22 | 23 | # --------------------------------------------------------- 24 | # helpers 25 | # --------------------------------------------------------- 26 | 27 | def collect_trainable_weights(layer): 28 | """Collects all `trainable_weights` attributes, 29 | excluding any sublayers where `trainable` is set the `False`. 30 | """ 31 | trainable = getattr(layer, 'trainable', True) 32 | if not trainable: 33 | return [] 34 | weights = [] 35 | if layer.__class__.__name__ == 'Sequential': 36 | for sublayer in layer.flattened_layers: 37 | weights += collect_trainable_weights(sublayer) 38 | elif layer.__class__.__name__ == 'Model': 39 | for sublayer in layer.layers: 40 | weights += collect_trainable_weights(sublayer) 41 | else: 42 | weights += layer.trainable_weights 43 | # dedupe weights 44 | weights = list(set(weights)) 45 | # TF variables have auto-generated the name, while Theano has auto-generated the auto_name variable. 46 | # name in Theano is sometimes None. 47 | # However, to work save_model() and load_model() properly, weights must be sorted by names. 48 | if weights: 49 | if "theano" == K.backend(): 50 | weights.sort(key=lambda x: x.name if x.name else x.auto_name) 51 | else: 52 | weights.sort(key=lambda x: x.name) 53 | return weights 54 | 55 | 56 | def extract_weights(model): 57 | """Extract symbolic, trainable weights from a Model.""" 58 | trainable_weights = [] 59 | for layer in model.layers: 60 | trainable_weights += collect_trainable_weights(layer) 61 | return trainable_weights 62 | 63 | 64 | # --------------------------------------------------------- 65 | # objectives 66 | # --------------------------------------------------------- 67 | 68 | def identity_objective(y_true, y_pred): 69 | """Hack to turn Keras' Layer engine into an empirical prior on the weights""" 70 | return y_pred 71 | 72 | 73 | # --------------------------------------------------------- 74 | # logsumexp 75 | # --------------------------------------------------------- 76 | 77 | def logsumexp(t, w=None, axis=1): 78 | """ 79 | t... tensor 80 | w... weight tensor 81 | """ 82 | 83 | t_max = K.max(t, axis=axis, keepdims=True) 84 | 85 | if w is not None: 86 | tmp = w * K.exp(t - t_max) 87 | else: 88 | tmp = K.exp(t - t_max) 89 | 90 | out = K.sum(tmp, axis=axis) 91 | out = K.log(out) 92 | 93 | t_max = K.max(t, axis=axis) 94 | 95 | return out + t_max 96 | 97 | # --------------------------------------------------------- 98 | # Callbacks 99 | # --------------------------------------------------------- 100 | 101 | class VisualisationCallback(keras.callbacks.Callback): 102 | """A callback for visualizing the progress in training.""" 103 | 104 | def __init__(self, model, X_test, Y_test, epochs): 105 | 106 | self.model = model 107 | self.X_test = X_test 108 | self.Y_test = Y_test 109 | self.epochs = epochs 110 | 111 | super(VisualisationCallback, self).__init__() 112 | 113 | def on_train_begin(self, logs={}): 114 | self.W_0 = self.model.get_weights() 115 | 116 | def on_epoch_begin(self, epoch, logs={}): 117 | self.plot_histogram(epoch) 118 | 119 | def on_train_end(self, logs={}): 120 | self.plot_histogram(epoch=self.epochs) 121 | images = [] 122 | filenames = ["./.tmp%d.png" % epoch for epoch in np.arange(self.epochs + 1)] 123 | for filename in filenames: 124 | images.append(imageio.imread(filename)) 125 | os.remove(filename) 126 | imageio.mimsave('./figures/retraining.gif', images, duration=.5) 127 | 128 | def plot_histogram(self, epoch): 129 | # get network weights 130 | W_T = self.model.get_weights() 131 | W_0 = self.W_0 132 | weights_0 = np.squeeze(special_flatten(W_0[:-3])) 133 | weights_T = np.squeeze(special_flatten(W_T[:-3])) 134 | # get means, variances and mixing proportions 135 | mu_T = np.concatenate([np.zeros(1), W_T[-3]]).flatten() 136 | prec_T = np.exp(W_T[-2]) 137 | var_T = 1. / prec_T 138 | std_T = np.sqrt(var_T) 139 | pi_T = (np.exp(W_T[-1])) 140 | # plot histograms and GMM 141 | x0 = -1.2 142 | x1 = 1.2 143 | I = np.random.permutation(len(weights_0)) 144 | f = sns.jointplot(weights_0[I], weights_T[I], size=8, kind="scatter", color="g", stat_func=None, edgecolor='w', 145 | marker='o', joint_kws={"s": 8}, marginal_kws=dict(bins=1000), ratio=4) 146 | f.ax_joint.hlines(mu_T, x0, x1, lw=0.5) 147 | 148 | for k in range(len(mu_T)): 149 | if k == 0: 150 | f.ax_joint.fill_between(np.linspace(x0, x1, 10), mu_T[k] - 2 * std_T[k], mu_T[k] + 2 * std_T[k], 151 | color='blue', alpha=0.1) 152 | else: 153 | f.ax_joint.fill_between(np.linspace(x0, x1, 10), mu_T[k] - 2 * std_T[k], mu_T[k] + 2 * std_T[k], 154 | color='red', alpha=0.1) 155 | score = \ 156 | self.model.evaluate({'input': self.X_test, }, {"error_loss": self.Y_test, "complexity_loss": self.Y_test, }, 157 | verbose=0)[3] 158 | sns.plt.title("Epoch: %d /%d\nTest accuracy: %.4f " % (epoch, self.epochs, score)) 159 | f.ax_marg_y.set_xscale("log") 160 | f.set_axis_labels("Pretrained", "Retrained") 161 | f.ax_marg_x.set_xlim(-1, 1) 162 | f.ax_marg_y.set_ylim(-1, 1) 163 | display.clear_output() 164 | f.savefig("./.tmp%d.png" % epoch, bbox_inches='tight') 165 | plt.show() 166 | -------------------------------------------------------------------------------- /figures/arrow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/arrow.jpg -------------------------------------------------------------------------------- /figures/han_clustering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/han_clustering.png -------------------------------------------------------------------------------- /figures/han_cutandtrain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/han_cutandtrain.gif -------------------------------------------------------------------------------- /figures/han_final_cut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/han_final_cut.png -------------------------------------------------------------------------------- /figures/han_pretraining.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/han_pretraining.gif -------------------------------------------------------------------------------- /figures/han_retraining.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/han_retraining.gif -------------------------------------------------------------------------------- /figures/post-processed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/post-processed.png -------------------------------------------------------------------------------- /figures/post-processed_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/post-processed_log.png -------------------------------------------------------------------------------- /figures/reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/reference.png -------------------------------------------------------------------------------- /figures/reference_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/reference_log.png -------------------------------------------------------------------------------- /figures/retrained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/retrained.png -------------------------------------------------------------------------------- /figures/retrained_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/retrained_log.png -------------------------------------------------------------------------------- /figures/retraining.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial-SoftWeightSharingForNNCompression/d1cb9025fe886f7b57b47ffc6d11370ea76a7a6c/figures/retraining.gif -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Methods for gerneral purpose 6 | 7 | Karen Ullrich, Sep 2016 8 | """ 9 | 10 | import numpy as np 11 | from scipy.misc import logsumexp 12 | 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | # --------------------------------------------------------- 16 | # RESHAPING LISTS FILLED WITH ARRAYS 17 | # --------------------------------------------------------- 18 | 19 | def special_flatten(arraylist): 20 | """Flattens the output of model.get_weights()""" 21 | out = np.concatenate([array.flatten() for array in arraylist]) 22 | return out.reshape((len(out), 1)) 23 | 24 | 25 | def reshape_like(in_array, shaped_array): 26 | "Inverts special_flatten" 27 | flattened_array = list(in_array) 28 | out = np.copy(shaped_array) 29 | for i, array in enumerate(shaped_array): 30 | num_samples = array.size 31 | dummy = flattened_array[:num_samples] 32 | del flattened_array[:num_samples] 33 | out[i] = np.asarray(dummy).reshape(array.shape) 34 | return out 35 | 36 | 37 | # --------------------------------------------------------- 38 | # DISCRETESIZE 39 | # --------------------------------------------------------- 40 | 41 | def merger(inputs): 42 | """Comparing and merging components.""" 43 | for _ in xrange(3): 44 | lists = [] 45 | for inpud in inputs: 46 | for i in inpud: 47 | tmp = 1 48 | for l in lists: 49 | if i in l: 50 | for j in inpud: 51 | l.append(j) 52 | tmp = 0 53 | if tmp is 1: 54 | lists.append(list(inpud)) 55 | lists = [np.unique(l) for l in lists] 56 | inputs = lists 57 | return lists 58 | 59 | 60 | def KL(means, logprecisions): 61 | """Compute the KL-divergence between 2 Gaussian Components.""" 62 | precisions = np.exp(logprecisions) 63 | return 0.5 * (logprecisions[0] - logprecisions[1]) + precisions[1] / 2. * ( 64 | 1. / precisions[0] + (means[0] - means[1]) ** 2) - 0.5 65 | 66 | 67 | def compute_responsibilies(xs, mus, logprecisions, pis): 68 | "Computing the unnormalized responsibilities." 69 | xs = xs.flatten() 70 | K = len(pis) 71 | W = len(xs) 72 | responsibilies = np.zeros((K, len(xs))) 73 | for k in xrange(K): 74 | # Not normalized!!! 75 | responsibilies[k] = pis[k] * np.exp(0.5 * logprecisions[k]) * np.exp( 76 | - np.exp(logprecisions[k]) / 2 * (xs - mus[k]) ** 2) 77 | return np.argmax(responsibilies, axis=0) 78 | 79 | 80 | def discretesize(W, pi_zero=0.999): 81 | # flattening hte weights 82 | weights = special_flatten(W[:-3]) 83 | 84 | means = np.concatenate([np.zeros(1), W[-3]]) 85 | logprecisions = W[-2] 86 | logpis = np.concatenate([np.log(pi_zero) * np.ones(1), W[-1]]) 87 | 88 | # classes K 89 | J = len(logprecisions) 90 | # compute KL-divergence 91 | K = np.zeros((J, J)) 92 | L = np.zeros((J, J)) 93 | 94 | for i, (m1, pr1, pi1) in enumerate(zip(means, logprecisions, logpis)): 95 | for j, (m2, pr2, pi2) in enumerate(zip(means, logprecisions, logpis)): 96 | K[i, j] = KL([m1, m2], [pr1, pr2]) 97 | L[i, j] = np.exp(pi1) * (pi1 - pi2 + K[i, j]) 98 | 99 | # merge 100 | idx, idy = np.where(K < 1e-10) 101 | lists = merger(np.asarray(zip(idx, idy))) 102 | # compute merged components 103 | # print lists 104 | new_means, new_logprecisions, new_logpis = [], [], [] 105 | 106 | for l in lists: 107 | new_logpis.append(logsumexp(logpis[l])) 108 | new_means.append( 109 | np.sum(means[l] * np.exp(logpis[l] - np.min(logpis[l]))) / np.sum(np.exp(logpis[l] - np.min(logpis[l])))) 110 | new_logprecisions.append(np.log( 111 | np.sum(np.exp(logprecisions[l]) * np.exp(logpis[l] - np.min(logpis[l]))) / np.sum( 112 | np.exp(logpis[l] - np.min(logpis[l]))))) 113 | 114 | new_means[np.argmin(np.abs(new_means))] = 0.0 115 | 116 | # compute responsibilities 117 | argmax_responsibilities = compute_responsibilies(weights, new_means, new_logprecisions, np.exp(new_logpis)) 118 | out = [new_means[i] for i in argmax_responsibilities] 119 | 120 | out = reshape_like(out, shaped_array=W[:-3]) 121 | return out 122 | 123 | 124 | def save_histogram(W_T,save, upper_bound=200): 125 | w = np.squeeze(special_flatten(W_T[:-3])) 126 | plt.figure(figsize=(10, 7)) 127 | sns.set(color_codes=True) 128 | plt.xlim(-1,1) 129 | plt.ylim(0,upper_bound) 130 | sns.distplot(w, kde=False, color="g",bins=200,norm_hist=True) 131 | plt.savefig("./"+save+".png", bbox_inches='tight') 132 | plt.close() 133 | 134 | 135 | plt.figure(figsize=(10, 7)) 136 | plt.yscale("log") 137 | sns.set(color_codes=True) 138 | plt.xlim(-1,1) 139 | plt.ylim(0.001,upper_bound*5) 140 | sns.distplot(w, kde=False, color="g",bins=200,norm_hist=True) 141 | plt.savefig("./"+save+"_log.png", bbox_inches='tight') 142 | plt.close() 143 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modified copy of Keras Adam Optimizer. 3 | 4 | Author: Karen Ullrich, Sep 2016 5 | 6 | """ 7 | 8 | from __future__ import print_function 9 | import numpy as np 10 | 11 | from keras import backend as K 12 | from keras.utils.generic_utils import get_from_module 13 | 14 | from keras.optimizers import Optimizer 15 | 16 | 17 | class Adam(Optimizer): 18 | """Adam optimizer. 19 | An extended Version. parameters that have been named can be trained with 20 | different hyperparams. 21 | 22 | Default parameters follow those provided in the original paper. 23 | 24 | # Arguments 25 | lr: float >= 0. Learning rate. 26 | beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. 27 | epsilon: float >= 0. Fuzz factor. 28 | 29 | # References 30 | - [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8) 31 | """ 32 | 33 | def __init__(self, 34 | lr=[0.001], 35 | beta_1=None, 36 | beta_2=None, 37 | epsilon=1e-8, 38 | decay=None, 39 | param_types_dict=[], 40 | **kwargs): 41 | 42 | super(Adam, self).__init__(**kwargs) 43 | 44 | if lr is None: 45 | lr = [0.001] 46 | self.__dict__.update(locals()) 47 | 48 | self.iterations = K.variable(0) 49 | # init params if not set 50 | l = len(lr) 51 | if beta_1 is None: 52 | beta_1 = list(np.tile([0.9], l)) 53 | if beta_2 is None: 54 | beta_2 = list(np.tile([0.999], l)) 55 | if decay is None: 56 | decay = list(np.tile([0.], l)) 57 | # add a tag for non-tagged variables 58 | self.param_types_dict = ['other'] + param_types_dict 59 | 60 | self.lr = {} 61 | self.beta_1, self.beta_2 = {}, {} 62 | self.decay, self.inital_decay = {}, {} 63 | 64 | for param_type in self.param_types_dict: 65 | self.lr[param_type] = K.variable(lr.pop(0)) 66 | self.beta_1[param_type] = K.variable(beta_1.pop(0)) 67 | self.beta_2[param_type] = K.variable(beta_2.pop(0)) 68 | tmp = decay.pop(0) 69 | self.decay[param_type] = K.variable(tmp) 70 | self.inital_decay[param_type] = tmp 71 | 72 | self.epsilon = epsilon 73 | 74 | def get_updates(self, params, constraints, loss): 75 | grads = self.get_gradients(loss, params) 76 | self.updates = [K.update_add(self.iterations, 1)] 77 | 78 | t = self.iterations + 1 79 | 80 | lr_t = {} 81 | for param_type in self.param_types_dict: 82 | lr = self.lr[param_type] 83 | if self.inital_decay[param_type] > 0: 84 | lr *= (1. / (1. + self.decay[param_type] * self.iterations[param_type])) 85 | lr_t[param_type] = lr * K.sqrt(1. - K.pow(self.beta_2[param_type], t)) / ( 86 | 1. - K.pow(self.beta_1[param_type], t)) 87 | 88 | shapes = [K.get_variable_shape(p) for p in params] 89 | # add param type here 90 | param_types = [] 91 | for param in params: 92 | tmp = None 93 | for param_type in self.param_types_dict: 94 | if param_type in param.name: 95 | tmp = param_type 96 | if tmp is None: 97 | tmp = 'other' 98 | param_types.append(tmp) 99 | 100 | if len(param_types) != len(params): 101 | print('Something went wrong with the naming of variables.') 102 | 103 | ms = [K.zeros(shape) for shape in shapes] 104 | vs = [K.zeros(shape) for shape in shapes] 105 | self.weights = [self.iterations] + ms + vs 106 | 107 | for p, param_type, g, m, v in zip(params, param_types, grads, ms, vs): 108 | m_t = (self.beta_1[param_type] * m) + (1. - self.beta_1[param_type]) * g 109 | v_t = (self.beta_2[param_type] * v) + (1. - self.beta_2[param_type]) * K.square(g) 110 | p_t = p - lr_t[param_type] * m_t / (K.sqrt(v_t) + self.epsilon) 111 | 112 | self.updates.append(K.update(m, m_t)) 113 | self.updates.append(K.update(v, v_t)) 114 | new_p = p_t 115 | # apply constraints 116 | if p in constraints: 117 | c = constraints[p] 118 | new_p = c(new_p) 119 | self.updates.append(K.update(p, new_p)) 120 | 121 | return self.updates 122 | 123 | @property 124 | def get_config(self): 125 | 126 | lr = {} 127 | beta_1, beta_2 = {}, {} 128 | decay, inital_decay = {}, {} 129 | 130 | for param_type in self.param_types_dict: 131 | lr[param_type] = float(K.get_value(self.lr[param_type])) 132 | beta_1[param_type] = float(K.get_value(self.beta_1[param_type])) 133 | beta_2[param_type] = float(K.get_value(self.beta_2[param_type])) 134 | decay[param_type] = float(K.get_value(self.decay[param_type])) 135 | inital_decay[param_type] = float(K.get_value(self.inital_decay[param_type])) 136 | 137 | config = {'lr': lr, 138 | 'beta_1': beta_1, 139 | 'beta_2': beta_2, 140 | 'epsilon': self.epsilon} 141 | 142 | base_config = super(Adam, self).get_config 143 | 144 | return dict(list(base_config.items()) + list(config.items())) 145 | 146 | 147 | # aliases 148 | adam = Adam 149 | 150 | 151 | def get(identifier, kwargs=None): 152 | return get_from_module(identifier, globals(), 'optimizer', 153 | instantiate=True, kwargs=kwargs) 154 | -------------------------------------------------------------------------------- /tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "deletable": true, 7 | "editable": true 8 | }, 9 | "source": [ 10 | "# A tutorial on 'Soft weight-sharing for Neural Network compression' \n", 11 | "\n", 12 | "## Introduction\n", 13 | "\n", 14 | "Recently, compression of neural networks has been a much-discussed issue. One reason for this is the desire to run and store them on mobile devices such as smartphones, robots or [Rasberry Pis](https://github.com/samjabrahams/tensorflow-on-raspberry-pi). Another problem is the energy consumption of a multimillion parameter network. When inference is run at scale the costs can quickly add up.\n", 15 | "\n", 16 | "Compressing a given neural network is a challenging task. While the idea of storing weights with low precision has been around for some time, ideas to store the weight matrix in a different storage format are recent. \n", 17 | "The latter proposal has led to the most successful compression scheme so far. \n", 18 | "Under the assumption that there is great redundancy within the weights, one can prune most of them. Consequently, only non-zero weights are being stored. These weights can further be compressed by qunatizing them.\n", 19 | "However, it is not clear how to infer the redundant weights given a trained neural network or to quantize the remaining weights. In the following, we will identify the problem more clearly and review one recent attempt to tackle it. For the following illustrations, I work with MNIST and LeNet-300-100 a simple 2-fully connected neural network with 300 and 100 units.\n", 20 | "\n", 21 | "Usually, when training an unregularized neural network, the distribution of weights looks somewhat like a Normal distribution centered at zero. For the proposed storage format this is not ideal. We would like a distribution that is sharply (ideally delta) peaked around some values with significant mass in the zero peak.\n", 22 | "\n", 23 | "Weight distribution while common training | | Desired distribution \n", 24 | ":-------------------------:|:------------:|:-------------------------:\n", 25 | "![](./figures/han_pretraining.gif \"title-1\")|| ![](./figures/han_clustering.png \"title-1\")\n", 26 | "\n", 27 | "Following we will shortly review how [Han et. al. (2016)](https://arxiv.org/abs/1510.00149), the proposers of this compression format and current state-of-the-art in compression, tackle the problem. \n", 28 | "The authors use a multistage pipeline: (i) re-training a pre-trained network with Gaussian prior aka L2-norm on the weights (ii) repeatedly cutting off all weights around a threshold close to zero and after that continue training with L2-norm, and (iii) clustering all weights and retraining again the cluster means.\n", 29 | "\n", 30 | "(i) Re-training with L2 regularization | (ii) Repetitiv Cutting and training \n", 31 | " :-------------------------:|:-------------------------:\n", 32 | "![](./figures/han_retraining.gif \"title-1\")|![](./figures/han_cutandtrain.gif \"title-1\")\n", 33 | "**(ii) Final stage before clustering** | **(iii) Clustering**\n", 34 | "![](./figures/han_final_cut.png \"title-1\")|![](./figures/han_clustering.png \"title-1\")\n", 35 | "\n", 36 | "\n", 37 | "Note that this pipeline is not a differentiable function. Furthermore, pruning and quantization are distinct stages. \n", 38 | "\n", 39 | "In contrast, we propose to sparsify and cluster weights in one differentiable retraining procedure. More precisely, we train the network weights with a Gaussian mixture model prior. \n", 40 | "This is an instance of an empirical Bayesian prior because the parameters in the prior are being learned as well. \n", 41 | "With this prior present, weights will naturally cluster together since that will allow the gaussian mixture to lower the variance and thus achieve higher probability. \n", 42 | "It is important, to carefully initialize the learning procedure for those priors because one might end up in a situation where the weights \"chase\" the mixture and the mixture the weights. \n", 43 | "\n", 44 | "Note that, even though compression seems to be a core topic of information theory, so far there has been little attention on this angle on things. While in our paper the emphasis lays on this information theoretic view, here we will restrict ourselves to a somewhat practical one.\n", 45 | "\n", 46 | "Following, we give a tutorial that shall serve as a practical guide to implementing empirical priors and in particular a Gaussian Mixture with an Inverse-Gamma prior on the variances. It is divided into 3 parts.\n", 47 | "\n", 48 | "\n", 49 | "* **PART 1:** Pretraining a Neural Network\n", 50 | "\n", 51 | "* **PART 2:** Re-train the network with an empirical Gaussian Mixture Prior with Inverse-Gamma hyper-prior on the variances. \n", 52 | "\n", 53 | "* **PART 3:** Post-process the re-trained network weights\n", 54 | "\n", 55 | "## PART 1: Pretraining a Neural Network \n", 56 | "\n", 57 | "\n", 58 | "First of all, we need a parameter heavy network to compress. In this first part of the tutorial, we train a simple -2 convolutional, 2 fully connected layer- neural network on MNIST with 642K paramters. \n", 59 | "___________________________\n", 60 | "\n", 61 | "We start by loading some essential libraries." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 1, 67 | "metadata": { 68 | "collapsed": false, 69 | "deletable": true, 70 | "editable": true 71 | }, 72 | "outputs": [ 73 | { 74 | "name": "stderr", 75 | "output_type": "stream", 76 | "text": [ 77 | "Using Theano backend.\n", 78 | "Using gpu device 0: GeForce GTX TITAN X (CNMeM is enabled with initial size: 70.0% of memory, cuDNN 5103)\n", 79 | "/home/karen/anaconda2/lib/python2.7/site-packages/theano/sandbox/cuda/__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5.\n", 80 | " warnings.warn(warn)\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "from __future__ import print_function\n", 86 | "import numpy as np\n", 87 | "%matplotlib inline\n", 88 | "import keras\n", 89 | "from IPython import display" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": { 95 | "deletable": true, 96 | "editable": true 97 | }, 98 | "source": [ 99 | "______________________\n", 100 | "Following, we load the MNIST dataset into memory." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 2, 106 | "metadata": { 107 | "collapsed": false, 108 | "deletable": true, 109 | "editable": true 110 | }, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "Successfully loaded 60000 train samples and 10000 test samples.\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "from dataset import mnist\n", 122 | "\n", 123 | "[X_train, X_test], [Y_train, Y_test], [img_rows, img_cols], nb_classes = mnist()" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "deletable": true, 130 | "editable": true 131 | }, 132 | "source": [ 133 | " ___________________________________________________\n", 134 | "\n", 135 | "Next, we choose a model. We decide in favor of a classical 2 convolutional, 2 fully connected layer network with ReLu activation." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 3, 141 | "metadata": { 142 | "collapsed": false, 143 | "deletable": true, 144 | "editable": true 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "____________________________________________________________________________________________________\n", 152 | "Layer (type) Output Shape Param # Connected to \n", 153 | "====================================================================================================\n", 154 | "input (InputLayer) (None, 1, 28, 28) 0 \n", 155 | "____________________________________________________________________________________________________\n", 156 | "convolution2d_1 (Convolution2D) (None, 25, 12, 12) 650 input[0][0] \n", 157 | "____________________________________________________________________________________________________\n", 158 | "convolution2d_2 (Convolution2D) (None, 50, 5, 5) 11300 convolution2d_1[0][0] \n", 159 | "____________________________________________________________________________________________________\n", 160 | "flatten_1 (Flatten) (None, 1250) 0 convolution2d_2[0][0] \n", 161 | "____________________________________________________________________________________________________\n", 162 | "dense_1 (Dense) (None, 500) 625500 flatten_1[0][0] \n", 163 | "____________________________________________________________________________________________________\n", 164 | "activation_1 (Activation) (None, 500) 0 dense_1[0][0] \n", 165 | "____________________________________________________________________________________________________\n", 166 | "dense_2 (Dense) (None, 10) 5010 activation_1[0][0] \n", 167 | "____________________________________________________________________________________________________\n", 168 | "error_loss (Activation) (None, 10) 0 dense_2[0][0] \n", 169 | "====================================================================================================\n", 170 | "Total params: 642,460\n", 171 | "Trainable params: 642,460\n", 172 | "Non-trainable params: 0\n", 173 | "____________________________________________________________________________________________________\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "from keras import backend as K\n", 179 | "\n", 180 | "from keras.models import Model\n", 181 | "from keras.layers import Input, Dense, Activation, Flatten, Convolution2D\n", 182 | "\n", 183 | "# We configure the input here to match the backend. If properly done this is a lot faster. \n", 184 | "if K._BACKEND == \"theano\":\n", 185 | " InputLayer = Input(shape=(1, img_rows, img_cols), name=\"input\")\n", 186 | "elif K._BACKEND == \"tensorflow\":\n", 187 | " InputLayer = Input(shape=(img_rows, img_cols,1), name=\"input\")\n", 188 | "\n", 189 | "# A classical architecture ...\n", 190 | "# ... with 3 convolutional layers,\n", 191 | "Layers = Convolution2D(25, 5, 5, subsample = (2,2), activation = \"relu\")(InputLayer)\n", 192 | "Layers = Convolution2D(50, 3, 3, subsample = (2,2), activation = \"relu\")(Layers)\n", 193 | "# ... and 2 fully connected layers.\n", 194 | "Layers = Flatten()(Layers)\n", 195 | "Layers = Dense(500)(Layers)\n", 196 | "Layers = Activation(\"relu\")(Layers)\n", 197 | "Layers = Dense(nb_classes)(Layers)\n", 198 | "PredictionLayer = Activation(\"softmax\", name =\"error_loss\")(Layers)\n", 199 | "\n", 200 | "# Fianlly, we create a model object:\n", 201 | "model = Model(input=[InputLayer], output=[PredictionLayer])\n", 202 | "\n", 203 | "model.summary()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": { 209 | "deletable": true, 210 | "editable": true 211 | }, 212 | "source": [ 213 | "---------------------------------------------------------------------------------------------------\n", 214 | "Next, we train the network for 100 epochs with the Adam optimizer. Let's see where our model gets us..." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 4, 220 | "metadata": { 221 | "collapsed": false, 222 | "deletable": true, 223 | "editable": true 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "from keras import optimizers\n", 228 | "\n", 229 | "epochs = 100\n", 230 | "batch_size = 256\n", 231 | "opt = optimizers.Adam(lr=0.001)\n", 232 | "\n", 233 | "model.compile(optimizer= opt,\n", 234 | " loss = {\"error_loss\": \"categorical_crossentropy\",},\n", 235 | " metrics=[\"accuracy\"])" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 5, 241 | "metadata": { 242 | "collapsed": false, 243 | "deletable": true, 244 | "editable": true 245 | }, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "Test accuracy: 0.9908\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "model.fit({\"input\": X_train, }, {\"error_loss\": Y_train},\n", 257 | " nb_epoch = epochs, batch_size = batch_size,\n", 258 | " verbose = 0, validation_data=(X_test, Y_test))\n", 259 | "\n", 260 | "score = model.evaluate(X_test, Y_test, verbose=0)\n", 261 | "print('Test accuracy:', score[1])" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": { 267 | "deletable": true, 268 | "editable": true 269 | }, 270 | "source": [ 271 | "***Note: *** The model should end up with approx. 0.9% error rate.\n", 272 | "___________________________________\n", 273 | "Fianlly, we save the model in case we need to reload it later, e.g. if you want to play around with the code ..." 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 6, 279 | "metadata": { 280 | "collapsed": false, 281 | "deletable": true, 282 | "editable": true 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "keras.models.save_model(model, \"./my_pretrained_net\")" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": { 292 | "deletable": true, 293 | "editable": true 294 | }, 295 | "source": [ 296 | "________________________________________\n", 297 | "__________________________________________\n", 298 | "\n", 299 | "## PART 2: Re-training the network with an empirical prior\n", 300 | "\n", 301 | "_____________________________________________________\n", 302 | "\n", 303 | "First of all, we load our pretrained model " 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 7, 309 | "metadata": { 310 | "collapsed": true, 311 | "deletable": true, 312 | "editable": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "pretrained_model = keras.models.load_model(\"./my_pretrained_net\")" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": { 322 | "deletable": true, 323 | "editable": true 324 | }, 325 | "source": [ 326 | "Following, we will initialize a 16 component Gaussian mixture as our empirical prior. We will learn all parameters in the prior but the mean and the mixing proportion of the zero component, we set $\\mu_0=0$ and $\\pi_0=0.99$, respectively. Furthermore, we put a Gamma hyper-prior on the precisions of the Gaussian mixture. We set the mean such that the expected variance is $0.02$. The variance of the hyper-prior is an estimate of how strongly the variance is regularized. Note that, the variance of the zero component has much more data (i.e. weight) evidence than the other components thus we put a stronger prior on it. Somewhat counterintuitive we found it beneficial to have wider and thus noisier expected variances. " 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 8, 332 | "metadata": { 333 | "collapsed": false, 334 | "deletable": true, 335 | "editable": true 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "from empirical_priors import GaussianMixturePrior\n", 340 | "from extended_keras import extract_weights\n", 341 | "\n", 342 | "pi_zero = 0.99\n", 343 | "\n", 344 | "RegularizationLayer = GaussianMixturePrior(nb_components=16, \n", 345 | " network_weights=extract_weights(model),\n", 346 | " pretrained_weights=pretrained_model.get_weights(), \n", 347 | " pi_zero=pi_zero,\n", 348 | " name=\"complexity_loss\")(Layers)\n", 349 | "\n", 350 | "model = Model(input = [InputLayer], output = [PredictionLayer, RegularizationLayer])" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "deletable": true, 357 | "editable": true 358 | }, 359 | "source": [ 360 | "We optimize the network again with ADAM, the learning rates for the network parameters, means, variances and mixing proportions may differ though." 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 9, 366 | "metadata": { 367 | "collapsed": false, 368 | "deletable": true, 369 | "editable": true 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "import optimizers \n", 374 | "from extended_keras import identity_objective\n", 375 | "\n", 376 | "tau = 0.003\n", 377 | "N = X_train.shape[0] \n", 378 | "\n", 379 | "opt = optimizers.Adam(lr = [5e-4,1e-4,3e-3,3e-3], #[unnamed, means, log(precition), log(mixing proportions)]\n", 380 | " param_types_dict = ['means','gammas','rhos'])\n", 381 | "\n", 382 | "model.compile(optimizer = opt,\n", 383 | " loss = {\"error_loss\": \"categorical_crossentropy\", \"complexity_loss\": identity_objective},\n", 384 | " loss_weights = {\"error_loss\": 1. , \"complexity_loss\": tau/N},\n", 385 | " metrics = ['accuracy'])" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": { 391 | "deletable": true, 392 | "editable": true 393 | }, 394 | "source": [ 395 | "We train our network for 30 epochs, each taking about 45s. You can watch the progress yourself. At each epoch, we compare the original weight distribution (histogram top) to the current distribution (log-scaled histogram right). The joint scatter plot in the middle shows how each weight changed.\n", 396 | "\n", 397 | "*Note* that we had to scale the histogram logarithmically otherwise it would be little informative due to the zero spike." 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 10, 403 | "metadata": { 404 | "collapsed": false, 405 | "deletable": true, 406 | "editable": true 407 | }, 408 | "outputs": [], 409 | "source": [ 410 | "from extended_keras import VisualisationCallback\n", 411 | "\n", 412 | "epochs = 30\n", 413 | "model.fit({\"input\": X_train,},\n", 414 | " {\"error_loss\" : Y_train, \"complexity_loss\": np.zeros((N,1))},\n", 415 | " nb_epoch = epochs,\n", 416 | " batch_size = batch_size,\n", 417 | " verbose = 1., callbacks=[VisualisationCallback(model,X_test,Y_test, epochs)])\n", 418 | "\n", 419 | "display.clear_output()" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 11, 425 | "metadata": { 426 | "collapsed": false, 427 | "deletable": true, 428 | "editable": true 429 | }, 430 | "outputs": [ 431 | { 432 | "data": { 433 | "text/html": [ 434 | "" 435 | ], 436 | "text/plain": [ 437 | "" 438 | ] 439 | }, 440 | "execution_count": 11, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | } 444 | ], 445 | "source": [ 446 | "display.Image(url='./figures/retraining.gif')" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": { 452 | "deletable": true, 453 | "editable": true 454 | }, 455 | "source": [ 456 | "## PART 3: Post-processing\n", 457 | "\n", 458 | "Now, the only thing that is left to do is setting each weight to the mean of the component that takes most responsibility for it i.e. quantising the weights. \n" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 12, 464 | "metadata": { 465 | "collapsed": false, 466 | "deletable": true, 467 | "editable": true 468 | }, 469 | "outputs": [], 470 | "source": [ 471 | "from helpers import discretesize\n", 472 | "\n", 473 | "retrained_weights = np.copy(model.get_weights())\n", 474 | "compressed_weights = np.copy(model.get_weights())\n", 475 | "compressed_weights[:-3] = discretesize(compressed_weights, pi_zero = pi_zero)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "Let us compare the accuracy of the reference, the retrained and the post-processed network." 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 13, 488 | "metadata": { 489 | "collapsed": false 490 | }, 491 | "outputs": [ 492 | { 493 | "name": "stdout", 494 | "output_type": "stream", 495 | "text": [ 496 | "MODEL ACCURACY\n", 497 | "Reference Network: 0.9908\n", 498 | "Retrained Network: 0.9903\n", 499 | "Post-processed Network: 0.9902\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "print(\"MODEL ACCURACY\")\n", 505 | "score = pretrained_model.evaluate({'input': X_test, },{\"error_loss\" : Y_test,}, verbose=0)[1]\n", 506 | "print(\"Reference Network: %0.4f\" %score)\n", 507 | "score = model.evaluate({'input': X_test, },{\"error_loss\" : Y_test, \"complexity_loss\": Y_test,}, verbose=0)[3]\n", 508 | "print(\"Retrained Network: %0.4f\" %score)\n", 509 | "model.set_weights(compressed_weights)\n", 510 | "score = model.evaluate({'input': X_test, },{\"error_loss\" : Y_test, \"complexity_loss\": Y_test,}, verbose=0)[3]\n", 511 | "print(\"Post-processed Network: %0.4f\" %score)" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "deletable": true, 518 | "editable": true 519 | }, 520 | "source": [ 521 | "Finally let us see how many weights have been pruned." 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": 14, 527 | "metadata": { 528 | "collapsed": false, 529 | "deletable": true, 530 | "editable": true 531 | }, 532 | "outputs": [ 533 | { 534 | "name": "stdout", 535 | "output_type": "stream", 536 | "text": [ 537 | "Non-zero weights: 7.43 %\n" 538 | ] 539 | } 540 | ], 541 | "source": [ 542 | "from helpers import special_flatten\n", 543 | "weights = special_flatten(compressed_weights[:-3]).flatten()\n", 544 | "print(\"Non-zero weights: %0.2f %%\" % (100.*np.count_nonzero(weights)/ weights.size) )" 545 | ] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": { 550 | "deletable": true, 551 | "editable": true 552 | }, 553 | "source": [ 554 | "As we can see in this naive implementation we got rid of 19 out of 20 weights. Furthermore note that we quantize weights with only 16 cluster means (aka 4 bit indexes). \n", 555 | "\n", 556 | "For better results (up to 0.5%) one may anneal $\\tau$, learn the mixing proportion for the zero spike with a beta prior on it for example and ideally optimize with some hyperparamter optimization of choice such as spearmint (I also wrote some example code for deep learning and spearmint).\n", 557 | "\n", 558 | "We finish this tutorial with a series of histograms showing the results of our procedure." 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 15, 564 | "metadata": { 565 | "collapsed": false 566 | }, 567 | "outputs": [], 568 | "source": [ 569 | "from helpers import save_histogram\n", 570 | "\n", 571 | "save_histogram(pretrained_model.get_weights(),save=\"figures/reference\")\n", 572 | "save_histogram(retrained_weights,save=\"figures/retrained\")\n", 573 | "save_histogram(compressed_weights,save=\"figures/post-processed\")" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": { 579 | "deletable": true, 580 | "editable": true 581 | }, 582 | "source": [ 583 | "|Weight distribution before retraining | Weight distribution after retraining| Weight distribution after post-processing \n", 584 | ":-------------------------:|:-------------------------:|:------------:|:-------------------------:\n", 585 | "histogram|![](./figures/reference.png)|| ![](./figures/post-processed.png)\n", 586 | "log-scaled histogram|![](./figures/reference_log.png)|| ![](./figures/post-processed_log.png)\n", 587 | "_______________________________\n", 588 | "### *Reference*\n", 589 | "\n", 590 | "The paper \"Soft weight-sharing for Neural Network compression\" has been accepted to ICLR 2017.\n", 591 | "\n", 592 | "\n", 593 | " @inproceedings{ullrich2017soft,\n", 594 | " title={Soft Weight-Sharing for Neural Network Compression},\n", 595 | " author={Ullrich, Karen and Meeds, Edward and Welling, Max},\n", 596 | " booktitle={ICLR 2017},\n", 597 | " year={2017}\n", 598 | " }" 599 | ] 600 | } 601 | ], 602 | "metadata": { 603 | "anaconda-cloud": {}, 604 | "celltoolbar": "Hide code", 605 | "kernelspec": { 606 | "display_name": "Python [Root]", 607 | "language": "python", 608 | "name": "Python [Root]" 609 | }, 610 | "language_info": { 611 | "codemirror_mode": { 612 | "name": "ipython", 613 | "version": 2 614 | }, 615 | "file_extension": ".py", 616 | "mimetype": "text/x-python", 617 | "name": "python", 618 | "nbconvert_exporter": "python", 619 | "pygments_lexer": "ipython2", 620 | "version": "2.7.11" 621 | } 622 | }, 623 | "nbformat": 4, 624 | "nbformat_minor": 0 625 | } 626 | --------------------------------------------------------------------------------