├── README.md ├── classifiers.py ├── fractioncorrect.png ├── main.py ├── network.py └── tuner.py /README.md: -------------------------------------------------------------------------------- 1 | ## Elastic weight consolidation 2 | 3 | ### Introduction 4 | 5 | A TensorFlow implementation of elastic weight consolidation as presented in [Overcoming catastrophic forgetting in neural networks](http://www.pnas.org/content/114/13/3521.full). 6 | 7 | ### Usage 8 | 9 | Perform hyperparameter search over learning rates for the permuted MNIST task (fisher multiplier locked at inverse learning rate): 10 | ``` 11 | python -u main.py --hidden_layers 2 --hidden_units 800 --num_perms 5 --trials 50 --epochs 100 12 | ``` 13 | 14 | ### Results 15 | 16 |
17 |

18 |
-------------------------------------------------------------------------------- /classifiers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from network import Network 5 | 6 | 7 | class Classifier(Network): 8 | """Supplies fully connected prediction model with training loop which absorbs minibatches and updates weights.""" 9 | 10 | def __init__(self, checkpoint_path='logs/checkpoints/', summaries_path='logs/summaries/', *args, **kwargs): 11 | super(Classifier, self).__init__(*args, **kwargs) 12 | self.checkpoint_path = checkpoint_path 13 | self. summaries_path = summaries_path 14 | self.writer = None 15 | self.merged = None 16 | self.optimizer = None 17 | self.train_step = None 18 | self.accuracy = None 19 | self.loss = None 20 | 21 | self.create_loss_and_accuracy() 22 | 23 | def train(self, sess, model_name, model_init_name, dataset, num_updates, mini_batch_size, fisher_multiplier, 24 | learning_rate, log_frequency=None, dataset_lagged=None): # pass previous dataset as convenience 25 | print('training ' + model_name + ' with weights initialized at ' + str(model_init_name)) 26 | self.prepare_for_training(sess, model_name, model_init_name, fisher_multiplier, learning_rate) 27 | for i in range(num_updates): 28 | self.minibatch_sgd(sess, i, dataset, mini_batch_size, log_frequency) 29 | self.update_fisher_full_batch(sess, dataset) 30 | self.save_weights(i, sess, model_name) 31 | print('finished training ' + model_name) 32 | 33 | def test(self, sess, model_name, batch_xs, batch_ys): 34 | self.restore_model(sess, model_name) 35 | feed_dict = self.create_feed_dict(batch_xs, batch_ys, keep_input=1.0, keep_hidden=1.0) 36 | accuracy = sess.run(self.accuracy, feed_dict=feed_dict) 37 | return accuracy 38 | 39 | def minibatch_sgd(self, sess, i, dataset, mini_batch_size, log_frequency): 40 | batch_xs, batch_ys = dataset.next_batch(mini_batch_size) 41 | feed_dict = self.create_feed_dict(batch_xs, batch_ys) 42 | sess.run(self.train_step, feed_dict=feed_dict) 43 | if log_frequency and i % log_frequency is 0: 44 | self.evaluate(sess, i, feed_dict) 45 | 46 | def evaluate(self, sess, iteration, feed_dict): 47 | if self.apply_dropout: 48 | feed_dict.update({self.keep_prob_input: 1.0, self.keep_prob_hidden: 1.0}) 49 | summary, accuracy = sess.run([self.merged, self.accuracy], feed_dict=feed_dict) 50 | self.writer.add_summary(summary, iteration) 51 | 52 | def update_fisher_full_batch(self, sess, dataset): 53 | dataset._index_in_epoch = 0 # ensures that all training examples are included without repetitions 54 | sess.run(self.fisher_zero_op) 55 | for _ in range(0, self.ewc_batches): 56 | self.accumulate_fisher(sess, dataset) 57 | sess.run(self.fisher_full_batch_average_op) 58 | sess.run(self.update_theta_op) 59 | 60 | def accumulate_fisher(self, sess, dataset): 61 | batch_xs, batch_ys = dataset.next_batch(self.ewc_batch_size) 62 | sess.run(self.fisher_accumulate_op, feed_dict={self.x_fisher: batch_xs, self.y_fisher: batch_ys}) 63 | 64 | def prepare_for_training(self, sess, model_name, model_init_name, fisher_multiplier, learning_rate): 65 | self.writer = tf.summary.FileWriter(self.summaries_path + model_name, sess.graph) 66 | self.merged = tf.summary.merge_all() 67 | self.train_step = self.create_train_step(fisher_multiplier if model_init_name else 0.0, learning_rate) 68 | init = tf.global_variables_initializer() 69 | sess.run(init) 70 | if model_init_name: 71 | self.restore_model(sess, model_init_name) 72 | 73 | def create_loss_and_accuracy(self): 74 | with tf.name_scope("loss"): 75 | average_nll = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.y)) # optimized 76 | tf.summary.scalar("loss", average_nll) 77 | self.loss = average_nll 78 | with tf.name_scope('accuracy'): 79 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.scores, 1), tf.argmax(self.y, 1)), tf.float32)) 80 | tf.summary.scalar('accuracy', accuracy) 81 | self.accuracy = accuracy 82 | 83 | def create_train_step(self, fisher_multiplier, learning_rate): 84 | with tf.name_scope("optimizer"): 85 | self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 86 | penalty = tf.add_n([tf.reduce_sum(tf.square(w1-w2)*f) for w1, w2, f 87 | in zip(self.theta, self.theta_lagged, self.fisher_diagonal)]) 88 | return self.optimizer.minimize(self.loss + (fisher_multiplier / 2) * penalty, var_list=self.theta) 89 | 90 | def save_weights(self, time_step, sess, model_name): 91 | if not os.path.exists(self.checkpoint_path): 92 | os.makedirs(self.checkpoint_path) 93 | self.saver.save(sess=sess, save_path=self.checkpoint_path + model_name + '.ckpt', global_step=time_step, 94 | latest_filename=model_name) 95 | print('saving model ' + model_name + ' at time step ' + str(time_step)) 96 | 97 | def restore_model(self, sess, model_name): 98 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir=self.checkpoint_path, latest_filename=model_name) 99 | self.saver.restore(sess=sess, save_path=ckpt.model_checkpoint_path) 100 | 101 | def create_feed_dict(self, batch_xs, batch_ys, keep_hidden=0.5, keep_input=0.8): 102 | feed_dict = {self.x: batch_xs, self.y: batch_ys} 103 | if self.apply_dropout: 104 | feed_dict.update({self.keep_prob_hidden: keep_hidden, self.keep_prob_input: keep_input}) 105 | return feed_dict 106 | -------------------------------------------------------------------------------- /fractioncorrect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stokesj/EWC/7e28b8d8cf0335962ef106d988cb45386bc008ed/fractioncorrect.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | 4 | from tuner import HyperparameterTuner 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--hidden_layers', type=int, default=2, help='the number of hidden layers') 10 | parser.add_argument('--hidden_units', type=int, default=800, help='the number of units per hidden layer') 11 | parser.add_argument('--num_perms', type=int, default=5, help='the number of tasks') 12 | parser.add_argument('--trials', type=int, default=50, help='the number of hyperparameter trials per task') 13 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs per task') 14 | return parser.parse_args() 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | with tf.Session() as sess: 20 | tuner = HyperparameterTuner(sess=sess, hidden_layers=args.hidden_layers, hidden_units=args.hidden_units, 21 | num_perms=args.num_perms, trials=args.trials, epochs=args.epochs) 22 | tuner.search() 23 | print(tuner.best_parameters) 24 | 25 | if __name__ == "__main__": 26 | main() 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Network(object): 5 | """Creates the computation graph for a fully connected rectifier/dropout network 6 | prediction model and Fisher diagonal.""" 7 | 8 | def __init__(self, num_features, num_class, fc_hidden_units, apply_dropout, ewc_batch_size=100, ewc_batches=550): 9 | self.num_features = num_features 10 | self.num_class = num_class 11 | self.fc_units = fc_hidden_units 12 | self.sizes = [self.num_features] + self.fc_units + [self.num_class] 13 | self.apply_dropout = apply_dropout 14 | self.ewc_batch_size = ewc_batch_size 15 | self.ewc_batches = ewc_batches 16 | 17 | self.x = None 18 | self.y = None 19 | self.x_fisher = None 20 | self.y_fisher = None 21 | self.keep_prob_input = None 22 | self.keep_prob_hidden = None 23 | 24 | self.biases = None 25 | self.weights = None 26 | self.theta = None 27 | self.biases_lagged = None 28 | self.weights_lagged = None 29 | self.theta_lagged = None 30 | 31 | self.scores = None 32 | self.fisher_diagonal = None 33 | self.fisher_minibatch = None 34 | 35 | self.fisher_accumulate_op = None 36 | self.fisher_full_batch_average_op = None 37 | self.fisher_zero_op = None 38 | self.update_theta_op = None 39 | 40 | self.create_graph() 41 | 42 | self.saver = tf.train.Saver(max_to_keep=1000, var_list=self.theta + self.theta_lagged + self.fisher_diagonal) 43 | 44 | def create_graph(self): 45 | self.create_placeholders() 46 | self.create_fc_variables() 47 | self.scores = self.fc_feedforward(self.x, self.biases, self.weights, self.apply_dropout) 48 | self.create_fisher_diagonal() 49 | 50 | def fc_feedforward(self, h, biases, weights, apply_dropout): 51 | if apply_dropout: 52 | h = tf.nn.dropout(h, self.keep_prob_input) 53 | for (w, b) in list(zip(weights, biases))[:-1]: 54 | h = self.create_fc_layer(h, w, b) 55 | if apply_dropout: 56 | h = tf.nn.dropout(h, self.keep_prob_hidden) 57 | return self.create_fc_layer(h, weights[-1], biases[-1], apply_relu=False) 58 | 59 | def create_fisher_diagonal(self): 60 | nll, biases_per_example, weights_per_example = self.unaggregated_nll() 61 | self.fisher_minibatch = self.fisher_minibatch_sum(nll, biases_per_example, weights_per_example) 62 | self.create_fisher_ops() 63 | 64 | def unaggregated_nll(self): 65 | x_examples = tf.unstack(self.x_fisher) 66 | y_examples = tf.unstack(self.y_fisher) 67 | biases_per_example = [self.clone_variable_list(self.biases) for _ in range(0, self.ewc_batch_size)] 68 | weights_per_example = [self.clone_variable_list(self.weights) for _ in range(0, self.ewc_batch_size)] 69 | nll_list = [] 70 | for (x, y, biases, weights) in zip(x_examples, y_examples, biases_per_example, weights_per_example): 71 | scores = self.fc_feedforward(tf.reshape(x, [1, self.num_features]), biases, weights, apply_dropout=False) 72 | nll = - tf.reduce_sum(y * tf.nn.log_softmax(scores)) 73 | nll_list.append(nll) 74 | nlls = tf.stack(nll_list) 75 | return tf.reduce_sum(nlls), biases_per_example, weights_per_example 76 | 77 | def fisher_minibatch_sum(self, nll_per_example, biases_per_example, weights_per_example): 78 | bias_grads_per_example = [tf.gradients(nll_per_example, biases) for biases in biases_per_example] 79 | weight_grads_per_example = [tf.gradients(nll_per_example, weights) for weights in weights_per_example] 80 | return self.sum_of_squared_gradients(bias_grads_per_example, weight_grads_per_example) 81 | 82 | def sum_of_squared_gradients(self, bias_grads_per_example, weight_grads_per_example): 83 | bias_grads2_sum = [] 84 | weight_grads2_sum = [] 85 | for layer in range(0, len(self.fc_units) + 1): 86 | bias_grad2_sum = tf.add_n([tf.square(example[layer]) for example in bias_grads_per_example]) 87 | weight_grad2_sum = tf.add_n([tf.square(example[layer]) for example in weight_grads_per_example]) 88 | bias_grads2_sum.append(bias_grad2_sum) 89 | weight_grads2_sum.append(weight_grad2_sum) 90 | return bias_grads2_sum + weight_grads2_sum 91 | 92 | def create_fisher_ops(self): 93 | self.fisher_diagonal = self.bias_shaped_variables(name='bias_grads2', c=0.0, trainable=False) +\ 94 | self.weight_shaped_variables(name='weight_grads2', c=0.0, trainable=False) 95 | 96 | self.fisher_accumulate_op = [tf.assign_add(f1, f2) for f1, f2 in zip(self.fisher_diagonal, self.fisher_minibatch)] 97 | scale = 1 / float(self.ewc_batches * self.ewc_batch_size) 98 | self.fisher_full_batch_average_op = [tf.assign(var, scale * var) for var in self.fisher_diagonal] 99 | self.fisher_zero_op = [tf.assign(tensor, tf.zeros_like(tensor)) for tensor in self.fisher_diagonal] 100 | 101 | @staticmethod 102 | def create_fc_layer(input, w, b, apply_relu=True): 103 | with tf.name_scope('fc_layer'): 104 | output = tf.matmul(input, w) + b 105 | if apply_relu: 106 | output = tf.nn.relu(output) 107 | return output 108 | 109 | @staticmethod 110 | def create_variable(shape, name, c=None, sigma=None, trainable=True): 111 | if sigma: 112 | initial = tf.truncated_normal(shape, stddev=sigma, name=name) 113 | else: 114 | initial = tf.constant(c if c else 0.0, shape=shape, name=name) 115 | return tf.Variable(initial, trainable=trainable) 116 | 117 | @staticmethod 118 | def clone_variable_list(variable_list): 119 | return [tf.identity(var) for var in variable_list] 120 | 121 | def bias_shaped_variables(self, name, c=None, sigma=None, trainable=True): 122 | return [self.create_variable(shape=[i], name=name + '{}'.format(layer + 1), 123 | c=c, sigma=sigma, trainable=trainable) for layer, i in enumerate(self.sizes[1:])] 124 | 125 | def weight_shaped_variables(self, name, c=None, sigma=None, trainable=True): 126 | return [self.create_variable([i, j], name=name + '{}'.format(layer + 1), 127 | c=c, sigma=sigma, trainable=trainable) 128 | for layer, (i, j) in enumerate(zip(self.sizes[:-1], self.sizes[1:]))] 129 | 130 | def create_fc_variables(self): 131 | with tf.name_scope('fc_variables'): 132 | self.biases = self.bias_shaped_variables(name='biases_fc', c=0.1, trainable=True) 133 | self.weights = self.weight_shaped_variables(name='weights_fc', sigma=0.1, trainable=True) 134 | self.theta = self.biases + self.weights 135 | with tf.name_scope('fc_variables_lagged'): 136 | self.biases_lagged = self.bias_shaped_variables(name='biases_fc_lagged', c=0.0, trainable=False) 137 | self.weights_lagged = self.weight_shaped_variables(name='weights_fc_lagged', c=0.0, trainable=False) 138 | self.theta_lagged = self.biases_lagged + self.weights_lagged 139 | self.update_theta_op = [v1.assign(v2) for v1, v2 in zip(self.theta_lagged, self.theta)] 140 | 141 | def create_placeholders(self): 142 | with tf.name_scope("prediction-inputs"): 143 | self.x = tf.placeholder(tf.float32, [None, self.num_features], name='x-input') 144 | self.y = tf.placeholder(tf.float32, [None, self.num_class], name='y-input') 145 | with tf.name_scope("dropout-probabilities"): 146 | self.keep_prob_input = tf.placeholder(tf.float32) 147 | self.keep_prob_hidden = tf.placeholder(tf.float32) 148 | with tf.name_scope("fisher-inputs"): 149 | self.x_fisher = tf.placeholder(tf.float32, [self.ewc_batch_size, self.num_features]) 150 | self.y_fisher = tf.placeholder(tf.float32, [self.ewc_batch_size, self.num_class]) -------------------------------------------------------------------------------- /tuner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from copy import deepcopy 4 | from classifiers import Classifier 5 | from numpy.random import RandomState 6 | from queue import PriorityQueue 7 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 8 | 9 | 10 | PRNG = RandomState(12345) 11 | MINI_BATCH_SIZE = 250 12 | LOG_FREQUENCY = 1000 13 | 14 | 15 | class HyperparameterTuner(object): 16 | def __init__(self, sess, hidden_layers, hidden_units, num_perms, trials, epochs): 17 | self.hidden_layers = hidden_layers 18 | self.hidden_units = hidden_units 19 | self.num_perms = num_perms 20 | self.epochs = epochs 21 | self.task_list = self.create_permuted_mnist_task(num_perms) 22 | self.trial_learning_rates = [PRNG.uniform(1e-4, 1e-3) for _ in range(0, trials)] 23 | self.best_parameters = [] 24 | self.sess = sess 25 | self.classifier = Classifier(num_class=10, 26 | num_features=784, 27 | fc_hidden_units=[hidden_units for _ in range(hidden_layers)], 28 | apply_dropout=True) 29 | 30 | def search(self): 31 | for t in range(0, self.num_perms): 32 | queue = PriorityQueue() 33 | for learning_rate in self.trial_learning_rates: 34 | self.train_on_task(t, learning_rate, queue) 35 | self.best_parameters.append(queue.get()) 36 | self.evaluate() 37 | 38 | def evaluate(self): 39 | accuracies = [] 40 | for parameters in self.best_parameters: 41 | accuracy = self.classifier.test(sess=self.sess, 42 | model_name=parameters[1], 43 | batch_xs=self.task_list[0].test.images, 44 | batch_ys=self.task_list[0].test.labels) 45 | accuracies.append(accuracy) 46 | print(accuracies) 47 | 48 | def train_on_task(self, t, lr, queue): 49 | model_name = self.file_name(lr, t) 50 | dataset_train = self.task_list[t].train 51 | dataset_lagged = self.task_list[t - 1] if t > 0 else None 52 | model_init_name = self.best_parameters[t - 1][1] if t > 0 else None 53 | self.classifier.train(sess=self.sess, 54 | model_name=model_name, 55 | model_init_name=model_init_name, 56 | dataset=dataset_train, 57 | dataset_lagged=dataset_lagged, 58 | num_updates=(55000//MINI_BATCH_SIZE)*self.epochs, 59 | mini_batch_size=MINI_BATCH_SIZE, 60 | log_frequency=LOG_FREQUENCY, 61 | fisher_multiplier=1.0/lr, 62 | learning_rate=lr) 63 | accuracy = self.classifier.test(sess=self.sess, 64 | model_name=model_name, 65 | batch_xs=self.task_list[0].validation.images, 66 | batch_ys=self.task_list[0].validation.labels) 67 | queue.put((-accuracy, model_name)) 68 | 69 | def create_permuted_mnist_task(self, num_datasets): 70 | mnist = read_data_sets("MNIST_data/", one_hot=True) 71 | task_list = [mnist] 72 | for seed in range(1, num_datasets): 73 | task_list.append(self.permute(mnist, seed)) 74 | return task_list 75 | 76 | @staticmethod 77 | def permute(task, seed): 78 | np.random.seed(seed) 79 | perm = np.random.permutation(task.train._images.shape[1]) 80 | permuted = deepcopy(task) 81 | permuted.train._images = permuted.train._images[:, perm] 82 | permuted.test._images = permuted.test._images[:, perm] 83 | permuted.validation._images = permuted.validation._images[:, perm] 84 | return permuted 85 | 86 | def file_name(self, lr, t): 87 | return 'layers=%d,hidden=%d,lr=%.5f,multiplier=%.2f,mbsize=%d,epochs=%d,perm=%d' \ 88 | % (self.hidden_layers, self.hidden_units, lr, 1 / lr, MINI_BATCH_SIZE, self.epochs, t) 89 | 90 | --------------------------------------------------------------------------------