├── .gitignore ├── datasets └── var_u.mat ├── requirements.txt ├── LICENSE ├── demo.py ├── README.md ├── kde.py ├── simplebinmi.py ├── MNIST_SaveActivations.ipynb ├── utils.py ├── IBnet_SaveActivations.ipynb ├── loggingreporter.py ├── IBnet_ComputeMI.ipynb └── MNIST_ComputeMI.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *# 3 | rawdata/ 4 | .ipynb_checkpoints/ 5 | -------------------------------------------------------------------------------- /datasets/var_u.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artemyk/ibsgd/HEAD/datasets/var_u.mat -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | keras 5 | pathlib2 6 | tensorflow 7 | seaborn 8 | six -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Andrew Michael Saxe, Yamini Bansal, Joel Dapello, Madhu Advani, 4 | Artemy Kolchinsky, Brendan Daniel Tracey, David Daniel Cox 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Simple example of how to estimate MI between X and Y, where Y = f(X) + Noise(0, noise_variance) 2 | from __future__ import print_function 3 | import kde 4 | import keras.backend as K 5 | import numpy as np 6 | 7 | Y_samples = K.placeholder(ndim=2) 8 | 9 | noise_variance = 0.05 10 | entropy_func_upper = K.function([Y_samples,], [kde.entropy_estimator_kl(Y_samples, noise_variance),]) 11 | entropy_func_lower = K.function([Y_samples,], [kde.entropy_estimator_bd(Y_samples, noise_variance),]) 12 | 13 | data = np.random.random( size = (1000, 20) ) # N x dims 14 | H_Y_given_X = kde.kde_condentropy(data, noise_variance) 15 | H_Y_upper = entropy_func_upper([data,])[0] 16 | H_Y_lower = entropy_func_lower([data,])[0] 17 | 18 | print("Upper bound: %0.3f nats" % (H_Y_upper - H_Y_given_X)) 19 | print("Lower bound: %0.3f nats" % (H_Y_lower - H_Y_given_X)) 20 | 21 | # Alternative calculation, direct from distance matrices 22 | dims, N = kde.get_shape(K.variable(data)) 23 | dists = kde.Kget_dists(K.variable(data)) 24 | dists2 = dists / (2*noise_variance) 25 | mi2 = K.eval(-K.mean(K.logsumexp(-dists2, axis=1) - K.log(N))) 26 | print("Upper bound2: %0.3f nats" % mi2) 27 | 28 | 29 | dims, N = kde.get_shape(K.variable(data)) 30 | dists = kde.Kget_dists(K.variable(data)) 31 | dists2 = dists / (2*4*noise_variance) 32 | mi2 = K.eval(-K.mean(K.logsumexp(-dists2, axis=1) - K.log(N)) ) 33 | print("Lower bound2: %0.3f nats" % mi2) 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for On the Information Bottleneck Theory of Deep Learning 2 | 3 | Recently updated for Python 3, `tensorflow` 2.1.0, `keras` 2.3.1. 4 | 5 | Other requirements: `six`, `pathlib2`, `seaborn`. 6 | 7 | * `MNIST_SaveActivations.ipynb` is a jupyter notebook that trains on MNIST and saves (in a data directory) activations when run on test set inputs (as well as weight norms, &c.) for each epoch. 8 | 9 | * `MNIST_ComputeMI.ipynb` is a jupyter notebook that loads the data files, computes MI values, and does the infoplane plots and SNR plots for data created using `MNIST_SaveActivations.ipynb`. 10 | 11 | * `IBnet_SaveActivations.ipynb` is a jupyter notebook that recreates the network and data from https://github.com/ravidziv/IDNNs and saves activations, weight norms, &c. for each epoch for a single trial. 12 | 13 | * `IBnet_ComputeMI.ipynb` is a jupyter notebook that loads the data files created by `IBnet_SaveActivations.ipynb`, computes MI values, and does the infoplane plots and SNR plots. Note, for the full results of the paper, MI values and SNR was averaged for 50 distinct trials; MI and SNR for individual runs can vary substantially. 14 | 15 | * `demo.py` is a simple script showing how to compute MI between X and Y, where Y = f(X) + Noise. 16 | 17 | Andrew Michael Saxe, Yamini Bansal, Joel Dapello, Madhu Advani, Artemy Kolchinsky, Brendan Daniel Tracey, David Daniel Cox, On the Information Bottleneck Theory of Deep Learning, *ICLR 2018*. 18 | -------------------------------------------------------------------------------- /kde.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import keras.backend as K 3 | 4 | import numpy as np 5 | 6 | def Kget_dists(X): 7 | """Keras code to compute the pairwise distance matrix for a set of 8 | vectors specifie by the matrix X. 9 | """ 10 | x2 = K.expand_dims(K.sum(K.square(X), axis=1), 1) 11 | dists = x2 + K.transpose(x2) - 2*K.dot(X, K.transpose(X)) 12 | return dists 13 | 14 | def get_shape(x): 15 | dims = K.cast( K.shape(x)[1], K.floatx() ) 16 | N = K.cast( K.shape(x)[0], K.floatx() ) 17 | return dims, N 18 | 19 | def entropy_estimator_kl(x, var): 20 | # KL-based upper bound on entropy of mixture of Gaussians with covariance matrix var * I 21 | # see Kolchinsky and Tracey, Estimating Mixture Entropy with Pairwise Distances, Entropy, 2017. Section 4. 22 | # and Kolchinsky and Tracey, Nonlinear Information Bottleneck, 2017. Eq. 10 23 | dims, N = get_shape(x) 24 | dists = Kget_dists(x) 25 | dists2 = dists / (2*var) 26 | normconst = (dims/2.0)*K.log(2*np.pi*var) 27 | lprobs = K.logsumexp(-dists2, axis=1) - K.log(N) - normconst 28 | h = -K.mean(lprobs) 29 | return dims/2 + h 30 | 31 | def entropy_estimator_bd(x, var): 32 | # Bhattacharyya-based lower bound on entropy of mixture of Gaussians with covariance matrix var * I 33 | # see Kolchinsky and Tracey, Estimating Mixture Entropy with Pairwise Distances, Entropy, 2017. Section 4. 34 | dims, N = get_shape(x) 35 | val = entropy_estimator_kl(x,4*var) 36 | return val + np.log(0.25)*dims/2 37 | 38 | def kde_condentropy(output, var): 39 | # Return entropy of a multivariate Gaussian, in nats 40 | dims = output.shape[1] 41 | return (dims/2.0)*(np.log(2*np.pi*var) + 1) 42 | 43 | -------------------------------------------------------------------------------- /simplebinmi.py: -------------------------------------------------------------------------------- 1 | # Simplified MI computation code from https://github.com/ravidziv/IDNNs 2 | import numpy as np 3 | 4 | def get_unique_probs(x): 5 | uniqueids = np.ascontiguousarray(x).view(np.dtype((np.void, x.dtype.itemsize * x.shape[1]))) 6 | _, unique_inverse, unique_counts = np.unique(uniqueids, return_index=False, return_inverse=True, return_counts=True) 7 | return np.asarray(unique_counts / float(sum(unique_counts))), unique_inverse 8 | 9 | def bin_calc_information(inputdata, layerdata, num_of_bins): 10 | p_xs, unique_inverse_x = get_unique_probs(inputdata) 11 | 12 | bins = np.linspace(-1, 1, num_of_bins, dtype='float32') 13 | digitized = bins[np.digitize(np.squeeze(layerdata.reshape(1, -1)), bins) - 1].reshape(len(layerdata), -1) 14 | p_ts, _ = get_unique_probs( digitized ) 15 | 16 | H_LAYER = -np.sum(p_ts * np.log(p_ts)) 17 | H_LAYER_GIVEN_INPUT = 0. 18 | for xval in np.arange(len(p_xs)): 19 | p_t_given_x, _ = get_unique_probs(digitized[unique_inverse_x == xval, :]) 20 | H_LAYER_GIVEN_INPUT += - p_xs[xval] * np.sum(p_t_given_x * np.log(p_t_given_x)) 21 | return H_LAYER - H_LAYER_GIVEN_INPUT 22 | 23 | def bin_calc_information2(labelixs, layerdata, binsize): 24 | # This is even further simplified, where we use np.floor instead of digitize 25 | def get_h(d): 26 | digitized = np.floor(d / binsize).astype('int') 27 | p_ts, _ = get_unique_probs( digitized ) 28 | return -np.sum(p_ts * np.log(p_ts)) 29 | 30 | H_LAYER = get_h(layerdata) 31 | H_LAYER_GIVEN_OUTPUT = 0 32 | for label, ixs in labelixs.items(): 33 | H_LAYER_GIVEN_OUTPUT += ixs.mean() * get_h(layerdata[ixs,:]) 34 | return H_LAYER, H_LAYER - H_LAYER_GIVEN_OUTPUT 35 | -------------------------------------------------------------------------------- /MNIST_SaveActivations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import keras\n", 10 | "import keras.backend as K\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "import utils\n", 14 | "import loggingreporter \n", 15 | "\n", 16 | "cfg = {}\n", 17 | "cfg['SGD_BATCHSIZE'] = 128\n", 18 | "cfg['SGD_LEARNINGRATE'] = 0.001\n", 19 | "cfg['NUM_EPOCHS'] = 10000\n", 20 | "\n", 21 | "#cfg['ACTIVATION'] = 'relu'\n", 22 | "cfg['ACTIVATION'] = 'tanh'\n", 23 | "# How many hidden neurons to put into each of the layers\n", 24 | "cfg['LAYER_DIMS'] = [1024, 20, 20, 20]\n", 25 | "#cfg['LAYER_DIMS'] = [32, 28, 24, 20, 16, 12, 8, 8]\n", 26 | "#cfg['LAYER_DIMS'] = [128, 64, 32, 16, 16] # 0.967 w. 128\n", 27 | "#cfg['LAYER_DIMS'] = [20, 20, 20, 20, 20, 20] # 0.967 w. 128\n", 28 | "ARCH_NAME = '-'.join(map(str,cfg['LAYER_DIMS']))\n", 29 | "trn, tst = utils.get_mnist()\n", 30 | "\n", 31 | "# Where to save activation and weights data\n", 32 | "cfg['SAVE_DIR'] = 'rawdata/' + cfg['ACTIVATION'] + '_' + ARCH_NAME " 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "input_layer = keras.layers.Input((trn.X.shape[1],))\n", 42 | "clayer = input_layer\n", 43 | "for n in cfg['LAYER_DIMS']:\n", 44 | " clayer = keras.layers.Dense(n, activation=cfg['ACTIVATION'])(clayer)\n", 45 | "output_layer = keras.layers.Dense(trn.nb_classes, activation='softmax')(clayer)\n", 46 | "\n", 47 | "model = keras.models.Model(inputs=input_layer, outputs=output_layer)\n", 48 | "optimizer = keras.optimizers.SGD(lr=cfg['SGD_LEARNINGRATE'])\n", 49 | "\n", 50 | "model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "scrolled": false 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "def do_report(epoch):\n", 62 | " # Only log activity for some epochs. Mainly this is to make things run faster.\n", 63 | " if epoch < 20: # Log for all first 20 epochs\n", 64 | " return True\n", 65 | " elif epoch < 100: # Then for every 5th epoch\n", 66 | " return (epoch % 5 == 0)\n", 67 | " elif epoch < 200: # Then every 10th\n", 68 | " return (epoch % 10 == 0)\n", 69 | " else: # Then every 100th\n", 70 | " return (epoch % 100 == 0)\n", 71 | " \n", 72 | "reporter = loggingreporter.LoggingReporter(cfg=cfg, \n", 73 | " trn=trn, \n", 74 | " tst=tst, \n", 75 | " do_save_func=do_report)\n", 76 | "r = model.fit(x=trn.X, y=trn.Y, \n", 77 | " verbose = 2, \n", 78 | " batch_size = cfg['SGD_BATCHSIZE'],\n", 79 | " epochs = cfg['NUM_EPOCHS'],\n", 80 | " # validation_data=(tst.X, tst.Y),\n", 81 | " callbacks = [reporter,])\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [] 90 | } 91 | ], 92 | "metadata": { 93 | "kernelspec": { 94 | "display_name": "Python 3", 95 | "language": "python", 96 | "name": "python3" 97 | }, 98 | "language_info": { 99 | "codemirror_mode": { 100 | "name": "ipython", 101 | "version": 3 102 | }, 103 | "file_extension": ".py", 104 | "mimetype": "text/x-python", 105 | "name": "python", 106 | "nbconvert_exporter": "python", 107 | "pygments_lexer": "ipython3", 108 | "version": "3.7.6" 109 | } 110 | }, 111 | "nbformat": 4, 112 | "nbformat_minor": 1 113 | } 114 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import keras.backend as K 3 | import numpy as np 4 | import scipy.io as sio 5 | from pathlib2 import Path 6 | from collections import namedtuple 7 | 8 | def get_mnist(): 9 | # Returns two namedtuples, with MNIST training and testing data 10 | # trn.X is training data 11 | # trn.y is trainiing class, with numbers from 0 to 9 12 | # trn.Y is training class, but coded as a 10-dim vector with one entry set to 1 13 | # similarly for tst 14 | nb_classes = 10 15 | (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data() 16 | X_train = np.reshape(X_train, [X_train.shape[0], -1]).astype('float32') / 255. 17 | X_test = np.reshape(X_test , [X_test.shape[0] , -1]).astype('float32') / 255. 18 | #X_train = X_train * 2.0 - 1.0 19 | #X_test = X_test * 2.0 - 1.0 20 | 21 | Y_train = keras.utils.np_utils.to_categorical(y_train, nb_classes).astype('float32') 22 | Y_test = keras.utils.np_utils.to_categorical(y_test, nb_classes).astype('float32') 23 | 24 | Dataset = namedtuple('Dataset',['X','Y','y','nb_classes']) 25 | trn = Dataset(X_train, Y_train, y_train, nb_classes) 26 | tst = Dataset(X_test , Y_test, y_test, nb_classes) 27 | 28 | del X_train, X_test, Y_train, Y_test, y_train, y_test 29 | 30 | return trn, tst 31 | 32 | def get_IB_data(ID): 33 | # Returns two namedtuples, with IB training and testing data 34 | # trn.X is training data 35 | # trn.y is trainiing class, with numbers from 0 to 9 36 | # trn.Y is training class, but coded as a 10-dim vector with one entry set to 1 37 | # similarly for tst 38 | nb_classes = 2 39 | data_file = Path('datasets/IB_data_'+str(ID)+'.npz') 40 | if data_file.is_file(): 41 | data = np.load('datasets/IB_data_'+str(ID)+'.npz') 42 | else: 43 | create_IB_data(ID) 44 | data = np.load('datasets/IB_data_'+str(ID)+'.npz') 45 | 46 | (X_train, y_train), (X_test, y_test) = (data['X_train'], data['y_train']), (data['X_test'], data['y_test']) 47 | 48 | Y_train = keras.utils.np_utils.to_categorical(y_train, nb_classes).astype('float32') 49 | Y_test = keras.utils.np_utils.to_categorical(y_test, nb_classes).astype('float32') 50 | 51 | Dataset = namedtuple('Dataset',['X','Y','y','nb_classes']) 52 | trn = Dataset(X_train, Y_train, y_train, nb_classes) 53 | tst = Dataset(X_test , Y_test, y_test, nb_classes) 54 | del X_train, X_test, Y_train, Y_test, y_train, y_test 55 | return trn, tst 56 | 57 | def create_IB_data(idx): 58 | data_sets_org = load_data() 59 | data_sets = data_shuffle(data_sets_org, 80, shuffle_data=True) 60 | X_train, y_train, X_test, y_test = data_sets.train.data, data_sets.train.labels[:,0], data_sets.test.data, data_sets.test.labels[:,0] 61 | np.savez_compressed('datasets/IB_data_'+str(idx), X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test) 62 | 63 | def construct_full_dataset(trn, tst): 64 | Dataset = namedtuple('Dataset',['X','Y','y','nb_classes']) 65 | X = np.concatenate((trn.X,tst.X)) 66 | y = np.concatenate((trn.y,tst.y)) 67 | Y = np.concatenate((trn.Y,tst.Y)) 68 | return Dataset(X, Y, y, trn.nb_classes) 69 | 70 | def load_data(): 71 | """Load the data 72 | name - the name of the dataset 73 | return object with data and labels""" 74 | print ('Loading Data...') 75 | C = type('type_C', (object,), {}) 76 | data_sets = C() 77 | d = sio.loadmat('datasets/var_u.mat') 78 | F = d['F'] 79 | y = d['y'] 80 | C = type('type_C', (object,), {}) 81 | data_sets = C() 82 | data_sets.data = F 83 | data_sets.labels = np.squeeze(np.concatenate((y[None, :], 1 - y[None, :]), axis=0).T) 84 | return data_sets 85 | 86 | def shuffle_in_unison_inplace(a, b): 87 | """Shuffle the arrays randomly""" 88 | assert len(a) == len(b) 89 | p = np.random.permutation(len(a)) 90 | return a[p], b[p] 91 | 92 | def data_shuffle(data_sets_org, percent_of_train, min_test_data=80, shuffle_data=False): 93 | """Divided the data to train and test and shuffle it""" 94 | perc = lambda i, t: np.rint((i * t) / 100).astype(np.int32) 95 | C = type('type_C', (object,), {}) 96 | data_sets = C() 97 | stop_train_index = perc(percent_of_train, data_sets_org.data.shape[0]) 98 | start_test_index = stop_train_index 99 | 100 | if percent_of_train > min_test_data: 101 | start_test_index = perc(min_test_data, data_sets_org.data.shape[0]) 102 | data_sets.train = C() 103 | data_sets.test = C() 104 | if shuffle_data: 105 | shuffled_data, shuffled_labels = shuffle_in_unison_inplace(data_sets_org.data, data_sets_org.labels) 106 | else: 107 | shuffled_data, shuffled_labels = data_sets_org.data, data_sets_org.labels 108 | data_sets.train.data = shuffled_data[:stop_train_index, :] 109 | data_sets.train.labels = shuffled_labels[:stop_train_index, :] 110 | data_sets.test.data = shuffled_data[start_test_index:, :] 111 | data_sets.test.labels = shuffled_labels[start_test_index:, :] 112 | return data_sets 113 | -------------------------------------------------------------------------------- /IBnet_SaveActivations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import keras\n", 10 | "import keras.backend as K\n", 11 | "import tensorflow as tf\n", 12 | "import numpy as np\n", 13 | "\n", 14 | "import utils\n", 15 | "import loggingreporter \n", 16 | "\n", 17 | "cfg = {}\n", 18 | "cfg['SGD_BATCHSIZE'] = 256\n", 19 | "cfg['SGD_LEARNINGRATE'] = 0.0004\n", 20 | "cfg['NUM_EPOCHS'] = 10000\n", 21 | "cfg['FULL_MI'] = True\n", 22 | "\n", 23 | "cfg['ACTIVATION'] = 'tanh'\n", 24 | "# cfg['ACTIVATION'] = 'relu'\n", 25 | "# cfg['ACTIVATION'] = 'softsign'\n", 26 | "# cfg['ACTIVATION'] = 'softplus'\n", 27 | "\n", 28 | "# How many hidden neurons to put into each of the layers\n", 29 | "cfg['LAYER_DIMS'] = [10,7,5,4,3] # original IB network\n", 30 | "ARCH_NAME = '-'.join(map(str,cfg['LAYER_DIMS']))\n", 31 | "\n", 32 | "trn, tst = utils.get_IB_data('2017_12_21_16_51_3_275766')\n", 33 | "\n", 34 | "# Where to save activation and weights data\n", 35 | "cfg['SAVE_DIR'] = 'rawdata/' + cfg['ACTIVATION'] + '_' + ARCH_NAME " 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "input_layer = keras.layers.Input(shape=(trn.X.shape[1],))\n", 45 | "clayer = input_layer\n", 46 | "for n in cfg['LAYER_DIMS']:\n", 47 | " clayer = keras.layers.Dense(n, \n", 48 | " activation=cfg['ACTIVATION'],\n", 49 | " kernel_initializer=keras.initializers.TruncatedNormal(mean=0.0, stddev=1/np.sqrt(float(n)), seed=None),\n", 50 | " bias_initializer='zeros'\n", 51 | " )(clayer)\n", 52 | "output_layer = keras.layers.Dense(trn.nb_classes, activation='softmax')(clayer)\n", 53 | "\n", 54 | "model = keras.models.Model(inputs=input_layer, outputs=output_layer)\n", 55 | "optimizer = keras.optimizers.Adam(learning_rate=cfg['SGD_LEARNINGRATE'])\n", 56 | "\n", 57 | "model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "scrolled": false 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "def do_report(epoch):\n", 69 | " # Only log activity for some epochs. Mainly this is to make things run faster.\n", 70 | " if epoch < 20: # Log for all first 20 epochs\n", 71 | " return True\n", 72 | " elif epoch < 100: # Then for every 5th epoch\n", 73 | " return (epoch % 5 == 0)\n", 74 | " elif epoch < 2000: # Then every 10th\n", 75 | " return (epoch % 20 == 0)\n", 76 | " else: # Then every 100th\n", 77 | " return (epoch % 100 == 0)\n", 78 | " \n", 79 | "reporter = loggingreporter.LoggingReporter(cfg=cfg, \n", 80 | " trn=trn, \n", 81 | " tst=tst, \n", 82 | " do_save_func=do_report)\n", 83 | "r = model.fit(x=trn.X, y=trn.Y, \n", 84 | " verbose = 2, \n", 85 | " batch_size = cfg['SGD_BATCHSIZE'],\n", 86 | " epochs = cfg['NUM_EPOCHS'],\n", 87 | " # validation_data=(tst.X, tst.Y),\n", 88 | " callbacks = [reporter,])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.7.6" 116 | }, 117 | "varInspector": { 118 | "cols": { 119 | "lenName": 16, 120 | "lenType": 16, 121 | "lenVar": 40 122 | }, 123 | "kernels_config": { 124 | "python": { 125 | "delete_cmd_postfix": "", 126 | "delete_cmd_prefix": "del ", 127 | "library": "var_list.py", 128 | "varRefreshCmd": "print(var_dic_list())" 129 | }, 130 | "r": { 131 | "delete_cmd_postfix": ") ", 132 | "delete_cmd_prefix": "rm(", 133 | "library": "var_list.r", 134 | "varRefreshCmd": "cat(var_dic_list()) " 135 | } 136 | }, 137 | "types_to_exclude": [ 138 | "module", 139 | "function", 140 | "builtin_function_or_method", 141 | "instance", 142 | "_Feature" 143 | ], 144 | "window_display": false 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 1 149 | } 150 | -------------------------------------------------------------------------------- /loggingreporter.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import keras 3 | import keras.backend as K 4 | import numpy as np 5 | 6 | from six.moves import cPickle 7 | import os 8 | 9 | import utils 10 | 11 | class LoggingReporter(keras.callbacks.Callback): 12 | def __init__(self, cfg, trn, tst, do_save_func=None, *kargs, **kwargs): 13 | super(LoggingReporter, self).__init__(*kargs, **kwargs) 14 | self.cfg = cfg # Configuration options dictionary 15 | self.trn = trn # Train data 16 | self.tst = tst # Test data 17 | 18 | if 'FULL_MI' not in cfg: 19 | self.cfg['FULL_MI'] = False # Whether to compute MI on train and test data, or just test 20 | 21 | if self.cfg['FULL_MI']: 22 | self.full = utils.construct_full_dataset(trn,tst) 23 | 24 | # do_save_func(epoch) should return True if we should save on that epoch 25 | self.do_save_func = do_save_func 26 | 27 | def on_train_begin(self, logs={}): 28 | if not os.path.exists(self.cfg['SAVE_DIR']): 29 | print("Making directory", self.cfg['SAVE_DIR']) 30 | os.makedirs(self.cfg['SAVE_DIR']) 31 | 32 | # Indexes of the layers which we keep track of. Basically, this will be any layer 33 | # which has a 'kernel' attribute, which is essentially the "Dense" or "Dense"-like layers 34 | self.layerixs = [] 35 | 36 | # Functions return activity of each layer 37 | self.layerfuncs = [] 38 | 39 | # Functions return weights of each layer 40 | self.layerweights = [] 41 | for lndx, l in enumerate(self.model.layers): 42 | if hasattr(l, 'kernel'): 43 | self.layerixs.append(lndx) 44 | self.layerfuncs.append(K.function(self.model.inputs, [l.output,])) 45 | self.layerweights.append(l.kernel) 46 | 47 | inputs = [self.model._feed_inputs, 48 | self.model._feed_targets, 49 | self.model._feed_sample_weights, 50 | K.learning_phase()] 51 | 52 | # Get gradients of all the relevant layers at once 53 | grads = self.model.optimizer.get_gradients(self.model.total_loss, self.layerweights) 54 | self.get_gradients = K.function(inputs=inputs, outputs=grads) 55 | 56 | # Get cross-entropy loss 57 | self.get_loss = K.function(inputs=inputs, outputs=[self.model.total_loss,]) 58 | 59 | def on_epoch_begin(self, epoch, logs={}): 60 | if self.do_save_func is not None and not self.do_save_func(epoch): 61 | # Don't log this epoch 62 | self._log_gradients = False 63 | else: 64 | # We will log this epoch. For each batch in this epoch, we will save the gradients (in on_batch_begin) 65 | # We will then compute means and vars of these gradients 66 | 67 | self._log_gradients = True 68 | self._batch_weightnorm = [] 69 | 70 | self._batch_gradients = [ [] for _ in self.model.layers[1:] ] 71 | 72 | # Indexes of all the training data samples. These are shuffled and read-in in chunks of SGD_BATCHSIZE 73 | ixs = list(range(len(self.trn.X))) 74 | np.random.shuffle(ixs) 75 | self._batch_todo_ixs = ixs 76 | 77 | def on_batch_begin(self, batch, logs={}): 78 | if not self._log_gradients: 79 | # We are not keeping track of batch gradients, so do nothing 80 | return 81 | 82 | # Sample a batch 83 | batchsize = self.cfg['SGD_BATCHSIZE'] 84 | cur_ixs = self._batch_todo_ixs[:batchsize] 85 | # Advance the indexing, so next on_batch_begin samples a different batch 86 | self._batch_todo_ixs = self._batch_todo_ixs[batchsize:] 87 | 88 | # Get gradients for this batch 89 | 90 | x, y, weights = self.model._standardize_user_data(self.trn.X[cur_ixs,:], self.trn.Y[cur_ixs,:]) 91 | inputs = [x, y, weights, 1] # 1 indicates training phase 92 | 93 | for lndx, g in enumerate(self.get_gradients(inputs)): 94 | # g is gradients for weights of lndx's layer 95 | oneDgrad = np.reshape(g, [-1, 1]) # Flatten to one dimensional vector 96 | self._batch_gradients[lndx].append(oneDgrad) 97 | 98 | 99 | def on_epoch_end(self, epoch, logs={}): 100 | if self.do_save_func is not None and not self.do_save_func(epoch): 101 | # Don't log this epoch 102 | return 103 | 104 | # Get overall performance 105 | loss = {} 106 | for cdata, cdataname, istrain in ((self.trn,'trn',1), (self.tst, 'tst',0)): 107 | x, y, weights = self.model._standardize_user_data(cdata.X, cdata.Y) 108 | loss[cdataname] = self.get_loss([x, y, weights, istrain])[0].flat[0] 109 | 110 | data = { 111 | 'weights_norm' : [], # L2 norm of weights 112 | 'gradmean' : [], # Mean of gradients 113 | 'gradstd' : [], # Std of gradients 114 | 'activity_tst' : [] # Activity in each layer for test set 115 | } 116 | 117 | for lndx, layerix in enumerate(self.layerixs): 118 | clayer = self.model.layers[layerix] 119 | 120 | data['weights_norm'].append( np.linalg.norm(K.get_value(clayer.kernel)) ) 121 | 122 | stackedgrads = np.stack(self._batch_gradients[lndx], axis=1) 123 | data['gradmean' ].append( np.linalg.norm(stackedgrads.mean(axis=1)) ) 124 | data['gradstd' ].append( np.linalg.norm(stackedgrads.std(axis=1)) ) 125 | 126 | if self.cfg['FULL_MI']: 127 | data['activity_tst'].append(self.layerfuncs[lndx]([self.full.X,])[0]) 128 | else: 129 | data['activity_tst'].append(self.layerfuncs[lndx]([self.tst.X,])[0]) 130 | 131 | fname = self.cfg['SAVE_DIR'] + "/epoch%08d"% epoch 132 | print("Saving", fname) 133 | with open(fname, 'wb') as f: 134 | cPickle.dump({'ACTIVATION':self.cfg['ACTIVATION'], 'epoch':epoch, 'data':data, 'loss':loss}, f, cPickle.HIGHEST_PROTOCOL) 135 | 136 | -------------------------------------------------------------------------------- /IBnet_ComputeMI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "import os\n", 13 | "if not os.path.exists('plots/'):\n", 14 | " os.mkdir('plots')\n", 15 | "\n", 16 | "from six.moves import cPickle\n", 17 | "from collections import defaultdict, OrderedDict\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import keras.backend as K\n", 21 | "\n", 22 | "import kde\n", 23 | "import simplebinmi\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import matplotlib.gridspec as gridspec\n", 28 | "import seaborn as sns\n", 29 | "sns.set_style('darkgrid')\n", 30 | "\n", 31 | "import utils\n", 32 | "\n", 33 | "# load data network was trained on\n", 34 | "trn, tst = utils.get_IB_data('2017_12_21_16_51_3_275766')\n", 35 | "\n", 36 | "# calc MI for train and test. Save_activations must have been run with cfg['FULL_MI'] = True\n", 37 | "FULL_MI = True\n", 38 | "\n", 39 | "# Which measure to plot\n", 40 | "infoplane_measure = 'upper'\n", 41 | "# infoplane_measure = 'bin'\n", 42 | "\n", 43 | "DO_SAVE = True # Whether to save plots or just show them\n", 44 | "DO_LOWER = (infoplane_measure == 'lower') # Whether to compute lower bounds also\n", 45 | "DO_BINNED = (infoplane_measure == 'bin') # Whether to compute MI estimates based on binning\n", 46 | "\n", 47 | "MAX_EPOCHS = 10000 # Max number of epoch for which to compute mutual information measure\n", 48 | "NUM_LABELS = 2\n", 49 | "# MAX_EPOCHS = 1000\n", 50 | "COLORBAR_MAX_EPOCHS = 10000\n", 51 | "\n", 52 | "# Directories from which to load saved layer activity\n", 53 | "# ARCH = '1024-20-20-20'\n", 54 | "ARCH = '10-7-5-4-3'\n", 55 | "#ARCH = '20-20-20-20-20-20'\n", 56 | "#ARCH = '32-28-24-20-16-12'\n", 57 | "#ARCH = '32-28-24-20-16-12-8-8'\n", 58 | "DIR_TEMPLATE = '%%s_%s'%ARCH\n", 59 | "\n", 60 | "# Functions to return upper and lower bounds on entropy of layer activity\n", 61 | "noise_variance = 1e-3 # Added Gaussian noise variance\n", 62 | "binsize = 0.07 # size of bins for binning method\n", 63 | "Klayer_activity = K.placeholder(ndim=2) # Keras placeholder \n", 64 | "entropy_func_upper = K.function([Klayer_activity,], [kde.entropy_estimator_kl(Klayer_activity, noise_variance),])\n", 65 | "entropy_func_lower = K.function([Klayer_activity,], [kde.entropy_estimator_bd(Klayer_activity, noise_variance),])\n", 66 | "\n", 67 | "# nats to bits conversion factor\n", 68 | "nats2bits = 1.0/np.log(2) \n", 69 | "\n", 70 | "# Save indexes of tests data for each of the output classes\n", 71 | "saved_labelixs = {}\n", 72 | "\n", 73 | "y = tst.y\n", 74 | "Y = tst.Y\n", 75 | "if FULL_MI:\n", 76 | " full = utils.construct_full_dataset(trn,tst)\n", 77 | " y = full.y\n", 78 | " Y = full.Y\n", 79 | "\n", 80 | "for i in range(NUM_LABELS):\n", 81 | " saved_labelixs[i] = y == i\n", 82 | "\n", 83 | "labelprobs = np.mean(Y, axis=0)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "PLOT_LAYERS = None # Which layers to plot. If None, all saved layers are plotted \n", 93 | "\n", 94 | "# Data structure used to store results\n", 95 | "measures = OrderedDict()\n", 96 | "measures['tanh'] = {}\n", 97 | "measures['relu'] = {}\n", 98 | "# measures['softsign'] = {}\n", 99 | "# measures['softplus'] = {}" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "Compute MI measures\n", 107 | "-----" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "scrolled": false 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "for activation in measures.keys():\n", 119 | " cur_dir = 'rawdata/' + DIR_TEMPLATE % activation\n", 120 | " if not os.path.exists(cur_dir):\n", 121 | " print(\"Directory %s not found\" % cur_dir)\n", 122 | " continue\n", 123 | " \n", 124 | " # Load files saved during each epoch, and compute MI measures of the activity in that epoch\n", 125 | " print('*** Doing %s ***' % cur_dir)\n", 126 | " for epochfile in sorted(os.listdir(cur_dir)):\n", 127 | " if not epochfile.startswith('epoch'):\n", 128 | " continue\n", 129 | " \n", 130 | " fname = cur_dir + \"/\" + epochfile\n", 131 | " with open(fname, 'rb') as f:\n", 132 | " d = cPickle.load(f)\n", 133 | "\n", 134 | " epoch = d['epoch']\n", 135 | " if epoch in measures[activation]: # Skip this epoch if its already been processed\n", 136 | " continue # this is a trick to allow us to rerun this cell multiple times)\n", 137 | " \n", 138 | " if epoch > MAX_EPOCHS:\n", 139 | " continue\n", 140 | "\n", 141 | " print(\"Doing\", fname)\n", 142 | " \n", 143 | " num_layers = len(d['data']['activity_tst'])\n", 144 | "\n", 145 | " if PLOT_LAYERS is None:\n", 146 | " PLOT_LAYERS = []\n", 147 | " for lndx in range(num_layers):\n", 148 | " #if d['data']['activity_tst'][lndx].shape[1] < 200 and lndx != num_layers - 1:\n", 149 | " PLOT_LAYERS.append(lndx)\n", 150 | " \n", 151 | " cepochdata = defaultdict(list)\n", 152 | " for lndx in range(num_layers):\n", 153 | " activity = d['data']['activity_tst'][lndx]\n", 154 | "\n", 155 | " # Compute marginal entropies\n", 156 | " h_upper = entropy_func_upper([activity,])[0]\n", 157 | " if DO_LOWER:\n", 158 | " h_lower = entropy_func_lower([activity,])[0]\n", 159 | " \n", 160 | " # Layer activity given input. This is simply the entropy of the Gaussian noise\n", 161 | " hM_given_X = kde.kde_condentropy(activity, noise_variance)\n", 162 | "\n", 163 | " # Compute conditional entropies of layer activity given output\n", 164 | " hM_given_Y_upper=0.\n", 165 | " for i in range(NUM_LABELS):\n", 166 | " hcond_upper = entropy_func_upper([activity[saved_labelixs[i],:],])[0]\n", 167 | " hM_given_Y_upper += labelprobs[i] * hcond_upper\n", 168 | " \n", 169 | " if DO_LOWER:\n", 170 | " hM_given_Y_lower=0.\n", 171 | " for i in range(NUM_LABELS):\n", 172 | " hcond_lower = entropy_func_lower([activity[saved_labelixs[i],:],])[0]\n", 173 | " hM_given_Y_lower += labelprobs[i] * hcond_lower\n", 174 | " \n", 175 | " cepochdata['MI_XM_upper'].append( nats2bits * (h_upper - hM_given_X) )\n", 176 | " cepochdata['MI_YM_upper'].append( nats2bits * (h_upper - hM_given_Y_upper) )\n", 177 | " cepochdata['H_M_upper' ].append( nats2bits * h_upper )\n", 178 | "\n", 179 | " pstr = 'upper: MI(X;M)=%0.3f, MI(Y;M)=%0.3f' % (cepochdata['MI_XM_upper'][-1], cepochdata['MI_YM_upper'][-1])\n", 180 | " if DO_LOWER: # Compute lower bounds\n", 181 | " cepochdata['MI_XM_lower'].append( nats2bits * (h_lower - hM_given_X) )\n", 182 | " cepochdata['MI_YM_lower'].append( nats2bits * (h_lower - hM_given_Y_lower) )\n", 183 | " cepochdata['H_M_lower' ].append( nats2bits * h_lower )\n", 184 | " pstr += ' | lower: MI(X;M)=%0.3f, MI(Y;M)=%0.3f' % (cepochdata['MI_XM_lower'][-1], cepochdata['MI_YM_lower'][-1])\n", 185 | "\n", 186 | " if DO_BINNED: # Compute binned estimates\n", 187 | " binxm, binym = simplebinmi.bin_calc_information2(saved_labelixs, activity, binsize)\n", 188 | " cepochdata['MI_XM_bin'].append( nats2bits * binxm )\n", 189 | " cepochdata['MI_YM_bin'].append( nats2bits * binym )\n", 190 | " pstr += ' | bin: MI(X;M)=%0.3f, MI(Y;M)=%0.3f' % (cepochdata['MI_XM_bin'][-1], cepochdata['MI_YM_bin'][-1])\n", 191 | " \n", 192 | " print('- Layer %d %s' % (lndx, pstr) )\n", 193 | "\n", 194 | " measures[activation][epoch] = cepochdata" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "Plot Infoplane Visualization\n", 202 | "----" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "scrolled": false 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "max_epoch = max( (max(vals.keys()) if len(vals) else 0) for vals in measures.values())\n", 214 | "sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))\n", 215 | "sm._A = []\n", 216 | "\n", 217 | "fig=plt.figure(figsize=(10,5))\n", 218 | "for actndx, (activation, vals) in enumerate(measures.items()):\n", 219 | " epochs = sorted(vals.keys())\n", 220 | " if not len(epochs):\n", 221 | " continue\n", 222 | " plt.subplot(1,2,actndx+1) \n", 223 | " for epoch in epochs:\n", 224 | " c = sm.to_rgba(epoch)\n", 225 | " xmvals = np.array(vals[epoch]['MI_XM_'+infoplane_measure])[PLOT_LAYERS]\n", 226 | " ymvals = np.array(vals[epoch]['MI_YM_'+infoplane_measure])[PLOT_LAYERS]\n", 227 | "\n", 228 | " plt.plot(xmvals, ymvals, c=c, alpha=0.1, zorder=1)\n", 229 | " plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2)\n", 230 | "\n", 231 | " plt.ylim([0, 1])\n", 232 | " plt.xlim([0, 12])\n", 233 | "# plt.ylim([0, 3.5])\n", 234 | "# plt.xlim([0, 14])\n", 235 | " plt.xlabel('I(X;M)')\n", 236 | " plt.ylabel('I(Y;M)')\n", 237 | " plt.title(activation)\n", 238 | " \n", 239 | "cbaxes = fig.add_axes([1.0, 0.125, 0.03, 0.8]) \n", 240 | "plt.colorbar(sm, label='Epoch', cax=cbaxes)\n", 241 | "plt.tight_layout()\n", 242 | "\n", 243 | "if DO_SAVE:\n", 244 | " plt.savefig('plots/' + DIR_TEMPLATE % ('infoplane_'+ARCH),bbox_inches='tight')" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "Plot SNR curves\n", 252 | "----" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "plt.figure(figsize=(12,5))\n", 262 | "\n", 263 | "gs = gridspec.GridSpec(len(measures), len(PLOT_LAYERS))\n", 264 | "for activation in measures.keys():\n", 265 | " cur_dir = 'rawdata/' + DIR_TEMPLATE % activation\n", 266 | " if not os.path.exists(cur_dir):\n", 267 | " continue\n", 268 | " \n", 269 | " epochs = []\n", 270 | " means = []\n", 271 | " stds = []\n", 272 | " wnorms = []\n", 273 | " for epochfile in sorted(os.listdir(cur_dir)):\n", 274 | " if not epochfile.startswith('epoch'):\n", 275 | " continue\n", 276 | " \n", 277 | " with open(cur_dir + \"/\"+epochfile, 'rb') as f:\n", 278 | " d = cPickle.load(f)\n", 279 | " \n", 280 | " epoch = d['epoch']\n", 281 | " epochs.append(epoch)\n", 282 | " wnorms.append(d['data']['weights_norm'])\n", 283 | " means.append(d['data']['gradmean'])\n", 284 | " stds.append(d['data']['gradstd'])\n", 285 | "\n", 286 | " wnorms, means, stds = map(np.array, [wnorms, means, stds])\n", 287 | " for lndx,layerid in enumerate(PLOT_LAYERS):\n", 288 | " plt.subplot(gs[actndx, lndx])\n", 289 | " plt.plot(epochs, means[:,layerid], 'b', label=\"Mean\")\n", 290 | " plt.plot(epochs, stds[:,layerid], 'orange', label=\"Std\")\n", 291 | " plt.plot(epochs, means[:,layerid]/stds[:,layerid], 'red', label=\"SNR\")\n", 292 | " plt.plot(epochs, wnorms[:,layerid], 'g', label=\"||W||\")\n", 293 | "\n", 294 | " plt.title('Layer %d'%layerid)\n", 295 | " plt.xlabel('Epoch')\n", 296 | " plt.gca().set_xscale(\"log\", nonposx='clip')\n", 297 | " plt.gca().set_yscale(\"log\", nonposy='clip')\n", 298 | " \n", 299 | "\n", 300 | "plt.legend(loc='lower left', bbox_to_anchor=(1.1, 0.2))\n", 301 | "plt.tight_layout()\n", 302 | "\n", 303 | "if DO_SAVE:\n", 304 | " plt.savefig('plots/' + DIR_TEMPLATE % ('snr_'+ARCH), bbox_inches='tight')\n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python 3", 318 | "language": "python", 319 | "name": "python3" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.7.6" 332 | }, 333 | "varInspector": { 334 | "cols": { 335 | "lenName": 16, 336 | "lenType": 16, 337 | "lenVar": 40 338 | }, 339 | "kernels_config": { 340 | "python": { 341 | "delete_cmd_postfix": "", 342 | "delete_cmd_prefix": "del ", 343 | "library": "var_list.py", 344 | "varRefreshCmd": "print(var_dic_list())" 345 | }, 346 | "r": { 347 | "delete_cmd_postfix": ") ", 348 | "delete_cmd_prefix": "rm(", 349 | "library": "var_list.r", 350 | "varRefreshCmd": "cat(var_dic_list()) " 351 | } 352 | }, 353 | "types_to_exclude": [ 354 | "module", 355 | "function", 356 | "builtin_function_or_method", 357 | "instance", 358 | "_Feature" 359 | ], 360 | "window_display": false 361 | } 362 | }, 363 | "nbformat": 4, 364 | "nbformat_minor": 1 365 | } 366 | -------------------------------------------------------------------------------- /MNIST_ComputeMI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "import os\n", 13 | "if not os.path.exists('plots/'):\n", 14 | " os.mkdir('plots')\n", 15 | "\n", 16 | "from six.moves import cPickle\n", 17 | "from collections import defaultdict, OrderedDict\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import keras.backend as K\n", 21 | "\n", 22 | "import kde\n", 23 | "import simplebinmi\n", 24 | "\n", 25 | "import utils\n", 26 | "trn, tst = utils.get_mnist()\n", 27 | "\n", 28 | "# Which measure to plot\n", 29 | "infoplane_measure = 'upper'\n", 30 | "#infoplane_measure = 'bin'\n", 31 | "\n", 32 | "DO_SAVE = False # Whether to save plots or just show them\n", 33 | "DO_LOWER = True # (infoplane_measure == 'lower') # Whether to compute lower bounds also\n", 34 | "DO_BINNED = True #(infoplane_measure == 'bin') # Whether to compute MI estimates based on binning\n", 35 | "\n", 36 | "MAX_EPOCHS = 10000 # Max number of epoch for which to compute mutual information measure\n", 37 | "# MAX_EPOCHS = 1000\n", 38 | "COLORBAR_MAX_EPOCHS = 10000\n", 39 | "\n", 40 | "# Directories from which to load saved layer activity\n", 41 | "ARCH = '1024-20-20-20'\n", 42 | "#ARCH = '20-20-20-20-20-20'\n", 43 | "#ARCH = '32-28-24-20-16-12'\n", 44 | "#ARCH = '32-28-24-20-16-12-8-8'\n", 45 | "DIR_TEMPLATE = '%%s_%s'%ARCH\n", 46 | "\n", 47 | "# Functions to return upper and lower bounds on entropy of layer activity\n", 48 | "noise_variance = 1e-1 # Added Gaussian noise variance\n", 49 | "Klayer_activity = K.placeholder(ndim=2) # Keras placeholder \n", 50 | "entropy_func_upper = K.function([Klayer_activity,], [kde.entropy_estimator_kl(Klayer_activity, noise_variance),])\n", 51 | "entropy_func_lower = K.function([Klayer_activity,], [kde.entropy_estimator_bd(Klayer_activity, noise_variance),])\n", 52 | "\n", 53 | "\n", 54 | "# nats to bits conversion factor\n", 55 | "nats2bits = 1.0/np.log(2) \n", 56 | "\n", 57 | "\n", 58 | "# Save indexes of tests data for each of the output classes\n", 59 | "saved_labelixs = {}\n", 60 | "for i in range(10):\n", 61 | " saved_labelixs[i] = tst.y == i\n", 62 | "\n", 63 | "labelprobs = np.mean(tst.Y, axis=0)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "PLOT_LAYERS = None # Which layers to plot. If None, all saved layers are plotted \n", 73 | "\n", 74 | "# Data structure used to store results\n", 75 | "measures = OrderedDict()\n", 76 | "measures['relu'] = {}\n", 77 | "measures['tanh'] = {}" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "Compute MI measures\n", 85 | "-----" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "scrolled": false 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "for activation in measures.keys():\n", 97 | " cur_dir = 'rawdata/' + DIR_TEMPLATE % activation\n", 98 | " if not os.path.exists(cur_dir):\n", 99 | " print(\"Directory %s not found\" % cur_dir)\n", 100 | " continue\n", 101 | " \n", 102 | " # Load files saved during each epoch, and compute MI measures of the activity in that epoch\n", 103 | " print('*** Doing %s ***' % cur_dir)\n", 104 | " for epochfile in sorted(os.listdir(cur_dir)):\n", 105 | " if not epochfile.startswith('epoch'):\n", 106 | " continue\n", 107 | " \n", 108 | " fname = cur_dir + \"/\" + epochfile\n", 109 | " with open(fname, 'rb') as f:\n", 110 | " d = cPickle.load(f)\n", 111 | "\n", 112 | " epoch = d['epoch']\n", 113 | " if epoch in measures[activation]: # Skip this epoch if its already been processed\n", 114 | " continue # this is a trick to allow us to rerun this cell multiple times)\n", 115 | " \n", 116 | " if epoch > MAX_EPOCHS:\n", 117 | " continue\n", 118 | "\n", 119 | " print(\"Doing\", fname)\n", 120 | " \n", 121 | " num_layers = len(d['data']['activity_tst'])\n", 122 | "\n", 123 | " if PLOT_LAYERS is None:\n", 124 | " PLOT_LAYERS = []\n", 125 | " for lndx in range(num_layers):\n", 126 | " #if d['data']['activity_tst'][lndx].shape[1] < 200 and lndx != num_layers - 1:\n", 127 | " PLOT_LAYERS.append(lndx)\n", 128 | " \n", 129 | " cepochdata = defaultdict(list)\n", 130 | " for lndx in range(num_layers):\n", 131 | " activity = d['data']['activity_tst'][lndx]\n", 132 | "\n", 133 | " # Compute marginal entropies\n", 134 | " h_upper = entropy_func_upper([activity,])[0]\n", 135 | " if DO_LOWER:\n", 136 | " h_lower = entropy_func_lower([activity,])[0]\n", 137 | " \n", 138 | " # Layer activity given input. This is simply the entropy of the Gaussian noise\n", 139 | " hM_given_X = kde.kde_condentropy(activity, noise_variance)\n", 140 | "\n", 141 | " # Compute conditional entropies of layer activity given output\n", 142 | " hM_given_Y_upper=0.\n", 143 | " for i in range(10):\n", 144 | " hcond_upper = entropy_func_upper([activity[saved_labelixs[i],:],])[0]\n", 145 | " hM_given_Y_upper += labelprobs[i] * hcond_upper\n", 146 | " \n", 147 | " if DO_LOWER:\n", 148 | " hM_given_Y_lower=0.\n", 149 | " for i in range(10):\n", 150 | " hcond_lower = entropy_func_lower([activity[saved_labelixs[i],:],])[0]\n", 151 | " hM_given_Y_lower += labelprobs[i] * hcond_lower\n", 152 | " \n", 153 | " \n", 154 | " # # It's also possible to treat the last layer probabilistically. Here is the \n", 155 | " # # code to do so. Should only be applied when lndx == num_layers - 1\n", 156 | "\n", 157 | " # ps = activity.mean(axis=0)\n", 158 | " # h_lower = h_upper = sum([-p*np.log(p) for p in ps if p != 0])\n", 159 | "\n", 160 | " # x = -activity * np.log(activity)\n", 161 | " # x[activity == 0] = 0.\n", 162 | " # hM_given_X = np.mean(x.sum(axis=1))\n", 163 | "\n", 164 | " # hM_given_Y=0.\n", 165 | " # for i in range(10):\n", 166 | " # ixs = tst.y[::subsample] == i\n", 167 | " # ps = activity[ixs,:].mean(axis=0)\n", 168 | " # hcond = sum([-p*np.log(p) for p in ps if p != 0])\n", 169 | " # prob = np.mean(ixs)\n", 170 | " # hM_given_Y += l * hcond\n", 171 | " # hM_given_Y_lower = hM_given_Y_upper = hM_given_Y\n", 172 | " # del hM_given_Y\n", 173 | " \n", 174 | " cepochdata['MI_XM_upper'].append( nats2bits * (h_upper - hM_given_X) )\n", 175 | " cepochdata['MI_YM_upper'].append( nats2bits * (h_upper - hM_given_Y_upper) )\n", 176 | " cepochdata['H_M_upper' ].append( nats2bits * h_upper )\n", 177 | "\n", 178 | " pstr = 'upper: MI(X;M)=%0.3f, MI(Y;M)=%0.3f' % (cepochdata['MI_XM_upper'][-1], cepochdata['MI_YM_upper'][-1])\n", 179 | " if DO_LOWER: # Compute lower bounds\n", 180 | " cepochdata['MI_XM_lower'].append( nats2bits * (h_lower - hM_given_X) )\n", 181 | " cepochdata['MI_YM_lower'].append( nats2bits * (h_lower - hM_given_Y_lower) )\n", 182 | " cepochdata['H_M_lower' ].append( nats2bits * h_lower )\n", 183 | " pstr += ' | lower: MI(X;M)=%0.3f, MI(Y;M)=%0.3f' % (cepochdata['MI_XM_lower'][-1], cepochdata['MI_YM_lower'][-1])\n", 184 | "\n", 185 | " if DO_BINNED: # Compute binner estimates\n", 186 | " binxm, binym = simplebinmi.bin_calc_information2(saved_labelixs, activity, 0.5)\n", 187 | " cepochdata['MI_XM_bin'].append( nats2bits * binxm )\n", 188 | " cepochdata['MI_YM_bin'].append( nats2bits * binym )\n", 189 | " pstr += ' | bin: MI(X;M)=%0.3f, MI(Y;M)=%0.3f' % (cepochdata['MI_XM_bin'][-1], cepochdata['MI_YM_bin'][-1])\n", 190 | " \n", 191 | " print('- Layer %d %s' % (lndx, pstr) )\n", 192 | "\n", 193 | " measures[activation][epoch] = cepochdata\n", 194 | " " 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "Plot overall summaries\n", 202 | "----\n", 203 | "\n", 204 | "This is more for diagnostic purposes, not for article\n" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "scrolled": false 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "%matplotlib inline\n", 216 | "import matplotlib.pyplot as plt\n", 217 | "import matplotlib.gridspec as gridspec\n", 218 | "import seaborn as sns\n", 219 | "sns.set_style('darkgrid')\n", 220 | "\n", 221 | "\n", 222 | "#PLOT_LAYERS = [0,1,2,3,4] # [1,2,3]\n", 223 | "#PLOT_LAYERS = [0,1,2,3]\n", 224 | "#PLOT_LAYERS = [0,1,2,3]\n", 225 | "plt.figure(figsize=(8,8))\n", 226 | "gs = gridspec.GridSpec(4,2)\n", 227 | "for actndx, (activation, vals) in enumerate(measures.items()):\n", 228 | " epochs = sorted(vals.keys())\n", 229 | " if not len(epochs):\n", 230 | " continue\n", 231 | " \n", 232 | " plt.subplot(gs[0,actndx])\n", 233 | " for lndx, layerid in enumerate(PLOT_LAYERS):\n", 234 | " xmvalsU = np.array([vals[epoch]['H_M_upper'][layerid] for epoch in epochs])\n", 235 | " if DO_LOWER:\n", 236 | " xmvalsL = np.array([vals[epoch]['H_M_lower'][layerid] for epoch in epochs])\n", 237 | " plt.plot(epochs, xmvalsU, label='Layer %d'%layerid)\n", 238 | " #plt.errorbar(epochs, (xmvalsL + xmvalsU)/2,xmvalsU - xmvalsL, label='Layer %d'%layerid)\n", 239 | " plt.xscale('log')\n", 240 | " plt.yscale('log')\n", 241 | " plt.title(activation)\n", 242 | " plt.ylabel('H(M)')\n", 243 | " \n", 244 | " plt.subplot(gs[1,actndx])\n", 245 | " for lndx, layerid in enumerate(PLOT_LAYERS):\n", 246 | " #for epoch in epochs:\n", 247 | " # print('her',epoch, measures[activation][epoch]['MI_XM_upper'])\n", 248 | " xmvalsU = np.array([vals[epoch]['MI_XM_upper'][layerid] for epoch in epochs])\n", 249 | " if DO_LOWER:\n", 250 | " xmvalsL = np.array([vals[epoch]['MI_XM_lower'][layerid] for epoch in epochs])\n", 251 | " plt.plot(epochs, xmvalsU, label='Layer %d'%layerid)\n", 252 | " #plt.errorbar(epochs, (xmvalsL + xmvalsU)/2,xmvalsU - xmvalsL, label='Layer %d'%layerid)\n", 253 | " plt.xscale('log')\n", 254 | " plt.ylabel('I(X;M)')\n", 255 | "\n", 256 | "\n", 257 | " plt.subplot(gs[2,actndx])\n", 258 | " for lndx, layerid in enumerate(PLOT_LAYERS):\n", 259 | " ymvalsU = np.array([vals[epoch]['MI_YM_upper'][layerid] for epoch in epochs])\n", 260 | " if DO_LOWER:\n", 261 | " ymvalsL = np.array([vals[epoch]['MI_YM_lower'][layerid] for epoch in epochs])\n", 262 | " plt.plot(epochs, ymvalsU, label='Layer %d'%layerid)\n", 263 | " plt.xscale('log')\n", 264 | " plt.ylabel('MI(Y;M)')\n", 265 | "\n", 266 | " if DO_BINNED:\n", 267 | " plt.subplot(gs[3,actndx])\n", 268 | " for lndx, layerid in enumerate(PLOT_LAYERS):\n", 269 | " hbinnedvals = np.array([vals[epoch]['MI_XM_bin'][layerid] for epoch in epochs])\n", 270 | " plt.semilogx(epochs, hbinnedvals, label='Layer %d'%layerid)\n", 271 | " plt.xlabel('Epoch')\n", 272 | " plt.ylabel(\"I(X;M)bin\")\n", 273 | " \n", 274 | " if actndx == 0:\n", 275 | " plt.legend(loc='lower right')\n", 276 | " \n", 277 | "plt.tight_layout()" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "Plot Infoplane Visualization\n", 285 | "----" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": { 292 | "scrolled": false 293 | }, 294 | "outputs": [], 295 | "source": [ 296 | "max_epoch = max( (max(vals.keys()) if len(vals) else 0) for vals in measures.values())\n", 297 | "sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))\n", 298 | "sm._A = []\n", 299 | "\n", 300 | "fig=plt.figure(figsize=(10,5))\n", 301 | "for actndx, (activation, vals) in enumerate(measures.items()):\n", 302 | " epochs = sorted(vals.keys())\n", 303 | " if not len(epochs):\n", 304 | " continue\n", 305 | " plt.subplot(1,2,actndx+1) \n", 306 | " for epoch in epochs:\n", 307 | " c = sm.to_rgba(epoch)\n", 308 | " xmvals = np.array(vals[epoch]['MI_XM_'+infoplane_measure])[PLOT_LAYERS]\n", 309 | " ymvals = np.array(vals[epoch]['MI_YM_'+infoplane_measure])[PLOT_LAYERS]\n", 310 | "\n", 311 | " plt.plot(xmvals, ymvals, c=c, alpha=0.1, zorder=1)\n", 312 | " plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2)\n", 313 | " \n", 314 | " plt.ylim([0, 3.5])\n", 315 | " plt.xlim([0, 14])\n", 316 | " plt.xlabel('I(X;M)')\n", 317 | " plt.ylabel('I(Y;M)')\n", 318 | " plt.title(activation)\n", 319 | " \n", 320 | "cbaxes = fig.add_axes([1.0, 0.125, 0.03, 0.8]) \n", 321 | "plt.colorbar(sm, label='Epoch', cax=cbaxes)\n", 322 | "plt.tight_layout()\n", 323 | "\n", 324 | "if DO_SAVE:\n", 325 | " plt.savefig('plots/' + DIR_TEMPLATE % ('infoplane_'+ARCH),bbox_inches='tight')\n", 326 | " \n", 327 | " \n", 328 | " " 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "Plot SNR curves\n", 336 | "----" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "plt.figure(figsize=(12,5))\n", 346 | "\n", 347 | "gs = gridspec.GridSpec(len(measures), len(PLOT_LAYERS))\n", 348 | "saved_data = {}\n", 349 | "for actndx, activation in enumerate(measures.keys()):\n", 350 | " cur_dir = 'rawdata/' + DIR_TEMPLATE % activation\n", 351 | " if not os.path.exists(cur_dir):\n", 352 | " continue\n", 353 | " \n", 354 | " epochs = []\n", 355 | " means = []\n", 356 | " stds = []\n", 357 | " wnorms = []\n", 358 | " trnloss = []\n", 359 | " tstloss = []\n", 360 | " for epochfile in sorted(os.listdir(cur_dir)):\n", 361 | " if not epochfile.startswith('epoch'):\n", 362 | " continue\n", 363 | " \n", 364 | " with open(cur_dir + \"/\"+epochfile, 'rb') as f:\n", 365 | " try:\n", 366 | " d = cPickle.load(f)\n", 367 | " except:\n", 368 | " print('Error loading ', epochfile)\n", 369 | " continue\n", 370 | " \n", 371 | " epoch = d['epoch']\n", 372 | " epochs.append(epoch)\n", 373 | " wnorms.append(d['data']['weights_norm'])\n", 374 | " means.append(d['data']['gradmean'])\n", 375 | " stds.append(d['data']['gradstd'])\n", 376 | " trnloss.append(d['loss']['trn'])\n", 377 | " tstloss.append(d['loss']['tst'])\n", 378 | "\n", 379 | " wnorms, means, stds, trnloss, tstloss = map(np.array, [wnorms, means, stds, trnloss, tstloss])\n", 380 | " saved_data[activation] = {'epochs':epochs, 'wnorms':wnorms, 'means': means, 'stds': stds, 'trnloss': trnloss, 'tstloss':tstloss}\n", 381 | " \n", 382 | " \n", 383 | " for lndx,layerid in enumerate(PLOT_LAYERS):\n", 384 | " plt.subplot(gs[actndx, lndx])\n", 385 | " plt.plot(epochs, means[:,layerid], 'b', label=\"Mean\")\n", 386 | " plt.plot(epochs, stds[:,layerid], 'orange', label=\"Std\")\n", 387 | " plt.plot(epochs, means[:,layerid]/stds[:,layerid], 'red', label=\"SNR\")\n", 388 | " plt.plot(epochs, wnorms[:,layerid], 'g', label=\"||W||\")\n", 389 | "\n", 390 | " plt.title('%s - Layer %d'%(activation, layerid))\n", 391 | " plt.xlabel('Epoch')\n", 392 | " plt.gca().set_xscale(\"log\", nonposx='clip')\n", 393 | " plt.gca().set_yscale(\"log\", nonposy='clip')\n", 394 | " \n", 395 | "\n", 396 | "plt.legend(loc='lower left', bbox_to_anchor=(1.1, 0.2))\n", 397 | "plt.tight_layout()\n", 398 | "\n", 399 | "if DO_SAVE:\n", 400 | " plt.savefig('plots/' + DIR_TEMPLATE % ('snr_'+ARCH), bbox_inches='tight')\n" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "GRID_PLOT_LAYERS = [0,1,2,3] # [1,2,3]\n", 410 | "sns.set_style('whitegrid')\n", 411 | "max_epoch = max( (max(vals.keys()) if len(vals) else 0) for vals in measures.values())\n", 412 | "H_X = np.log2(10000)\n", 413 | "for actndx, (activation, vals) in enumerate(measures.items()):\n", 414 | " fig = plt.figure(figsize=(12,11))\n", 415 | " gs = gridspec.GridSpec(4, len(GRID_PLOT_LAYERS))\n", 416 | " \n", 417 | " cur_epochs = np.array(sorted(vals.keys()))\n", 418 | " if not len(cur_epochs):\n", 419 | " continue\n", 420 | " \n", 421 | " plt.subplot(gs[0,0])\n", 422 | " plt.title('Loss')\n", 423 | " plt.xlabel('Epoch')\n", 424 | " plt.plot(cur_epochs,saved_data[activation]['trnloss'][cur_epochs]/np.log(2), label='Train')\n", 425 | " plt.plot(cur_epochs,saved_data[activation]['tstloss'][cur_epochs]/np.log(2), label='Test')\n", 426 | " plt.ylabel('Cross entropy loss')\n", 427 | " plt.gca().set_xscale(\"log\", nonposx='clip')\n", 428 | " \n", 429 | " plt.legend(loc='upper right', frameon=True)\n", 430 | " \n", 431 | " vals_binned = np.array([vals[epoch]['MI_XM_bin'] for epoch in cur_epochs])\n", 432 | " vals_lower = np.array([vals[epoch]['MI_XM_lower'] for epoch in cur_epochs])\n", 433 | " vals_upper = np.array([vals[epoch]['MI_XM_upper'] for epoch in cur_epochs])\n", 434 | " for layerndx, layerid in enumerate(GRID_PLOT_LAYERS):\n", 435 | " plt.subplot(gs[1,layerndx])\n", 436 | " plt.plot(cur_epochs, cur_epochs*0 + H_X, 'k:', label=r'$H(X)$')\n", 437 | " plt.fill_between(cur_epochs, vals_lower[cur_epochs,layerid], vals_upper[:,layerid])\n", 438 | " plt.gca().set_xscale(\"log\", nonposx='clip')\n", 439 | " plt.ylim([0, 1.1*H_X])\n", 440 | " plt.title('Layer %d Mutual Info (KDE)'%(layerid+1))\n", 441 | " plt.ylabel(r'$I(X;T)$')\n", 442 | " plt.xlabel('Epoch')\n", 443 | " if layerndx == len(GRID_PLOT_LAYERS)-1:\n", 444 | " plt.legend(loc='lower right', frameon=True)\n", 445 | " \n", 446 | " plt.subplot(gs[2,layerndx])\n", 447 | " plt.plot(cur_epochs, cur_epochs*0 + H_X, 'k:', label=r'$H(X)$')\n", 448 | " plt.plot(cur_epochs, vals_binned[cur_epochs,layerid])\n", 449 | " plt.gca().set_xscale(\"log\", nonposx='clip')\n", 450 | " plt.ylim([0, 1.1*H_X])\n", 451 | " plt.ylabel(r'$I(X;T)$')\n", 452 | " plt.title('Layer %d Mutual Info (binned)'%(layerid+1))\n", 453 | " plt.xlabel('Epoch')\n", 454 | " if layerndx == len(GRID_PLOT_LAYERS)-1:\n", 455 | " plt.legend(loc='lower right', frameon=True)\n", 456 | " \n", 457 | " plt.subplot(gs[3,layerndx])\n", 458 | " plt.title('Layer %d SNR'%(layerid+1))\n", 459 | " plt.plot(cur_epochs, saved_data[activation]['means' ][cur_epochs,layerid], 'b', label=\"Mean\")\n", 460 | " plt.plot(cur_epochs, saved_data[activation]['stds' ][cur_epochs,layerid], 'orange', label=\"Std\")\n", 461 | " plt.plot(cur_epochs, saved_data[activation]['means' ][cur_epochs,layerid]/saved_data[activation]['stds'][cur_epochs,layerid], 'red', label=\"SNR\")\n", 462 | " plt.plot(cur_epochs, saved_data[activation]['wnorms'][cur_epochs,layerid], 'g', label=\"||W||\")\n", 463 | "\n", 464 | " plt.xlabel('Epoch')\n", 465 | " plt.gca().set_xscale(\"log\", nonposx='clip')\n", 466 | " plt.gca().set_yscale(\"log\", nonposy='clip')\n", 467 | " \n", 468 | " plt.tight_layout()\n", 469 | " plt.legend(loc='lower left', frameon=True)\n", 470 | " \n", 471 | " if DO_SAVE:\n", 472 | " plt.savefig('plots/' + DIR_TEMPLATE % ('gridplot_'+activation) + '.pdf', bbox_inches='tight')\n" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [] 481 | } 482 | ], 483 | "metadata": { 484 | "kernelspec": { 485 | "display_name": "Python 3", 486 | "language": "python", 487 | "name": "python3" 488 | }, 489 | "language_info": { 490 | "codemirror_mode": { 491 | "name": "ipython", 492 | "version": 3 493 | }, 494 | "file_extension": ".py", 495 | "mimetype": "text/x-python", 496 | "name": "python", 497 | "nbconvert_exporter": "python", 498 | "pygments_lexer": "ipython3", 499 | "version": "3.7.6" 500 | } 501 | }, 502 | "nbformat": 4, 503 | "nbformat_minor": 1 504 | } 505 | --------------------------------------------------------------------------------