├── README.md
├── classifiers.py
├── fractioncorrect.png
├── main.py
├── network.py
└── tuner.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Elastic weight consolidation
2 |
3 | ### Introduction
4 |
5 | A TensorFlow implementation of elastic weight consolidation as presented in [Overcoming catastrophic forgetting in neural networks](http://www.pnas.org/content/114/13/3521.full).
6 |
7 | ### Usage
8 |
9 | Perform hyperparameter search over learning rates for the permuted MNIST task (fisher multiplier locked at inverse learning rate):
10 | ```
11 | python -u main.py --hidden_layers 2 --hidden_units 800 --num_perms 5 --trials 50 --epochs 100
12 | ```
13 |
14 | ### Results
15 |
16 |
17 |

18 |
--------------------------------------------------------------------------------
/classifiers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 | from network import Network
5 |
6 |
7 | class Classifier(Network):
8 | """Supplies fully connected prediction model with training loop which absorbs minibatches and updates weights."""
9 |
10 | def __init__(self, checkpoint_path='logs/checkpoints/', summaries_path='logs/summaries/', *args, **kwargs):
11 | super(Classifier, self).__init__(*args, **kwargs)
12 | self.checkpoint_path = checkpoint_path
13 | self. summaries_path = summaries_path
14 | self.writer = None
15 | self.merged = None
16 | self.optimizer = None
17 | self.train_step = None
18 | self.accuracy = None
19 | self.loss = None
20 |
21 | self.create_loss_and_accuracy()
22 |
23 | def train(self, sess, model_name, model_init_name, dataset, num_updates, mini_batch_size, fisher_multiplier,
24 | learning_rate, log_frequency=None, dataset_lagged=None): # pass previous dataset as convenience
25 | print('training ' + model_name + ' with weights initialized at ' + str(model_init_name))
26 | self.prepare_for_training(sess, model_name, model_init_name, fisher_multiplier, learning_rate)
27 | for i in range(num_updates):
28 | self.minibatch_sgd(sess, i, dataset, mini_batch_size, log_frequency)
29 | self.update_fisher_full_batch(sess, dataset)
30 | self.save_weights(i, sess, model_name)
31 | print('finished training ' + model_name)
32 |
33 | def test(self, sess, model_name, batch_xs, batch_ys):
34 | self.restore_model(sess, model_name)
35 | feed_dict = self.create_feed_dict(batch_xs, batch_ys, keep_input=1.0, keep_hidden=1.0)
36 | accuracy = sess.run(self.accuracy, feed_dict=feed_dict)
37 | return accuracy
38 |
39 | def minibatch_sgd(self, sess, i, dataset, mini_batch_size, log_frequency):
40 | batch_xs, batch_ys = dataset.next_batch(mini_batch_size)
41 | feed_dict = self.create_feed_dict(batch_xs, batch_ys)
42 | sess.run(self.train_step, feed_dict=feed_dict)
43 | if log_frequency and i % log_frequency is 0:
44 | self.evaluate(sess, i, feed_dict)
45 |
46 | def evaluate(self, sess, iteration, feed_dict):
47 | if self.apply_dropout:
48 | feed_dict.update({self.keep_prob_input: 1.0, self.keep_prob_hidden: 1.0})
49 | summary, accuracy = sess.run([self.merged, self.accuracy], feed_dict=feed_dict)
50 | self.writer.add_summary(summary, iteration)
51 |
52 | def update_fisher_full_batch(self, sess, dataset):
53 | dataset._index_in_epoch = 0 # ensures that all training examples are included without repetitions
54 | sess.run(self.fisher_zero_op)
55 | for _ in range(0, self.ewc_batches):
56 | self.accumulate_fisher(sess, dataset)
57 | sess.run(self.fisher_full_batch_average_op)
58 | sess.run(self.update_theta_op)
59 |
60 | def accumulate_fisher(self, sess, dataset):
61 | batch_xs, batch_ys = dataset.next_batch(self.ewc_batch_size)
62 | sess.run(self.fisher_accumulate_op, feed_dict={self.x_fisher: batch_xs, self.y_fisher: batch_ys})
63 |
64 | def prepare_for_training(self, sess, model_name, model_init_name, fisher_multiplier, learning_rate):
65 | self.writer = tf.summary.FileWriter(self.summaries_path + model_name, sess.graph)
66 | self.merged = tf.summary.merge_all()
67 | self.train_step = self.create_train_step(fisher_multiplier if model_init_name else 0.0, learning_rate)
68 | init = tf.global_variables_initializer()
69 | sess.run(init)
70 | if model_init_name:
71 | self.restore_model(sess, model_init_name)
72 |
73 | def create_loss_and_accuracy(self):
74 | with tf.name_scope("loss"):
75 | average_nll = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.y)) # optimized
76 | tf.summary.scalar("loss", average_nll)
77 | self.loss = average_nll
78 | with tf.name_scope('accuracy'):
79 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.scores, 1), tf.argmax(self.y, 1)), tf.float32))
80 | tf.summary.scalar('accuracy', accuracy)
81 | self.accuracy = accuracy
82 |
83 | def create_train_step(self, fisher_multiplier, learning_rate):
84 | with tf.name_scope("optimizer"):
85 | self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
86 | penalty = tf.add_n([tf.reduce_sum(tf.square(w1-w2)*f) for w1, w2, f
87 | in zip(self.theta, self.theta_lagged, self.fisher_diagonal)])
88 | return self.optimizer.minimize(self.loss + (fisher_multiplier / 2) * penalty, var_list=self.theta)
89 |
90 | def save_weights(self, time_step, sess, model_name):
91 | if not os.path.exists(self.checkpoint_path):
92 | os.makedirs(self.checkpoint_path)
93 | self.saver.save(sess=sess, save_path=self.checkpoint_path + model_name + '.ckpt', global_step=time_step,
94 | latest_filename=model_name)
95 | print('saving model ' + model_name + ' at time step ' + str(time_step))
96 |
97 | def restore_model(self, sess, model_name):
98 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir=self.checkpoint_path, latest_filename=model_name)
99 | self.saver.restore(sess=sess, save_path=ckpt.model_checkpoint_path)
100 |
101 | def create_feed_dict(self, batch_xs, batch_ys, keep_hidden=0.5, keep_input=0.8):
102 | feed_dict = {self.x: batch_xs, self.y: batch_ys}
103 | if self.apply_dropout:
104 | feed_dict.update({self.keep_prob_hidden: keep_hidden, self.keep_prob_input: keep_input})
105 | return feed_dict
106 |
--------------------------------------------------------------------------------
/fractioncorrect.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stokesj/EWC/7e28b8d8cf0335962ef106d988cb45386bc008ed/fractioncorrect.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import tensorflow as tf
3 |
4 | from tuner import HyperparameterTuner
5 |
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--hidden_layers', type=int, default=2, help='the number of hidden layers')
10 | parser.add_argument('--hidden_units', type=int, default=800, help='the number of units per hidden layer')
11 | parser.add_argument('--num_perms', type=int, default=5, help='the number of tasks')
12 | parser.add_argument('--trials', type=int, default=50, help='the number of hyperparameter trials per task')
13 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs per task')
14 | return parser.parse_args()
15 |
16 |
17 | def main():
18 | args = parse_args()
19 | with tf.Session() as sess:
20 | tuner = HyperparameterTuner(sess=sess, hidden_layers=args.hidden_layers, hidden_units=args.hidden_units,
21 | num_perms=args.num_perms, trials=args.trials, epochs=args.epochs)
22 | tuner.search()
23 | print(tuner.best_parameters)
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class Network(object):
5 | """Creates the computation graph for a fully connected rectifier/dropout network
6 | prediction model and Fisher diagonal."""
7 |
8 | def __init__(self, num_features, num_class, fc_hidden_units, apply_dropout, ewc_batch_size=100, ewc_batches=550):
9 | self.num_features = num_features
10 | self.num_class = num_class
11 | self.fc_units = fc_hidden_units
12 | self.sizes = [self.num_features] + self.fc_units + [self.num_class]
13 | self.apply_dropout = apply_dropout
14 | self.ewc_batch_size = ewc_batch_size
15 | self.ewc_batches = ewc_batches
16 |
17 | self.x = None
18 | self.y = None
19 | self.x_fisher = None
20 | self.y_fisher = None
21 | self.keep_prob_input = None
22 | self.keep_prob_hidden = None
23 |
24 | self.biases = None
25 | self.weights = None
26 | self.theta = None
27 | self.biases_lagged = None
28 | self.weights_lagged = None
29 | self.theta_lagged = None
30 |
31 | self.scores = None
32 | self.fisher_diagonal = None
33 | self.fisher_minibatch = None
34 |
35 | self.fisher_accumulate_op = None
36 | self.fisher_full_batch_average_op = None
37 | self.fisher_zero_op = None
38 | self.update_theta_op = None
39 |
40 | self.create_graph()
41 |
42 | self.saver = tf.train.Saver(max_to_keep=1000, var_list=self.theta + self.theta_lagged + self.fisher_diagonal)
43 |
44 | def create_graph(self):
45 | self.create_placeholders()
46 | self.create_fc_variables()
47 | self.scores = self.fc_feedforward(self.x, self.biases, self.weights, self.apply_dropout)
48 | self.create_fisher_diagonal()
49 |
50 | def fc_feedforward(self, h, biases, weights, apply_dropout):
51 | if apply_dropout:
52 | h = tf.nn.dropout(h, self.keep_prob_input)
53 | for (w, b) in list(zip(weights, biases))[:-1]:
54 | h = self.create_fc_layer(h, w, b)
55 | if apply_dropout:
56 | h = tf.nn.dropout(h, self.keep_prob_hidden)
57 | return self.create_fc_layer(h, weights[-1], biases[-1], apply_relu=False)
58 |
59 | def create_fisher_diagonal(self):
60 | nll, biases_per_example, weights_per_example = self.unaggregated_nll()
61 | self.fisher_minibatch = self.fisher_minibatch_sum(nll, biases_per_example, weights_per_example)
62 | self.create_fisher_ops()
63 |
64 | def unaggregated_nll(self):
65 | x_examples = tf.unstack(self.x_fisher)
66 | y_examples = tf.unstack(self.y_fisher)
67 | biases_per_example = [self.clone_variable_list(self.biases) for _ in range(0, self.ewc_batch_size)]
68 | weights_per_example = [self.clone_variable_list(self.weights) for _ in range(0, self.ewc_batch_size)]
69 | nll_list = []
70 | for (x, y, biases, weights) in zip(x_examples, y_examples, biases_per_example, weights_per_example):
71 | scores = self.fc_feedforward(tf.reshape(x, [1, self.num_features]), biases, weights, apply_dropout=False)
72 | nll = - tf.reduce_sum(y * tf.nn.log_softmax(scores))
73 | nll_list.append(nll)
74 | nlls = tf.stack(nll_list)
75 | return tf.reduce_sum(nlls), biases_per_example, weights_per_example
76 |
77 | def fisher_minibatch_sum(self, nll_per_example, biases_per_example, weights_per_example):
78 | bias_grads_per_example = [tf.gradients(nll_per_example, biases) for biases in biases_per_example]
79 | weight_grads_per_example = [tf.gradients(nll_per_example, weights) for weights in weights_per_example]
80 | return self.sum_of_squared_gradients(bias_grads_per_example, weight_grads_per_example)
81 |
82 | def sum_of_squared_gradients(self, bias_grads_per_example, weight_grads_per_example):
83 | bias_grads2_sum = []
84 | weight_grads2_sum = []
85 | for layer in range(0, len(self.fc_units) + 1):
86 | bias_grad2_sum = tf.add_n([tf.square(example[layer]) for example in bias_grads_per_example])
87 | weight_grad2_sum = tf.add_n([tf.square(example[layer]) for example in weight_grads_per_example])
88 | bias_grads2_sum.append(bias_grad2_sum)
89 | weight_grads2_sum.append(weight_grad2_sum)
90 | return bias_grads2_sum + weight_grads2_sum
91 |
92 | def create_fisher_ops(self):
93 | self.fisher_diagonal = self.bias_shaped_variables(name='bias_grads2', c=0.0, trainable=False) +\
94 | self.weight_shaped_variables(name='weight_grads2', c=0.0, trainable=False)
95 |
96 | self.fisher_accumulate_op = [tf.assign_add(f1, f2) for f1, f2 in zip(self.fisher_diagonal, self.fisher_minibatch)]
97 | scale = 1 / float(self.ewc_batches * self.ewc_batch_size)
98 | self.fisher_full_batch_average_op = [tf.assign(var, scale * var) for var in self.fisher_diagonal]
99 | self.fisher_zero_op = [tf.assign(tensor, tf.zeros_like(tensor)) for tensor in self.fisher_diagonal]
100 |
101 | @staticmethod
102 | def create_fc_layer(input, w, b, apply_relu=True):
103 | with tf.name_scope('fc_layer'):
104 | output = tf.matmul(input, w) + b
105 | if apply_relu:
106 | output = tf.nn.relu(output)
107 | return output
108 |
109 | @staticmethod
110 | def create_variable(shape, name, c=None, sigma=None, trainable=True):
111 | if sigma:
112 | initial = tf.truncated_normal(shape, stddev=sigma, name=name)
113 | else:
114 | initial = tf.constant(c if c else 0.0, shape=shape, name=name)
115 | return tf.Variable(initial, trainable=trainable)
116 |
117 | @staticmethod
118 | def clone_variable_list(variable_list):
119 | return [tf.identity(var) for var in variable_list]
120 |
121 | def bias_shaped_variables(self, name, c=None, sigma=None, trainable=True):
122 | return [self.create_variable(shape=[i], name=name + '{}'.format(layer + 1),
123 | c=c, sigma=sigma, trainable=trainable) for layer, i in enumerate(self.sizes[1:])]
124 |
125 | def weight_shaped_variables(self, name, c=None, sigma=None, trainable=True):
126 | return [self.create_variable([i, j], name=name + '{}'.format(layer + 1),
127 | c=c, sigma=sigma, trainable=trainable)
128 | for layer, (i, j) in enumerate(zip(self.sizes[:-1], self.sizes[1:]))]
129 |
130 | def create_fc_variables(self):
131 | with tf.name_scope('fc_variables'):
132 | self.biases = self.bias_shaped_variables(name='biases_fc', c=0.1, trainable=True)
133 | self.weights = self.weight_shaped_variables(name='weights_fc', sigma=0.1, trainable=True)
134 | self.theta = self.biases + self.weights
135 | with tf.name_scope('fc_variables_lagged'):
136 | self.biases_lagged = self.bias_shaped_variables(name='biases_fc_lagged', c=0.0, trainable=False)
137 | self.weights_lagged = self.weight_shaped_variables(name='weights_fc_lagged', c=0.0, trainable=False)
138 | self.theta_lagged = self.biases_lagged + self.weights_lagged
139 | self.update_theta_op = [v1.assign(v2) for v1, v2 in zip(self.theta_lagged, self.theta)]
140 |
141 | def create_placeholders(self):
142 | with tf.name_scope("prediction-inputs"):
143 | self.x = tf.placeholder(tf.float32, [None, self.num_features], name='x-input')
144 | self.y = tf.placeholder(tf.float32, [None, self.num_class], name='y-input')
145 | with tf.name_scope("dropout-probabilities"):
146 | self.keep_prob_input = tf.placeholder(tf.float32)
147 | self.keep_prob_hidden = tf.placeholder(tf.float32)
148 | with tf.name_scope("fisher-inputs"):
149 | self.x_fisher = tf.placeholder(tf.float32, [self.ewc_batch_size, self.num_features])
150 | self.y_fisher = tf.placeholder(tf.float32, [self.ewc_batch_size, self.num_class])
--------------------------------------------------------------------------------
/tuner.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from copy import deepcopy
4 | from classifiers import Classifier
5 | from numpy.random import RandomState
6 | from queue import PriorityQueue
7 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
8 |
9 |
10 | PRNG = RandomState(12345)
11 | MINI_BATCH_SIZE = 250
12 | LOG_FREQUENCY = 1000
13 |
14 |
15 | class HyperparameterTuner(object):
16 | def __init__(self, sess, hidden_layers, hidden_units, num_perms, trials, epochs):
17 | self.hidden_layers = hidden_layers
18 | self.hidden_units = hidden_units
19 | self.num_perms = num_perms
20 | self.epochs = epochs
21 | self.task_list = self.create_permuted_mnist_task(num_perms)
22 | self.trial_learning_rates = [PRNG.uniform(1e-4, 1e-3) for _ in range(0, trials)]
23 | self.best_parameters = []
24 | self.sess = sess
25 | self.classifier = Classifier(num_class=10,
26 | num_features=784,
27 | fc_hidden_units=[hidden_units for _ in range(hidden_layers)],
28 | apply_dropout=True)
29 |
30 | def search(self):
31 | for t in range(0, self.num_perms):
32 | queue = PriorityQueue()
33 | for learning_rate in self.trial_learning_rates:
34 | self.train_on_task(t, learning_rate, queue)
35 | self.best_parameters.append(queue.get())
36 | self.evaluate()
37 |
38 | def evaluate(self):
39 | accuracies = []
40 | for parameters in self.best_parameters:
41 | accuracy = self.classifier.test(sess=self.sess,
42 | model_name=parameters[1],
43 | batch_xs=self.task_list[0].test.images,
44 | batch_ys=self.task_list[0].test.labels)
45 | accuracies.append(accuracy)
46 | print(accuracies)
47 |
48 | def train_on_task(self, t, lr, queue):
49 | model_name = self.file_name(lr, t)
50 | dataset_train = self.task_list[t].train
51 | dataset_lagged = self.task_list[t - 1] if t > 0 else None
52 | model_init_name = self.best_parameters[t - 1][1] if t > 0 else None
53 | self.classifier.train(sess=self.sess,
54 | model_name=model_name,
55 | model_init_name=model_init_name,
56 | dataset=dataset_train,
57 | dataset_lagged=dataset_lagged,
58 | num_updates=(55000//MINI_BATCH_SIZE)*self.epochs,
59 | mini_batch_size=MINI_BATCH_SIZE,
60 | log_frequency=LOG_FREQUENCY,
61 | fisher_multiplier=1.0/lr,
62 | learning_rate=lr)
63 | accuracy = self.classifier.test(sess=self.sess,
64 | model_name=model_name,
65 | batch_xs=self.task_list[0].validation.images,
66 | batch_ys=self.task_list[0].validation.labels)
67 | queue.put((-accuracy, model_name))
68 |
69 | def create_permuted_mnist_task(self, num_datasets):
70 | mnist = read_data_sets("MNIST_data/", one_hot=True)
71 | task_list = [mnist]
72 | for seed in range(1, num_datasets):
73 | task_list.append(self.permute(mnist, seed))
74 | return task_list
75 |
76 | @staticmethod
77 | def permute(task, seed):
78 | np.random.seed(seed)
79 | perm = np.random.permutation(task.train._images.shape[1])
80 | permuted = deepcopy(task)
81 | permuted.train._images = permuted.train._images[:, perm]
82 | permuted.test._images = permuted.test._images[:, perm]
83 | permuted.validation._images = permuted.validation._images[:, perm]
84 | return permuted
85 |
86 | def file_name(self, lr, t):
87 | return 'layers=%d,hidden=%d,lr=%.5f,multiplier=%.2f,mbsize=%d,epochs=%d,perm=%d' \
88 | % (self.hidden_layers, self.hidden_units, lr, 1 / lr, MINI_BATCH_SIZE, self.epochs, t)
89 |
90 |
--------------------------------------------------------------------------------