├── README.md ├── data └── mnist.pkl.gz ├── datasets.py ├── main_cifar.py ├── main_permutedmnist.py ├── main_splitmnist.py ├── main_toydata.py ├── models.py ├── opt_fromp.py ├── requirements.txt ├── train_cifar.py ├── train_permutedmnist.py ├── train_splitmnist.py ├── train_toydata.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # FROMP 2 | 3 | Contains code for the NeurIPS 2020 paper by Pan et al., "[Continual Deep Learning by Functional Regularisation of Memorable Past](https://arxiv.org/abs/2004.14070)". 4 | 5 | ## Continual learning with functional regularisation 6 | 7 | FROMP performs continual continual by functionally regularising on a few memorable past datapoints, in order to avoid forgetting past information. 8 | 9 | FROMP's functional regularisation is implemented in ``opt_fromp.py``. This is a PyTorch optimiser (built upon PyTorch's Adam optimiser), with FROMP's relevant additions. The important lines are in the ``step()`` function. 10 | 11 | The provided scripts replicate FROMP's results from the paper. Running the code as in the Table below yields the reported average accuracy. 12 | The files will run each experiment once. Change the ``num_runs`` variable to obtain mean and standard deviation over many runs (as reported in the paper). 13 | 14 | | Benchmark | File | Average accuracy | 15 | |--- |--- |--- | 16 | | Split MNIST | ``main_splitmnist.py`` | 99.3% | 17 | | Permuted MNIST | ``main_permutedmnist.py`` | 94.8% | 18 | | Split CIFAR | ``main_cifar.py`` | 76.2% | 19 | | Toy dataset | ``main_toydata.py`` | (Visualisation) | 20 | 21 | ### Further details 22 | 23 | The code was run with ``Python 3.7``, ``PyTorch v1.2``. For the full environment, see ``requirements.txt``. 24 | 25 | Hyperparameters (reported in Appendix F of the paper) are set in the ``main_*.py`` files. More detailed code is in the corresponding ``train_*.py`` files. 26 | 27 | This code was written by Siddharth Swaroop and Pingbo Pan. Please raise issues here via github, or contact [Siddharth](mailto:ss2163@cam.ac.uk). 28 | 29 | ## Citation 30 | 31 | ``` 32 | @article{pan2020continual, 33 | title = {Continual Deep Learning by Functional Regularisation of Memorable Past}, 34 | author = {Pan, Pingbo and Swaroop, Siddharth and Immer, Alexander and Eschenhagen, Runa and Turner, Richard E and Khan, Mohammad Emtiyaz}, 35 | journal = {Advances in neural information processing systems}, 36 | year = {2020} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /data/mnist.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/team-approx-bayes/fromp/c9606ecea72ff3309c1887ec3b0404d4d7d87906/data/mnist.pkl.gz -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.datasets.samples_generator import make_blobs 4 | from torch.utils.data import TensorDataset, Dataset, Subset 5 | import pickle 6 | import gzip 7 | from copy import deepcopy 8 | 9 | 10 | class ToydataGenerator(): 11 | def __init__(self, max_iter=5, num_samples=2000, option=0): 12 | 13 | self.offset = 5 # Offset when loading data in next_task() 14 | 15 | # Generate data 16 | if option == 0: 17 | # Standard settings 18 | centers = [[0, 0.2], [0.6, 0.9], [1.3, 0.4], [1.6, -0.1], [2.0, 0.3], 19 | [0.45, 0], [0.7, 0.45], [1., 0.1], [1.7, -0.4], [2.3, 0.1]] 20 | std = [[0.08, 0.22], [0.24, 0.08], [0.04, 0.2], [0.16, 0.05], [0.05, 0.16], 21 | [0.08, 0.16], [0.16, 0.08], [0.06, 0.16], [0.24, 0.05], [0.05, 0.22]] 22 | 23 | elif option == 1: 24 | # Six tasks 25 | centers = [[0, 0.2], [0.6, 0.9], [1.3, 0.4], [1.6, -0.1], [2.0, 0.3], [1.65, 0.1], 26 | [0.45, 0], [0.7, 0.45], [1., 0.1], [1.7, -0.4], [2.3, 0.1], [0.7, 0.25]] 27 | std = [[0.08, 0.22], [0.24, 0.08], [0.04, 0.2], [0.16, 0.05], [0.05, 0.16], [0.14, 0.14], 28 | [0.08, 0.16], [0.16, 0.08], [0.06, 0.16], [0.24, 0.05], [0.05, 0.22], [0.14, 0.14]] 29 | 30 | elif option == 2: 31 | # All std devs increased 32 | centers = [[0, 0.2], [0.6, 0.9], [1.3, 0.4], [1.6, -0.1], [2.0, 0.3], 33 | [0.45, 0], [0.7, 0.45], [1., 0.1], [1.7, -0.4], [2.3, 0.1]] 34 | std = [[0.12, 0.22], [0.24, 0.12], [0.07, 0.2], [0.16, 0.08], [0.08, 0.16], 35 | [0.12, 0.16], [0.16, 0.12], [0.08, 0.16], [0.24, 0.08], [0.08, 0.22]] 36 | 37 | elif option == 3: 38 | # Tougher to separate 39 | centers = [[0, 0.2], [0.6, 0.65], [1.3, 0.4], [1.6, -0.22], [2.0, 0.3], 40 | [0.45, 0], [0.7, 0.55], [1., 0.1], [1.7, -0.3], [2.3, 0.1]] 41 | std = [[0.08, 0.22], [0.24, 0.08], [0.04, 0.2], [0.16, 0.05], [0.05, 0.16], 42 | [0.08, 0.16], [0.16, 0.08], [0.06, 0.16], [0.24, 0.05], [0.05, 0.22]] 43 | 44 | elif option == 4: 45 | # Two tasks, of same two gaussians 46 | centers = [[0, 0.2], [0, 0.2], 47 | [0.45, 0], [0.45, 0]] 48 | std = [[0.08, 0.22], [0.08, 0.22], 49 | [0.08, 0.16], [0.08, 0.16]] 50 | 51 | else: 52 | # If new / unknown option 53 | centers = [[0, 0.2], [0.6, 0.9], [1.3, 0.4], [1.6, -0.1], [2.0, 0.3], 54 | [0.45, 0], [0.7, 0.45], [1., 0.1], [1.7, -0.4], [2.3, 0.1]] 55 | std = [[0.08, 0.22], [0.24, 0.08], [0.04, 0.2], [0.16, 0.05], [0.05, 0.16], 56 | [0.08, 0.16], [0.16, 0.08], [0.06, 0.16], [0.24, 0.05], [0.05, 0.22]] 57 | 58 | if option != 1 and max_iter > 5: 59 | raise Exception("Current toydatagenerator only supports up to 5 tasks.") 60 | 61 | self.X, self.y = make_blobs(num_samples*2*max_iter, centers=centers, cluster_std=std) 62 | self.X = self.X.astype('float32') 63 | h = 0.01 64 | self.x_min, self.x_max = self.X[:, 0].min() - 0.2, self.X[:, 0].max() + 0.2 65 | self.y_min, self.y_max = self.X[:, 1].min() - 0.2, self.X[:, 1].max() + 0.2 66 | self.data_min = np.array([self.x_min, self.y_min], dtype='float32') 67 | self.data_max = np.array([self.x_max, self.y_max], dtype='float32') 68 | self.data_min = np.expand_dims(self.data_min, axis=0) 69 | self.data_max = np.expand_dims(self.data_max, axis=0) 70 | xx, yy = np.meshgrid(np.arange(self.x_min, self.x_max, h), 71 | np.arange(self.y_min, self.y_max, h)) 72 | xx = xx.astype('float32') 73 | yy = yy.astype('float32') 74 | self.test_shape = xx.shape 75 | X_test = np.c_[xx.ravel(), yy.ravel()] 76 | self.X_test = torch.from_numpy(X_test) 77 | self.y_test = torch.zeros((len(self.X_test)), dtype=self.X_test.dtype) 78 | self.max_iter = max_iter 79 | self.num_samples = num_samples # number of samples per task 80 | 81 | if option == 1: 82 | self.offset = 6 83 | elif option == 4: 84 | self.offset = 2 85 | 86 | self.cur_iter = 0 87 | 88 | def next_task(self): 89 | if self.cur_iter >= self.max_iter: 90 | raise Exception("Number of tasks exceeded!") 91 | else: 92 | x_train_0 = self.X[self.y == self.cur_iter] 93 | x_train_1 = self.X[self.y == self.cur_iter + self.offset] 94 | y_train_0 = np.zeros_like(self.y[self.y == self.cur_iter]) 95 | y_train_1 = np.ones_like(self.y[self.y == self.cur_iter + self.offset]) 96 | x_train = np.concatenate([x_train_0, x_train_1], axis=0) 97 | y_train = np.concatenate([y_train_0, y_train_1], axis=0) 98 | y_train = y_train.astype('int64') 99 | self.cur_iter += 1 100 | x_train = torch.from_numpy(x_train) 101 | y_train = torch.from_numpy(y_train) 102 | return TensorDataset(x_train, y_train), TensorDataset(self.X_test, self.y_test) 103 | 104 | def full_data(self): 105 | x_train_list = [] 106 | y_train_list = [] 107 | for i in range(self.max_iter): 108 | x_train_list.append(self.X[self.y == i]) 109 | x_train_list.append(self.X[self.y == i+self.offset]) 110 | y_train_list.append(np.zeros_like(self.y[self.y == i])) 111 | y_train_list.append(np.ones_like(self.y[self.y == i+self.offset])) 112 | x_train = np.concatenate(x_train_list, axis=0) 113 | y_train = np.concatenate(y_train_list, axis=0) 114 | y_train = y_train.astype('int64') 115 | x_train = torch.from_numpy(x_train) 116 | y_train = torch.from_numpy(y_train) 117 | return TensorDataset(x_train, y_train), TensorDataset(self.X_test, self.y_test) 118 | 119 | def reset(self): 120 | self.cur_iter = 0 121 | 122 | 123 | class PermutedMnistGenerator(): 124 | def __init__(self, max_iter=10, random_seed=0): 125 | # Open data file 126 | f = gzip.open('data/mnist.pkl.gz', 'rb') 127 | train_set, valid_set, test_set = pickle.load(f, encoding='latin1') 128 | f.close() 129 | 130 | # Define train and test data 131 | self.X_train = np.vstack((train_set[0], valid_set[0])) 132 | self.Y_train = np.hstack((train_set[1], valid_set[1])) 133 | self.X_test = test_set[0] 134 | self.Y_test = test_set[1] 135 | self.random_seed = random_seed 136 | self.max_iter = max_iter 137 | self.cur_iter = 0 138 | 139 | self.out_dim = 10 # Total number of unique classes 140 | self.class_list = range(10) # List of unique classes being considered, in the order they appear 141 | 142 | # self.classes is the classes (with correct indices for training/testing) of interest at each task_id 143 | self.classes = [] 144 | for iter in range(self.max_iter): 145 | self.classes.append(range(0,10)) 146 | 147 | self.sets = self.classes 148 | 149 | def get_dims(self): 150 | # Get data input and output dimensions 151 | return self.X_train.shape[1], self.out_dim 152 | 153 | def next_task(self): 154 | if self.cur_iter >= self.max_iter: 155 | raise Exception('Number of tasks exceeded!') 156 | else: 157 | np.random.seed(self.cur_iter+self.random_seed) 158 | perm_inds = np.arange(self.X_train.shape[1]) 159 | 160 | # First task is (unpermuted) MNIST, subsequent tasks are random permutations of pixels 161 | if self.cur_iter > 0: 162 | np.random.shuffle(perm_inds) 163 | 164 | # Retrieve train data 165 | next_x_train = deepcopy(self.X_train) 166 | next_x_train = next_x_train[:,perm_inds] 167 | 168 | # Initialise next_y_train to zeros, then change relevant entries to ones, and then stack 169 | next_y_train = deepcopy(self.Y_train) 170 | 171 | # Retrieve test data 172 | next_x_test = deepcopy(self.X_test) 173 | next_x_test = next_x_test[:,perm_inds] 174 | 175 | next_y_test = deepcopy(self.Y_test) 176 | 177 | self.cur_iter += 1 178 | 179 | next_x_train = torch.from_numpy(next_x_train) 180 | next_y_train = torch.from_numpy(next_y_train) 181 | next_x_test = torch.from_numpy(next_x_test) 182 | next_y_test = torch.from_numpy(next_y_test) 183 | return TensorDataset(next_x_train, next_y_train), TensorDataset(next_x_test, next_y_test) 184 | 185 | def reset(self): 186 | self.cur_iter = 0 187 | 188 | 189 | class SplitMnistGenerator(): 190 | def __init__(self): 191 | # Open data file 192 | f = gzip.open('data/mnist.pkl.gz', 'rb') 193 | train_set, valid_set, test_set = pickle.load(f, encoding='latin1') 194 | f.close() 195 | 196 | # Define train and test data 197 | self.X_train = np.vstack((train_set[0], valid_set[0])) 198 | self.X_test = test_set[0] 199 | self.train_label = np.hstack((train_set[1], valid_set[1])) 200 | self.test_label = test_set[1] 201 | 202 | # split MNIST 203 | task1 = [0, 1] 204 | task2 = [2, 3] 205 | task3 = [4, 5] 206 | task4 = [6, 7] 207 | task5 = [8, 9] 208 | self.sets = [task1, task2, task3, task4, task5] 209 | 210 | self.max_iter = len(self.sets) 211 | 212 | self.out_dim = 0 # Total number of unique classes 213 | self.class_list = [] # List of unique classes being considered, in the order they appear 214 | for task_id in range(self.max_iter): 215 | for class_index in range(len(self.sets[task_id])): 216 | if self.sets[task_id][class_index] not in self.class_list: 217 | # Convert from MNIST digit numbers to class index number by using self.class_list.index(), 218 | # which is done in self.classes 219 | self.class_list.append(self.sets[task_id][class_index]) 220 | self.out_dim = self.out_dim + 1 221 | 222 | # self.classes is the classes (with correct indices for training/testing) of interest at each task_id 223 | self.classes = [] 224 | for task_id in range(self.max_iter): 225 | class_idx = [] 226 | for i in range(len(self.sets[task_id])): 227 | class_idx.append(self.class_list.index(self.sets[task_id][i])) 228 | self.classes.append(class_idx) 229 | 230 | self.cur_iter = 0 231 | 232 | def get_dims(self): 233 | # Get data input and output dimensions 234 | return self.X_train.shape[1], self.out_dim 235 | 236 | def next_task(self): 237 | if self.cur_iter >= self.max_iter: 238 | raise Exception('Number of tasks exceeded!') 239 | else: 240 | next_x_train = [] 241 | next_y_train = [] 242 | next_x_test = [] 243 | next_y_test = [] 244 | 245 | # Loop over all classes in current iteration 246 | for class_index in range(np.size(self.sets[self.cur_iter])): 247 | 248 | # Find the correct set of training inputs 249 | train_id = np.where(self.train_label == self.sets[self.cur_iter][class_index])[0] 250 | # Stack the training inputs 251 | if class_index == 0: 252 | next_x_train = self.X_train[train_id] 253 | else: 254 | next_x_train = np.vstack((next_x_train, self.X_train[train_id])) 255 | 256 | # Initialise next_y_train to zeros, then change relevant entries to ones, and then stack 257 | next_y_train_interm = np.zeros((len(train_id)), dtype='int64') 258 | if class_index == 0: 259 | next_y_train = next_y_train_interm 260 | else: 261 | next_y_train_interm += 1 262 | next_y_train = np.concatenate((next_y_train, next_y_train_interm), axis=0) 263 | 264 | # Repeat above process for test inputs 265 | test_id = np.where(self.test_label == self.sets[self.cur_iter][class_index])[0] 266 | if class_index == 0: 267 | next_x_test = self.X_test[test_id] 268 | else: 269 | next_x_test = np.vstack((next_x_test, self.X_test[test_id])) 270 | 271 | next_y_test_interm = np.zeros((len(test_id)), dtype='int64') 272 | if class_index == 0: 273 | next_y_test = next_y_test_interm 274 | else: 275 | next_y_test_interm += 1 276 | next_y_test = np.concatenate((next_y_test, next_y_test_interm), axis=0) 277 | 278 | self.cur_iter += 1 279 | 280 | next_x_train = torch.from_numpy(next_x_train) 281 | next_y_train = torch.from_numpy(next_y_train) 282 | next_x_test = torch.from_numpy(next_x_test) 283 | next_y_test = torch.from_numpy(next_y_test) 284 | return TensorDataset(next_x_train, next_y_train), TensorDataset(next_x_test, next_y_test), self.sets[self.cur_iter-1] 285 | 286 | def reset(self): 287 | self.cur_iter = 0 288 | 289 | 290 | class SplitCIFAR100: 291 | 292 | def __init__(self, train_dataset, val_dataset): 293 | self.train_dataset = train_dataset 294 | self.val_dataset = val_dataset 295 | 296 | self.nr_classes = 100 297 | self.nr_classes_per_task = 10 298 | self.max_iter = self.nr_classes / self.nr_classes_per_task 299 | self.cur_iter = 0 300 | self.class_sets = [ 301 | list(range(10, 20)), 302 | list(range(20, 30)), 303 | list(range(30, 40)), 304 | list(range(40, 50)), 305 | list(range(50, 60)), 306 | list(range(60, 70)), 307 | list(range(70, 80)), 308 | list(range(80, 90)), 309 | list(range(90, 100)), 310 | list(range(100, 110)) 311 | ] 312 | 313 | def get_dims(self): 314 | # Get data input and output dimensions 315 | return len(self.train_dataset) / self.nr_classes_per_task, self.nr_classes_per_task 316 | 317 | def next_task(self): 318 | if self.cur_iter >= self.max_iter: 319 | raise Exception('Number of tasks exceeded!') 320 | else: 321 | train_dataset = SplitDataSet(self.train_dataset, self.cur_iter, self.nr_classes, 322 | self.nr_classes_per_task) 323 | val_dataset = SplitDataSet(self.val_dataset, self.cur_iter, self.nr_classes, 324 | self.nr_classes_per_task) 325 | 326 | self.cur_iter += 1 327 | 328 | return train_dataset, val_dataset, self.class_sets[self.cur_iter-1] 329 | 330 | 331 | class SplitDataSet(Dataset): 332 | 333 | def __init__(self, dataset, cur_iter, nr_classes, nr_classes_per_task): 334 | self.dataset = dataset 335 | self.cur_iter = cur_iter 336 | self.classes = [i for i in range(nr_classes)] 337 | 338 | targets = self.dataset.targets 339 | task_idx = torch.nonzero(torch.from_numpy( 340 | np.isin(targets, self.classes[nr_classes_per_task * self.cur_iter: 341 | nr_classes_per_task * self.cur_iter 342 | + nr_classes_per_task]))) 343 | 344 | self.subset = Subset(self.dataset, task_idx) 345 | 346 | def __getitem__(self, index): 347 | img, target = self.subset[index] 348 | target = target - 10 * self.cur_iter 349 | 350 | return img, target 351 | 352 | def __len__(self): 353 | return len(self.subset) 354 | -------------------------------------------------------------------------------- /main_cifar.py: -------------------------------------------------------------------------------- 1 | from train_cifar import train_cifar 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import copy 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--num_tasks', type=int, default=6, help='number of tasks for continual learning') 10 | parser.add_argument('--batch_size', type=int, default=256, help='number of data points in a batch') 11 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 12 | parser.add_argument('--num_epochs', type=int, default=80, help='number of training epochs') 13 | parser.add_argument('--num_points', type=int, default=200, help='number of inducing points for each task') 14 | parser.add_argument('--seed', type=int, default=42, help='random seed') 15 | parser.add_argument('--num_runs', type=int, default=1, help='how many random seed runs to average over') 16 | parser.add_argument('--select_method', type=str, default='lambda_descend', 17 | help='method to select memorable points, can be: {random, lambda_descend, lambda_ascend}') 18 | parser.add_argument('--tau', type=float, default=10, 19 | help='hyperparameter tau (scaled by a factor N), should be scaled with num_points') 20 | 21 | args = parser.parse_args() 22 | 23 | def main(args): 24 | 25 | use_cuda = True if torch.cuda.is_available() else False 26 | 27 | acc = train_cifar(num_tasks=args.num_tasks, batch_size=args.batch_size, lr=args.lr, num_epochs=args.num_epochs, 28 | num_points=args.num_points, use_cuda=use_cuda, select_method=args.select_method, tau=args.tau) 29 | 30 | return acc 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | acc_list = [] 36 | args_list = [] 37 | 38 | for i in range(args.num_runs): 39 | # Set random seed 40 | np.random.seed(args.seed+i) 41 | torch.manual_seed(args.seed+i) 42 | print('\nSplit CIFAR, seed', args.seed+i) 43 | 44 | # Run FROMP 45 | acc = main(args) 46 | acc_list.append(acc) 47 | args_list.append(copy.copy(args)) 48 | 49 | # Save results 50 | save_results = False 51 | if save_results: 52 | save_path = 'results/' 53 | torch.save({ 54 | 'args_list': args_list, 55 | 'accs_list': acc_list, 56 | }, save_path + 'cifar_seed_%d.tar' % (args.seed)) 57 | 58 | # Print average final accuracy and standard deviation 59 | print('Mean accuracy', np.mean([np.mean(x[-1]) for x in acc_list])) 60 | print('Mean std', np.std([np.mean(x[-1]) for x in acc_list])) 61 | -------------------------------------------------------------------------------- /main_permutedmnist.py: -------------------------------------------------------------------------------- 1 | from train_permutedmnist import train_permutedmnist 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import copy 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--num_tasks', type=int, default=10, help='number of tasks for continual learning') 10 | parser.add_argument('--batch_size', type=int, default=128, help='number of data points in a batch') 11 | parser.add_argument('--hidden_size', type=int, default=100, help='network hidden layer size (2 hidden layers)') 12 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 13 | parser.add_argument('--num_epochs', type=int, default=10, help='number of training epochs') 14 | parser.add_argument('--seed', type=int, default=43, help='random seed') 15 | parser.add_argument('--num_runs', type=int, default=1, help='how many random seed runs to average over') 16 | parser.add_argument('--num_points', type=int, default=200, help='number of memorable points for each task') 17 | parser.add_argument('--select_method', type=str, default='lambda_descend', 18 | help='method to select memorable points, can be: {random, lambda_descend, lambda_ascend}') 19 | parser.add_argument('--tau', type=float, default=0.5, 20 | help='hyperparameter tau (scaled by a factor N), should be scaled with num_points') 21 | 22 | args = parser.parse_args() 23 | 24 | def main(args): 25 | 26 | use_cuda = True if torch.cuda.is_available() else False 27 | 28 | acc = train_permutedmnist(num_tasks=args.num_tasks, batch_size=args.batch_size, hidden_size=args.hidden_size, 29 | lr=args.lr, num_epochs=args.num_epochs, num_points=args.num_points, 30 | use_cuda=use_cuda, select_method=args.select_method, tau=args.tau) 31 | 32 | return acc 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | acc_list = [] 38 | args_list = [] 39 | 40 | for i in range(args.num_runs): 41 | # Set random seed 42 | np.random.seed(args.seed+i) 43 | torch.manual_seed(args.seed+i) 44 | print('\nPermuted MNIST, seed', args.seed+i) 45 | 46 | # Run FROMP 47 | acc = main(args) 48 | acc_list.append(acc) 49 | args_list.append(copy.copy(args)) 50 | 51 | # Save results 52 | save_results = False 53 | if save_results: 54 | save_path = 'results/' 55 | torch.save({ 56 | 'args_list': args_list, 57 | 'accs_list': acc_list, 58 | }, save_path + 'permuted_seed_%d.tar' % (args.seed)) 59 | 60 | # Print average final accuracy and standard deviation 61 | print('Mean accuracy', np.mean([np.mean(x[-1]) for x in acc_list])) 62 | print('Mean std', np.std([np.mean(x[-1]) for x in acc_list])) 63 | -------------------------------------------------------------------------------- /main_splitmnist.py: -------------------------------------------------------------------------------- 1 | from train_splitmnist import train_splitmnist 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import copy 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--num_tasks', type=int, default=5, help='number of tasks for continual learning') 10 | parser.add_argument('--batch_size', type=int, default=128, help='number of data points in a batch') 11 | parser.add_argument('--hidden_size', type=int, default=256, help='network hidden layer size (2 hidden layers)') 12 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 13 | parser.add_argument('--num_epochs', type=int, default=15, help='number of training epochs') 14 | parser.add_argument('--seed', type=int, default=42, help='random seed') 15 | parser.add_argument('--num_runs', type=int, default=1, help='how many random seed runs to average over') 16 | parser.add_argument('--num_points', type=int, default=40, help='number of memorable points for each task') 17 | parser.add_argument('--select_method', type=str, default='lambda_descend', 18 | help='method to select memorable points, can be: {random, lambda_descend, lambda_ascend}') 19 | parser.add_argument('--tau', type=float, default=10, 20 | help='hyperparameter tau (scaled by a factor N), should be scaled with num_points') 21 | args = parser.parse_args() 22 | 23 | 24 | def main(args): 25 | 26 | use_cuda = True if torch.cuda.is_available() else False 27 | 28 | acc = train_splitmnist(num_tasks=args.num_tasks, batch_size=args.batch_size, hidden_size=args.hidden_size, 29 | lr=args.lr, num_epochs=args.num_epochs, num_points=args.num_points, 30 | select_method=args.select_method, use_cuda=use_cuda, tau=args.tau) 31 | return acc 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | acc_list = [] 37 | args_list = [] 38 | 39 | for i in range(args.num_runs): 40 | # Set random seed 41 | np.random.seed(args.seed+i) 42 | torch.manual_seed(args.seed+i) 43 | print('\nSplit MNIST, seed', args.seed+i) 44 | 45 | # Run FROMP 46 | acc = main(args) 47 | acc_list.append(acc) 48 | args_list.append(copy.copy(args)) 49 | 50 | # Save results 51 | save_results = False 52 | if save_results: 53 | save_path = 'results/' 54 | torch.save({ 55 | 'args_list': args_list, 56 | 'accs_list': acc_list, 57 | }, save_path + 'splitmnist_seed_%d.tar' % (args.seed)) 58 | 59 | # Print average final accuracy and standard deviation 60 | print('Mean accuracy', np.mean([np.mean(x[-1]) for x in acc_list])) 61 | print('Mean std', np.std([np.mean(x[-1]) for x in acc_list])) 62 | -------------------------------------------------------------------------------- /main_toydata.py: -------------------------------------------------------------------------------- 1 | from train_toydata import train_model 2 | import argparse 3 | import torch 4 | import numpy as np 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--num_tasks', type=int, default=5, help='number of tasks for continual learning') 8 | parser.add_argument('--batch_size', type=int, default=20, help='number of data points in a batch') 9 | parser.add_argument('--hidden_size', type=int, default=20, help='network hidden layer size') 10 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate') 11 | parser.add_argument('--num_epochs', type=int, default=50, help='number of training epochs') 12 | parser.add_argument('--num_points', type=int, default=20, help='number of inducing points for each task') 13 | parser.add_argument('--seed', type=int, default=123, help='random seed') 14 | parser.add_argument('--select_method', type=str, default='lambda_descend', 15 | help='method to select memorable points, can be: {random, lambda_descend, lambda_ascend}') 16 | parser.add_argument('--tau', type=float, default=1, 17 | help='hyperparameter tau (scaled by a factor N), should be scaled with num_points') 18 | 19 | args = parser.parse_args() 20 | 21 | def main(args): 22 | 23 | use_cuda = False 24 | 25 | train_model(args=args, use_cuda=use_cuda) 26 | 27 | 28 | if __name__ == '__main__': 29 | 30 | np.random.seed(args.seed) 31 | torch.manual_seed(args.seed) 32 | print('FROMP on toy data, seed %d' % (args.seed)) 33 | 34 | main(args) 35 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Fully connected network, size: input_size, hidden_size,... , output_size 7 | class MLP(nn.Module): 8 | def __init__(self, size, act='sigmoid'): 9 | super(type(self), self).__init__() 10 | self.num_layers = len(size) - 1 11 | lower_modules = [] 12 | for i in range(self.num_layers - 1): 13 | lower_modules.append(nn.Linear(size[i], size[i+1])) 14 | if act == 'relu': 15 | lower_modules.append(nn.ReLU()) 16 | elif act == 'sigmoid': 17 | lower_modules.append(nn.Sigmoid()) 18 | else: 19 | raise ValueError("%s activation layer hasn't been implemented in this code" %act) 20 | self.layer_1 = nn.Sequential(*lower_modules) 21 | self.layer_2 = nn.Linear(size[-2], size[-1]) 22 | 23 | 24 | def forward(self, x): 25 | o = self.layer_1(x) 26 | o = self.layer_2(o) 27 | return o 28 | 29 | 30 | class SplitMLP(nn.Module): 31 | def __init__(self, size, act='sigmoid'): 32 | super(type(self), self).__init__() 33 | self.num_layers = len(size) - 1 34 | lower_modules = [] 35 | for i in range(self.num_layers - 1): 36 | lower_modules.append(nn.Linear(size[i], size[i+1])) 37 | if act == 'relu': 38 | lower_modules.append(nn.ReLU()) 39 | elif act == 'sigmoid': 40 | lower_modules.append(nn.Sigmoid()) 41 | else: 42 | raise ValueError("%s activation layer hasn't been implemented in this code" %act) 43 | self.layer_1 = nn.Sequential(*lower_modules) 44 | self.layer_2 = nn.Linear(size[-2], size[-1]) 45 | 46 | 47 | def forward(self, x, label_set): 48 | o = self.layer_1(x) 49 | o = self.layer_2(o) 50 | o = o[:, label_set] 51 | return o 52 | 53 | 54 | class CifarNet(nn.Module): 55 | def __init__(self, in_channels, out_channels): 56 | super(type(self), self).__init__() 57 | self.conv_block = nn.Sequential( 58 | nn.Conv2d(in_channels, 32, 3, padding=1), 59 | nn.ReLU(), 60 | nn.Conv2d(32, 32, 3), 61 | nn.ReLU(), 62 | nn.MaxPool2d(2), 63 | nn.Dropout(p=0.25), 64 | nn.Conv2d(32, 64, 3, padding=1), 65 | nn.ReLU(), 66 | nn.Conv2d(64, 64, 3), 67 | nn.ReLU(), 68 | nn.MaxPool2d(2), 69 | nn.Dropout(p=0.25) 70 | ) 71 | self.linear_block = nn.Sequential( 72 | nn.Linear(64*6*6, 512), 73 | nn.ReLU(), 74 | nn.Dropout(p=0.5) 75 | ) 76 | self.out_block = nn.Linear(512, out_channels) 77 | 78 | 79 | def weight_init(self): 80 | nn.init.constant_(self.out_block.weight, 0) 81 | nn.init.constant_(self.out_block.bias, 0) 82 | 83 | 84 | def forward(self, x, label_set): 85 | o = self.conv_block(x) 86 | o = torch.flatten(o, 1) 87 | o = self.linear_block(o) 88 | o = self.out_block(o) 89 | o = o[:, label_set] 90 | return o 91 | -------------------------------------------------------------------------------- /opt_fromp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 5 | import torch.nn as nn 6 | import torch.nn.functional as f 7 | from utils import logistic_hessian, full_softmax_hessian 8 | 9 | 10 | def update_input(self, input, output): 11 | self.input = input[0].data 12 | self.output = output 13 | 14 | 15 | def _check_param_device(param, old_param_device): 16 | if old_param_device is None: 17 | old_param_device = param.get_device() if param.is_cuda else -1 18 | else: 19 | warn = False 20 | if param.is_cuda: # check if in same gpu 21 | warn = (param.get_device() != old_param_device) 22 | else: # check if in cpu 23 | warn = (old_param_device != -1) 24 | if warn: 25 | raise typeerror('found two parameters on different devices, ' 26 | 'this is currently not supported.') 27 | return old_param_device 28 | 29 | 30 | def parameters_to_matrix(parameters): 31 | param_device = None 32 | mat = [] 33 | for param in parameters: 34 | param_device = _check_param_device(param, param_device) 35 | m = param.shape[0] 36 | mat.append(param.view(m, -1)) 37 | return torch.cat(mat, dim=-1) 38 | 39 | 40 | def parameters_grads_to_vector(parameters): 41 | param_device = None 42 | vec = [] 43 | for param in parameters: 44 | param_device = _check_param_device(param, param_device) 45 | if param.grad is None: 46 | raise valueerror('gradient not available') 47 | vec.append(param.grad.data.view(-1)) 48 | return torch.cat(vec, dim=-1) 49 | 50 | 51 | # Optimizer that is torch.optim.adam with extra regularisation terms for FROMP (Pan et al., 2020) 52 | # grad_clip_norm: What value to clip the norm of the gradient to during training 53 | class opt_fromp(Optimizer): 54 | def __init__(self, model, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, prior_prec=1e-3, grad_clip_norm=1, tau=1, 55 | amsgrad=False): 56 | if not 0.0 <= lr: 57 | raise valueerror("invalid learning rate: {}".format(lr)) 58 | if not 0.0 <= eps: 59 | raise valueerror("invalid epsilon value: {}".format(eps)) 60 | if not 0.0 <= betas[0] < 1.0: 61 | raise valueerror("invalid beta parameter at index 0: {}".format(betas[0])) 62 | if not 0.0 <= betas[1] < 1.0: 63 | raise valueerror("invalid beta parameter at index 1: {}".format(betas[1])) 64 | if not 0.0 <= prior_prec: 65 | raise valueerror("invalid prior precision: {}".format(prior_prec)) 66 | if grad_clip_norm is not None and not 0.0 <= grad_clip_norm: 67 | raise valueerror("invalid gradient clip norm: {}".format(grad_clip_norm)) 68 | if not 0.0 <= tau: 69 | raise valueerror("invalid tau: {}".format(tau)) 70 | defaults = dict(lr=lr, betas=betas, eps=eps, prior_prec=prior_prec, grad_clip_norm=grad_clip_norm, 71 | tau=tau, amsgrad=amsgrad) 72 | super(opt_fromp, self).__init__(model.parameters(), defaults) 73 | 74 | self.model = model 75 | self.train_modules = [] 76 | self.set_train_modules(model) 77 | for module in self.train_modules: 78 | module.register_forward_hook(update_input) 79 | 80 | parameters = self.param_groups[0]['params'] 81 | 82 | p = parameters_to_vector(parameters) 83 | self.state['mu'] = p.clone().detach() 84 | self.state['mu_previous'] = p.clone().detach() 85 | self.state['fisher'] = torch.zeros_like(self.state['mu']) 86 | self.state['step'] = 0 87 | self.state['exp_avg'] = torch.zeros_like(self.state['mu']) 88 | self.state['exp_avg_sq'] = torch.zeros_like(self.state['mu']) 89 | if amsgrad: 90 | self.state['max_exp_avg_sq'] = torch.zeros_like(self.state['mu']) 91 | 92 | 93 | # Load zeros into model 94 | def load_zeros(self): 95 | zeros = torch.zeros_like(self.state['mu']) 96 | vector_to_parameters(zeros, self.param_groups[0]['params']) 97 | 98 | 99 | def update_mu(self): 100 | parameters = self.param_groups[0]['params'] 101 | p = parameters_to_vector(parameters) 102 | self.state['mu'] = p.clone().detach() 103 | 104 | # Calculate values (memorable_logits, hkh_l) for regularisation term (all but the first task) 105 | def init_task(self, closure, task_id, eps=1e-5): 106 | self.state['exp_avg'] = torch.zeros_like(self.state['mu']) 107 | self.state['exp_avg_sq'] = torch.zeros_like(self.state['mu']) 108 | self.state['step'] = 0 109 | self.state['kernel_inv'] = [] 110 | self.state['memorable_logits'] = [] 111 | fisher = self.state['fisher'] 112 | prior_prec = self.param_groups[0]['prior_prec'] 113 | mu = self.state['mu'] 114 | self.state['mu_previous'] = mu.clone().detach() 115 | parameters = self.param_groups[0]['params'] 116 | vector_to_parameters(mu, parameters) 117 | covariance = 1. / (fisher + prior_prec) 118 | 119 | # Calculate kernel = J \Sigma J^T for all memory points, and store via cholesky decomposition 120 | self.model.eval() 121 | for i in range(task_id): 122 | preds = closure(i) 123 | num_fun = preds.shape[-1] 124 | if num_fun == 1: 125 | preds = torch.sigmoid(preds) 126 | else: 127 | preds = torch.softmax(preds, dim=-1) 128 | self.state['memorable_logits'].append(preds.detach()) 129 | lc = [] 130 | for module in self.train_modules: 131 | lc.append(module.output) 132 | kernel_inv = [] 133 | for fi in range(num_fun): 134 | loss = preds[:, fi].sum() 135 | retain_graph = True if fi < num_fun - 1 else None 136 | grad = self.cac_grad(loss, lc, retain_graph=retain_graph) 137 | kernel = torch.einsum('ij,j,pj->ip', grad, covariance, grad) + \ 138 | torch.eye(grad.shape[0], dtype=grad.dtype, device=grad.device)*eps 139 | kernel_inv.append(torch.cholesky_inverse(torch.cholesky(kernel))) 140 | 141 | self.state['kernel_inv'].append(kernel_inv) 142 | 143 | 144 | # For calculating Jacobians in PyTorch 145 | def set_train_modules(self, module): 146 | if len(list(module.children())) == 0: 147 | if len(list(module.parameters())) != 0: 148 | self.train_modules.append(module) 149 | else: 150 | for child in list(module.children()): 151 | self.set_train_modules(child) 152 | 153 | 154 | # Update the hyperparameter tau 155 | def update_tau(self, tau): 156 | self.defaults['tau'] = tau 157 | 158 | 159 | # Calculate the gradient (part of calculating Jacobian) of the parameters lc wrt loss 160 | def cac_grad(self, loss, lc, retain_graph=None): 161 | linear_grad = torch.autograd.grad(loss, lc, retain_graph=retain_graph) 162 | grad = [] 163 | for i, module in enumerate(self.train_modules): 164 | g = linear_grad[i] 165 | a = module.input.clone().detach() 166 | m = a.shape[0] 167 | 168 | if isinstance(module, nn.Linear): 169 | grad.append(torch.einsum('ij,ik->ijk', g, a)) 170 | if module.bias is not None: 171 | grad.append(g) 172 | 173 | if isinstance(module, nn.Conv2d): 174 | a = f.unfold(a, kernel_size=module.kernel_size, dilation=module.dilation, padding=module.padding, 175 | stride=module.stride) 176 | _, k, hw = a.shape 177 | _, c, _, _ = g.shape 178 | g = g.view(m, c, -1) 179 | grad.append(torch.einsum('ijl,ikl->ijk', g, a)) 180 | if module.bias is not None: 181 | a = torch.ones((m, 1, hw), device=a.device) 182 | grad.append(torch.einsum('ijl,ikl->ijk', g, a)) 183 | 184 | if isinstance(module, nn.BatchNorm1d): 185 | grad.append(torch.mul(g, a)) 186 | if module.bias is not None: 187 | grad.append(g) 188 | 189 | if isinstance(module, nn.BatchNorm2d): 190 | grad.append(torch.einsum('ijkl->ij', torch.mul(g, a))) 191 | if module.bias is not None: 192 | grad.append(torch.einsum('ijkl->ij', g)) 193 | 194 | grad_m = parameters_to_matrix(grad) 195 | return grad_m.detach() 196 | 197 | 198 | # Calculate the Jacobian matrix 199 | def cac_jacobian(self, output, lc): 200 | if output.dim() > 2: 201 | raise valueerror('the dimension of output must be smaller than 3.') 202 | elif output.dim() == 2: 203 | num_fun = output.shape[1] 204 | grad = [] 205 | for i in range(num_fun): 206 | retain_graph = None if i == num_fun - 1 else True 207 | loss = output[:, i].sum() 208 | g = self.cac_grad(loss, lc, retain_graph=retain_graph) 209 | grad.append(g) 210 | result = torch.zeros((grad[0].shape[0], grad[0].shape[1], num_fun), 211 | dtype=grad[0].dtype, device=grad[0].device) 212 | for i in range(num_fun): 213 | result[:, :, i] = grad[i] 214 | return result 215 | 216 | 217 | # After training on a new task, update the fisher matrix estimate 218 | def update_fisher(self, closure): 219 | fisher = self.state['fisher'] 220 | preds = closure() 221 | lc = [] 222 | for module in self.train_modules: 223 | lc.append(module.output) 224 | jac = self.cac_jacobian(preds, lc) 225 | if preds.shape[-1] == 1: 226 | hes = logistic_hessian(preds).detach() 227 | hes = hes[:, :, None] 228 | else: 229 | hes = full_softmax_hessian(preds).detach() 230 | jhj = torch.einsum('ijd,idp,ijp->j', jac, hes, jac) 231 | fisher.add_(jhj) 232 | 233 | 234 | # Iteration step for this optimiser 235 | # Added extra regularisation terms to torch.optim.adam 236 | def step(self, closure_data, closure_memorable_points, task_id): 237 | defaults = self.defaults 238 | lr = self.param_groups[0]['lr'] 239 | beta1, beta2 = self.param_groups[0]['betas'] 240 | amsgrad = self.param_groups[0]['amsgrad'] 241 | parameters = self.param_groups[0]['params'] 242 | mu = self.state['mu'] 243 | 244 | #vector_to_parameters(mu, parameters) 245 | self.model.train() 246 | 247 | # Normal loss term over current task's data 248 | vector_to_parameters(mu, parameters) 249 | loss_cur, preds_cur = closure_data() 250 | loss_cur.backward(retain_graph=True) 251 | grad = parameters_grads_to_vector(parameters).detach() 252 | grad.mul_(1/defaults['tau']) 253 | 254 | 255 | # The loss term corresponding to memorable points 256 | if task_id > 0: 257 | self.model.eval() 258 | kernel_inv = self.state['kernel_inv'] 259 | memorable_logits = self.state['memorable_logits'] 260 | grad_t_sum = torch.zeros_like(grad) 261 | for t in range(task_id): 262 | preds_t = closure_memorable_points(t) 263 | num_fun = preds_t.shape[-1] 264 | if num_fun == 1: 265 | preds_t = torch.sigmoid(preds_t) 266 | else: 267 | preds_t = torch.softmax(preds_t, dim=-1) 268 | lc = [] 269 | for module in self.train_modules: 270 | lc.append(module.output) 271 | for fi in range(num_fun): 272 | # \Lambda * Jacobian 273 | loss_jac_t = preds_t[:, fi].sum() 274 | retain_graph = True if fi < num_fun - 1 else None 275 | jac_t = self.cac_grad(loss_jac_t, lc, retain_graph=retain_graph) 276 | 277 | # m_t - m_{t-1} 278 | logits_t = preds_t[:, fi].detach() 279 | delta_logits = logits_t - memorable_logits[t][:,fi] 280 | 281 | # K_{t-1}^{-1} 282 | kernel_inv_t = kernel_inv[t][fi] 283 | 284 | # Uncomment the following line for L2 variants of algorithms 285 | # kernel_inv_t = torch.eye(kernel_inv_t.shape[0], device=kernel_inv_t.device) 286 | 287 | # Calculate K_{t-1}^{-1} (m_t - m_{t-1}) 288 | kinvf_t = torch.squeeze(torch.matmul(kernel_inv_t, delta_logits[:,None]), dim=-1) 289 | 290 | grad_t = torch.einsum('ij,i->j', jac_t, kinvf_t) 291 | grad_t_sum.add_(grad_t) 292 | 293 | grad.add_(grad_t_sum) 294 | 295 | # Do gradient norm clipping 296 | clip_norm = self.defaults['grad_clip_norm'] 297 | if clip_norm is not None: 298 | grad_norm = torch.norm(grad) 299 | grad_norm = 1.0 if grad_norm < clip_norm else grad_norm/clip_norm 300 | grad.div_(grad_norm) 301 | 302 | # Adam update equations 303 | exp_avg, exp_avg_sq = self.state['exp_avg'], self.state['exp_avg_sq'] 304 | if amsgrad: 305 | max_exp_avg_sq = self.state['max_exp_avg_sq'] 306 | self.state['step'] += 1 307 | exp_avg.mul_(beta1).add_(1-beta1, grad) 308 | exp_avg_sq.mul_(beta2).addcmul_(1-beta2, grad, grad) 309 | if amsgrad: 310 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 311 | denom = max_exp_avg_sq.sqrt().add_(self.param_groups[0]['eps']) 312 | else: 313 | denom = exp_avg_sq.sqrt().add_(self.param_groups[0]['eps']) 314 | 315 | bias_correction1 = 1 - beta1 ** self.state['step'] 316 | bias_correction2 = 1 - beta2 ** self.state['step'] 317 | step_size = lr * math.sqrt(bias_correction2) / bias_correction1 318 | mu.addcdiv_(-step_size, exp_avg, denom) 319 | vector_to_parameters(mu, parameters) 320 | 321 | return loss_cur, preds_cur 322 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | brewer2mpl==1.4.1 2 | certifi==2020.6.20 3 | chardet==3.0.4 4 | click==7.1.2 5 | configparser==5.0.0 6 | cycler==0.10.0 7 | docker-pycreds==0.4.0 8 | gitdb==4.0.5 9 | GitPython==3.1.9 10 | idna==2.10 11 | joblib==0.14.0 12 | kiwisolver==1.1.0 13 | matplotlib==3.1.1 14 | numpy==1.17.2 15 | pandas==1.1.1 16 | pathtools==0.1.2 17 | Pillow==6.2.0 18 | promise==2.3 19 | protobuf==3.13.0 20 | psutil==5.7.2 21 | pyparsing==2.4.2 22 | python-dateutil==2.8.0 23 | pytz==2020.1 24 | PyYAML==5.3.1 25 | requests==2.24.0 26 | scikit-learn==0.21.3 27 | scipy==1.3.1 28 | seaborn==0.10.1 29 | sentry-sdk==0.18.0 30 | shortuuid==1.0.1 31 | six==1.15.0 32 | smmap==3.0.4 33 | subprocess32==3.5.4 34 | torch==1.2.0+cu92 35 | torchvision==0.4.0+cu92 36 | urllib3==1.25.10 37 | watchdog==0.10.3 38 | -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import CifarNet 4 | from datasets import SplitCIFAR100 5 | from torchvision import datasets 6 | from torch.utils.data.dataloader import DataLoader 7 | from utils import select_memorable_points, update_fisher, random_memorable_points 8 | from opt_fromp import opt_fromp 9 | import logging 10 | from torchvision import transforms 11 | 12 | 13 | def train(model, dataloaders, memorable_points, criterion, optimizer, label_sets, task_id=0, 14 | num_epochs=100, use_cuda=False): 15 | trainloader, testloader = dataloaders 16 | label_set_cur = label_sets[task_id] 17 | 18 | # Train 19 | model.train() 20 | for epoch in range(num_epochs): 21 | for inputs, labels in trainloader: 22 | if use_cuda: 23 | inputs, labels = inputs.cuda(), labels.cuda() 24 | 25 | if isinstance(optimizer, opt_fromp): 26 | # Closure on current task's data 27 | def closure(): 28 | optimizer.zero_grad() 29 | logits = model.forward(inputs, label_set_cur) 30 | loss = criterion(logits, labels) 31 | return loss, logits 32 | 33 | # Closure on memorable past data 34 | def closure_memorable_points(tid): 35 | memorable_data_t = memorable_points[tid][0] 36 | label_set_t = label_sets[tid] 37 | if use_cuda: 38 | memorable_data_t = memorable_data_t.cuda() 39 | optimizer.zero_grad() 40 | logits = model.forward(memorable_data_t, label_set_t) 41 | return logits 42 | 43 | # Optimiser step 44 | loss, logits = optimizer.step(closure, closure_memorable_points, task_id) 45 | 46 | # Test 47 | model.eval() 48 | print('Begin testing...') 49 | test_accuracy = [] 50 | for tid, testdata in enumerate(testloader): 51 | total = 0 52 | correct = 0 53 | for inputs, labels in testdata: 54 | if use_cuda: 55 | inputs, labels = inputs.cuda(), labels.cuda() 56 | label_set_t = label_sets[tid] 57 | logits = model.forward(inputs, label_set_t) 58 | predict_label = torch.argmax(logits, dim=-1) 59 | total += inputs.shape[0] 60 | if use_cuda: 61 | correct += torch.sum(predict_label == labels).cpu().item() 62 | else: 63 | correct += torch.sum(predict_label == labels).item() 64 | test_accuracy.append(correct / total) 65 | 66 | return test_accuracy 67 | 68 | 69 | def train_cifar(num_tasks, batch_size, lr, num_epochs, num_points, select_method='lambda_descend', 70 | use_cuda=True, tau=10): 71 | 72 | # Log console output to 'clog.txt' 73 | logging.basicConfig(filename='clog.txt') 74 | logger = logging.getLogger() 75 | logger.setLevel(logging.INFO) 76 | 77 | # Data generator 78 | num_classes_per_task = 10 79 | out_dim = num_tasks*num_classes_per_task 80 | data_transforms = transforms.ToTensor() 81 | cifar10_train = datasets.CIFAR10('datasets/', 82 | train=True, transform=data_transforms, download=True) 83 | cifar10_test = datasets.CIFAR10('datasets/', 84 | train=False, transform=data_transforms, download=True) 85 | cifar100_train = datasets.CIFAR100('datasets/', 86 | train=True, transform=data_transforms, download=True) 87 | cifar100_test = datasets.CIFAR100('datasets/', 88 | train=False, transform=data_transforms, download=True) 89 | data_gen = SplitCIFAR100(cifar100_train, cifar100_test) 90 | 91 | # Model 92 | model = CifarNet(3, out_dim) 93 | 94 | criterion = nn.CrossEntropyLoss() 95 | if use_cuda: 96 | criterion.cuda() 97 | model.cuda() 98 | 99 | # Optimiser 100 | opt = opt_fromp(model, lr=lr, prior_prec=1e-4, grad_clip_norm=100, tau=tau) 101 | 102 | # Train on tasks 103 | memorable_points = [] 104 | testloaders = [] 105 | label_sets = [] # To record the labels for each task. 106 | acc_list = [] 107 | for tid in range(num_tasks): 108 | 109 | # If not first task, need to calculate and store regularisation-term-related quantities 110 | if tid > 0: 111 | def closure(task_id): 112 | memorable_points_t = memorable_points[task_id][0] 113 | label_set_t = label_sets[task_id] 114 | if use_cuda: 115 | memorable_points_t = memorable_points_t.cuda() 116 | opt.zero_grad() 117 | logits = model.forward(memorable_points_t, label_set_t) 118 | return logits 119 | opt.init_task(closure, tid, eps=1e-6) 120 | 121 | # Data generator for this task 122 | if tid == 0: 123 | itrain, itest = cifar10_train, cifar10_test 124 | ilabel_set = list(range(10)) 125 | else: 126 | itrain, itest, ilabel_set = data_gen.next_task() 127 | label_sets.append(ilabel_set) 128 | itrainloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=True, num_workers=3) 129 | itestloader = DataLoader(dataset=itest, batch_size=batch_size, shuffle=False, num_workers=3) 130 | memorableloader = DataLoader(dataset=itrain, batch_size=6, shuffle=False, num_workers=3) 131 | testloaders.append(itestloader) 132 | iloaders = [itrainloader, testloaders] 133 | 134 | # Train and test 135 | acc = train(model, iloaders, memorable_points, criterion, opt, label_sets, task_id=tid, 136 | num_epochs=num_epochs, use_cuda=use_cuda) 137 | 138 | # Select memorable past datapoints 139 | if select_method == 'random': 140 | i_memorable_points = random_memorable_points(itrain, num_points=num_points, num_classes=num_classes_per_task) 141 | elif select_method == 'lambda_descend': 142 | i_memorable_points = select_memorable_points(memorableloader, model, num_points=num_points, num_classes=num_classes_per_task, 143 | use_cuda=use_cuda, label_set=ilabel_set, descending=True) 144 | elif select_method == 'lambda_ascend': 145 | i_memorable_points = select_memorable_points(memorableloader, model, num_points=num_points, num_classes=num_classes_per_task, 146 | use_cuda=use_cuda, label_set=ilabel_set, descending=False) 147 | else: 148 | raise Exception('Invalid memorable points selection method.') 149 | 150 | memorable_points.append(i_memorable_points) 151 | 152 | # Update covariance (\Sigma) 153 | update_fisher(memorableloader, model, opt, use_cuda=use_cuda, label_set=ilabel_set) 154 | 155 | print(acc) 156 | print('Mean accuracy after task %d: %f'%(tid+1, sum(acc)/len(acc))) 157 | logger.info('After learn task: %d'%(tid+1)) 158 | logger.info(acc) 159 | logger.info('Mean accuracy is: %f'%(sum(acc)/len(acc))) 160 | acc_list.append(acc) 161 | 162 | return acc_list 163 | -------------------------------------------------------------------------------- /train_permutedmnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import MLP 4 | from datasets import PermutedMnistGenerator 5 | from torch.utils.data.dataloader import DataLoader 6 | from utils import select_memorable_points, update_fisher, random_memorable_points 7 | from opt_fromp import opt_fromp 8 | import logging 9 | 10 | 11 | def train(model, dataloaders, memorable_points, criterion, optimizer, task_id=0, num_epochs=10, use_cuda=True): 12 | trainloader, testloader = dataloaders 13 | 14 | # Train 15 | model.train() 16 | for epoch in range(num_epochs): 17 | for inputs, labels in trainloader: 18 | if use_cuda: 19 | inputs, labels = inputs.cuda(), labels.cuda() 20 | 21 | if isinstance(optimizer, opt_fromp): 22 | # Closure on current task's data 23 | def closure(): 24 | optimizer.zero_grad() 25 | logits = model.forward(inputs) 26 | loss = criterion(logits, labels) 27 | return loss, logits 28 | 29 | # Closure on memorable past data 30 | def closure_memorable_points(task_id): 31 | memorable_points_t = memorable_points[task_id][0] 32 | if use_cuda: 33 | memorable_points_t = memorable_points_t.cuda() 34 | optimizer.zero_grad() 35 | logits = model.forward(memorable_points_t) 36 | return logits 37 | 38 | # Optimiser step 39 | loss, logits = optimizer.step(closure, closure_memorable_points, task_id) 40 | 41 | # Test 42 | model.eval() 43 | print('Begin testing...') 44 | test_accuracy = [] 45 | for testdata in testloader: 46 | total = 0 47 | correct = 0 48 | for inputs, labels in testdata: 49 | if use_cuda: 50 | inputs, labels = inputs.cuda(), labels.cuda() 51 | logits = model.forward(inputs) 52 | predict_label = torch.argmax(logits, dim=-1) 53 | total += inputs.shape[0] 54 | if use_cuda: 55 | correct += torch.sum(predict_label == labels).cpu().item() 56 | else: 57 | correct += torch.sum(predict_label == labels).item() 58 | test_accuracy.append(correct / total) 59 | 60 | return test_accuracy 61 | 62 | 63 | def train_permutedmnist(num_tasks, batch_size, hidden_size, lr, num_epochs, num_points, 64 | select_method='lambda_descend', use_cuda=True, tau=0.5): 65 | 66 | # Log console output to 'pmlog.txt' 67 | logging.basicConfig(filename='pmlog.txt') 68 | logger = logging.getLogger() 69 | logger.setLevel(logging.INFO) 70 | 71 | # Data generator 72 | datagen = PermutedMnistGenerator(max_iter=num_tasks) 73 | 74 | # Model 75 | num_classes = 10 76 | layer_size = [784, hidden_size, hidden_size, num_classes] 77 | model = MLP(layer_size, act='relu') 78 | 79 | criterion = nn.CrossEntropyLoss() 80 | if use_cuda: 81 | criterion.cuda() 82 | model.cuda() 83 | 84 | # Optimiser 85 | opt = opt_fromp(model, lr=lr, prior_prec=1e-5, grad_clip_norm=0.01, tau=tau) 86 | 87 | # Train on tasks 88 | memorable_points = [] 89 | testloaders = [] 90 | acc_list = [] 91 | for tid in range(num_tasks): 92 | 93 | # If not first task, need to calculate and store regularisation-term-related quantities 94 | if tid > 0: 95 | def closure(task_id): 96 | memorable_points_t = memorable_points[task_id][0] 97 | if use_cuda: 98 | memorable_points_t = memorable_points_t.cuda() 99 | opt.zero_grad() 100 | logits = model.forward(memorable_points_t) 101 | return logits 102 | opt.init_task(closure, tid, eps=1e-5) 103 | 104 | # Data generator for this task 105 | itrain, itest = datagen.next_task() 106 | itrainloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=True, num_workers=3) 107 | itestloader = DataLoader(dataset=itest, batch_size=batch_size, shuffle=False, num_workers=3) 108 | memorableloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=False, num_workers=3) 109 | testloaders.append(itestloader) 110 | iloaders = [itrainloader, testloaders] 111 | 112 | # Train and test 113 | acc = train(model, iloaders, memorable_points, criterion, opt, task_id=tid, num_epochs=num_epochs, 114 | use_cuda=use_cuda) 115 | 116 | # Select memorable past datapoints 117 | if select_method == 'random': 118 | i_memorable_points = random_memorable_points(itrain, num_points=num_points, num_classes=num_classes) 119 | elif select_method == 'lambda_descend': 120 | i_memorable_points = select_memorable_points(memorableloader, model, num_points=num_points, num_classes=num_classes, 121 | use_cuda=use_cuda, descending=True) 122 | elif select_method == 'lambda_ascend': 123 | i_memorable_points = select_memorable_points(memorableloader, model, num_points=num_points, num_classes=num_classes, 124 | use_cuda=use_cuda, descending=False) 125 | else: 126 | raise Exception('Invalid memorable points selection method.') 127 | 128 | memorable_points.append(i_memorable_points) 129 | 130 | # Update covariance (\Sigma) 131 | update_fisher(memorableloader, model, opt, use_cuda=use_cuda) 132 | 133 | print(acc) 134 | print('Mean accuracy after task %d: %f'%(tid+1, sum(acc)/len(acc))) 135 | logger.info('After learn task: %d'%(tid+1)) 136 | logger.info(acc) 137 | acc_list.append(acc) 138 | 139 | return acc_list 140 | -------------------------------------------------------------------------------- /train_splitmnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import SplitMLP 4 | from datasets import SplitMnistGenerator 5 | from torch.utils.data.dataloader import DataLoader 6 | from utils import select_memorable_points, update_fisher, random_memorable_points 7 | from opt_fromp import opt_fromp 8 | import logging 9 | 10 | 11 | def train(model, dataloaders, memorable_points, criterion, optimizer, label_sets, task_id=0, 12 | num_epochs=15, use_cuda=True): 13 | trainloader, testloader = dataloaders 14 | 15 | # Train 16 | model.train() 17 | label_set_cur = label_sets[task_id] 18 | for epoch in range(num_epochs): 19 | for inputs, labels in trainloader: 20 | if use_cuda: 21 | inputs, labels = inputs.cuda(), labels.cuda() 22 | 23 | if isinstance(optimizer, opt_fromp): 24 | # Closure on current task's data 25 | def closure(): 26 | optimizer.zero_grad() 27 | logits = model.forward(inputs, label_set_cur) 28 | loss = criterion(logits, labels) 29 | return loss, logits 30 | 31 | # Closure on memorable past data 32 | def closure_memorable_points(task_id): 33 | memorable_points_t = memorable_points[task_id][0] 34 | label_set_t = label_sets[task_id] 35 | if use_cuda: 36 | memorable_points_t = memorable_points_t.cuda() 37 | optimizer.zero_grad() 38 | logits = model.forward(memorable_points_t, label_set_t) 39 | return logits 40 | 41 | # Optimiser step 42 | loss, logits = optimizer.step(closure, closure_memorable_points, task_id) 43 | 44 | # Test 45 | model.eval() 46 | print('Begin testing...') 47 | test_accuracy = [] 48 | for tid, testdata in enumerate(testloader): 49 | total = 0 50 | correct = 0 51 | for inputs, labels in testdata: 52 | if use_cuda: 53 | inputs, labels = inputs.cuda(), labels.cuda() 54 | label_set_t = label_sets[tid] 55 | logits = model.forward(inputs, label_set_t) 56 | predict_label = torch.argmax(logits, dim=-1) 57 | total += inputs.shape[0] 58 | if use_cuda: 59 | correct += torch.sum(predict_label == labels).cpu().item() 60 | else: 61 | correct += torch.sum(predict_label == labels).item() 62 | test_accuracy.append(correct / total) 63 | 64 | return test_accuracy 65 | 66 | 67 | def train_splitmnist(num_tasks, batch_size, hidden_size, lr, num_epochs, num_points, 68 | select_method='lambda_descend', use_cuda=True, tau=1): 69 | 70 | # Log console output to 'smlog.txt' 71 | logging.basicConfig(filename='smlog.txt') 72 | logger = logging.getLogger() 73 | logger.setLevel(logging.INFO) 74 | 75 | # Data generator 76 | datagen = SplitMnistGenerator() 77 | 78 | # Model 79 | num_classes_per_task = 2 80 | layer_size = [784, hidden_size, hidden_size, 10] 81 | model = SplitMLP(layer_size, act='relu') 82 | 83 | criterion = nn.CrossEntropyLoss() 84 | if use_cuda: 85 | criterion.cuda() 86 | model.cuda() 87 | 88 | # Optimiser 89 | opt = opt_fromp(model, lr=lr, prior_prec=1e-3, grad_clip_norm=0.1, tau=tau) 90 | 91 | # Train on tasks 92 | memorable_points = [] 93 | testloaders = [] 94 | label_sets = [] # To record the labels for each task 95 | acc_list = [] 96 | for tid in range(num_tasks): 97 | 98 | # If not first task, need to calculate and store regularisation-term-related quantities 99 | if tid > 0: 100 | def closure(task_id): 101 | memorable_points_t = memorable_points[task_id][0] 102 | label_set_t = label_sets[task_id] 103 | if use_cuda: 104 | memorable_points_t = memorable_points_t.cuda() 105 | opt.zero_grad() 106 | logits = model.forward(memorable_points_t, label_set_t) 107 | return logits 108 | opt.init_task(closure, tid, eps=1e-6) 109 | 110 | # Data generator for this task 111 | itrain, itest, ilabel_set = datagen.next_task() 112 | label_sets.append(ilabel_set) 113 | itrainloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=True, num_workers=3) 114 | itestloader = DataLoader(dataset=itest, batch_size=batch_size, shuffle=False, num_workers=3) 115 | memorableloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=False, num_workers=3) 116 | testloaders.append(itestloader) 117 | iloaders = [itrainloader, testloaders] 118 | 119 | # Train and test 120 | acc = train(model, iloaders, memorable_points, criterion, opt, label_sets, task_id=tid, 121 | num_epochs=num_epochs, use_cuda=use_cuda) 122 | 123 | # Select memorable past datapoints 124 | if select_method == 'random': 125 | i_memorable_points = random_memorable_points(itrain, num_points=num_points, num_classes=num_classes_per_task) 126 | elif select_method == 'lambda_descend': 127 | i_memorable_points = select_memorable_points(memorableloader, model, num_points=num_points, num_classes=num_classes_per_task, 128 | use_cuda=use_cuda, label_set=ilabel_set, descending=True) 129 | elif select_method == 'lambda_ascend': 130 | i_memorable_points = select_memorable_points(memorableloader, model, num_points=num_points, num_classes=num_classes_per_task, 131 | use_cuda=use_cuda, label_set=ilabel_set, descending=False) 132 | else: 133 | raise Exception('Invalid memorable points selection method.') 134 | 135 | memorable_points.append(i_memorable_points) 136 | 137 | # Update covariance (\Sigma) 138 | update_fisher(memorableloader, model, opt, use_cuda=use_cuda, label_set=ilabel_set) 139 | 140 | print(acc) 141 | print('Mean accuracy after task %d: %f'%(tid, sum(acc)/len(acc))) 142 | logger.info('After learn task: %d'%(tid+1)) 143 | logger.info(acc) 144 | acc_list.append(acc) 145 | 146 | return acc_list 147 | -------------------------------------------------------------------------------- /train_toydata.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import MLP 4 | from opt_fromp import opt_fromp 5 | from datasets import ToydataGenerator 6 | from torch.utils.data.dataloader import DataLoader 7 | from utils import select_memorable_points, update_fisher, random_memorable_points 8 | import numpy as np 9 | import time 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def train(model, dataloaders, memorable_points, criterion, optimizer, task_id=0, num_epochs=25, use_cuda=False): 14 | 15 | trainloader, testloader = dataloaders 16 | 17 | model.train() 18 | for epoch in range(num_epochs): 19 | running_train_loss = 0 20 | count = 0 21 | for inputs, labels in trainloader: 22 | if use_cuda: 23 | inputs, labels = inputs.cuda(), labels.cuda() 24 | 25 | # Continual learning optimiser 26 | if isinstance(optimizer, opt_fromp): 27 | def closure(): 28 | optimizer.zero_grad() 29 | logits = model.forward(inputs) 30 | loss = criterion(torch.squeeze(logits, dim=-1), labels) 31 | return loss, logits 32 | def closure_memorable_points(task_id): 33 | memorable_points_t = memorable_points[task_id] 34 | if use_cuda: 35 | memorable_points_t = memorable_points_t.cuda() 36 | optimizer.zero_grad() 37 | logits = model.forward(memorable_points_t) 38 | return logits 39 | loss, logits = optimizer.step(closure, closure_memorable_points, task_id) 40 | 41 | if use_cuda: 42 | loss_val = loss.detach().cpu().item() 43 | else: 44 | loss_val = loss.detach().item() 45 | running_train_loss += loss_val 46 | count += 1 47 | if epoch == 0 or epoch == num_epochs-1: 48 | print('Epoch[%d]: Train loss: %f' %(epoch, running_train_loss/count)) 49 | 50 | # Run on test data (a 2D grid of points for plotting) 51 | full_outputs = [] 52 | model.eval() 53 | print('Begin test.') 54 | for inputs, _ in testloader: 55 | if use_cuda: 56 | inputs = inputs.cuda() 57 | outputs = model(inputs) 58 | full_outputs.append(outputs) 59 | full_outputs = torch.cat(full_outputs, dim=0) 60 | full_outputs = torch.sigmoid(full_outputs) 61 | 62 | return full_outputs 63 | 64 | 65 | def train_model(args, use_cuda=False): 66 | start_time = time.time() 67 | 68 | # Read values from args 69 | num_tasks = args.num_tasks 70 | batch_size = args.batch_size 71 | hidden_size = args.hidden_size 72 | lr = args.lr 73 | num_epochs = args.num_epochs 74 | num_points = args.num_points 75 | coreset_select_method = args.select_method 76 | 77 | # Some parameters 78 | dataset_generation_test = False 79 | dataset_num_samples = 2000 80 | 81 | # Colours for plotting 82 | color = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] 83 | 84 | # Load / Generate toy data 85 | datagen = ToydataGenerator(max_iter=num_tasks, num_samples=dataset_num_samples) 86 | 87 | plt.figure() 88 | datagen.reset() 89 | total_loaders = [] 90 | criterion_cl = nn.CrossEntropyLoss() 91 | 92 | # Create model 93 | layer_size = [2, hidden_size, hidden_size, 2] 94 | model = MLP(layer_size, act='sigmoid') 95 | if use_cuda: 96 | model = model.cuda() 97 | 98 | # Optimiser 99 | opt = opt_fromp(model, lr=lr, prior_prec=1e-4, grad_clip_norm=None, tau=args.tau) 100 | 101 | memorable_points = None 102 | inducing_targets = None 103 | 104 | for tid in range(num_tasks): 105 | # If not first task, need to calculate and store regularisation-term-related quantities 106 | if tid > 0: 107 | def closure(task_id): 108 | memorable_points_t = memorable_points[task_id] 109 | if use_cuda: 110 | memorable_points_t = memorable_points_t.cuda() 111 | opt.zero_grad() 112 | logits = model.forward(memorable_points_t) 113 | return logits 114 | opt.init_task(closure, tid, eps=1e-3) 115 | 116 | # Data generator for this task 117 | itrain, itest = datagen.next_task() 118 | itrainloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=True, num_workers=8) 119 | itestloader = DataLoader(dataset=itest, batch_size=batch_size, shuffle=False, num_workers=8) 120 | inducingloader = DataLoader(dataset=itrain, batch_size=batch_size, shuffle=False, num_workers=8) 121 | iloaders = [itrainloader, itestloader] 122 | 123 | if tid == 0: 124 | total_loaders = [itrainloader] 125 | else: 126 | total_loaders.append(itrainloader) 127 | 128 | # Train and test 129 | cl_outputs = train(model, iloaders, memorable_points, criterion_cl, opt, task_id=tid, num_epochs=num_epochs, 130 | use_cuda=use_cuda) 131 | 132 | # Select memorable past datapoints 133 | if coreset_select_method == 'random': 134 | i_memorable_points, i_inducing_targets = random_memorable_points( 135 | itrain, num_points=num_points, num_classes=2) 136 | else: 137 | i_memorable_points, i_inducing_targets = select_memorable_points( 138 | inducingloader, model, num_points=num_points, use_cuda=use_cuda) 139 | 140 | # Add memory points to set 141 | if tid > 0: 142 | memorable_points.append(i_memorable_points) 143 | inducing_targets.append(i_inducing_targets) 144 | else: 145 | memorable_points = [i_memorable_points] 146 | inducing_targets = [i_inducing_targets] 147 | 148 | # Update covariance (\Sigma) 149 | update_fisher(inducingloader, model, opt, use_cuda=use_cuda) 150 | 151 | # Plot visualisation (2D figure) 152 | cl_outputs, _ = torch.max(cl_outputs, dim=-1) 153 | cl_show = 2*cl_outputs - 1 154 | 155 | cl_show = cl_show.detach() 156 | if use_cuda: 157 | cl_show = cl_show.cpu() 158 | cl_show = cl_show.numpy() 159 | cl_show = cl_show.reshape(datagen.test_shape) 160 | 161 | plt.figure() 162 | axs = plt.subplot(111) 163 | axs.title.set_text('FROMP') 164 | if not dataset_generation_test: 165 | plt.imshow(cl_show, cmap='gray', 166 | extent=(datagen.x_min, datagen.x_max, datagen.y_min, datagen.y_max), origin='lower') 167 | for l in range(tid+1): 168 | idx = np.where(datagen.y == l) 169 | plt.scatter(datagen.X[idx][:,0], datagen.X[idx][:,1], c=color[l], s=0.03) 170 | idx = np.where(datagen.y == l+datagen.offset) 171 | plt.scatter(datagen.X[idx][:,0], datagen.X[idx][:,1], c=color[l+datagen.offset], s=0.03) 172 | if not dataset_generation_test: 173 | plt.scatter(memorable_points[l][:,0], memorable_points[l][:, 1], c='m', s=0.4, marker='x') 174 | 175 | plt.show() 176 | 177 | # Calculate and print train accuracy and negative log likelihood 178 | with torch.no_grad(): 179 | if not dataset_generation_test: 180 | model.eval() 181 | N = len(itrain) 182 | 183 | metric_task_id = 0 184 | nll_loss_avg = 0 185 | accuracy_avg = 0 186 | for metric_loader in total_loaders: 187 | nll_loss = 0 188 | correct = 0 189 | for inputs, labels in metric_loader: 190 | if use_cuda: 191 | inputs, labels = inputs.cuda(), labels.cuda() 192 | 193 | logits = model.forward(inputs) 194 | 195 | nll_loss += nn.functional.cross_entropy(torch.squeeze(logits, dim=-1), labels) * float(inputs.shape[0]) 196 | 197 | # Calculate predicted classes 198 | pred = logits.data.max(1, keepdim=True)[1] 199 | 200 | # Count number of correctly predicted datapoints 201 | correct += pred.eq(labels.data.view_as(pred)).sum() 202 | 203 | nll_loss /= N 204 | accuracy = float(correct) / float(N) * 100. 205 | 206 | print('Task {}, Train accuracy: {:.2f}%, Train Loglik: {:.4f}'.format( 207 | metric_task_id, accuracy, nll_loss)) 208 | 209 | metric_task_id += 1 210 | nll_loss_avg += nll_loss 211 | accuracy_avg += accuracy 212 | 213 | print('Avg train accuracy: {:.2f}%, Avg train Loglik: {:.4f}'.format( 214 | accuracy_avg/metric_task_id, nll_loss_avg/metric_task_id)) 215 | 216 | print('Time taken: ', time.time()-start_time) 217 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # We only calculate the diagonal elements of the hessian 6 | def logistic_hessian(f): 7 | f = f[:, :] 8 | pi = torch.sigmoid(f) 9 | return pi*(1-pi) 10 | 11 | 12 | def softmax_hessian(f): 13 | s = F.softmax(f, dim=-1) 14 | return s - s*s 15 | 16 | 17 | # Calculate the full softmax hessian 18 | def full_softmax_hessian(f): 19 | s = F.softmax(f, dim=-1) 20 | e = torch.eye(s.shape[-1], dtype=s.dtype, device=s.device) 21 | return s[:, :, None]*e[None, :, :] - s[:, :, None]*s[:, None, :] 22 | 23 | 24 | # Select memorable points ordered by their lambda values (descending=True picks most important points) 25 | def select_memorable_points(dataloader, model, num_points=10, num_classes=2, 26 | use_cuda=False, label_set=None, descending=True): 27 | memorable_points = {} 28 | scores = {} 29 | num_points_per_class = int(num_points/num_classes) 30 | for i, dt in enumerate(dataloader): 31 | data, target = dt 32 | if use_cuda: 33 | data_in = data.cuda() 34 | else: 35 | data_in = data 36 | if label_set == None: 37 | f = model.forward(data_in) 38 | else: 39 | f = model.forward(data_in, label_set) 40 | if f.shape[-1] > 1: 41 | lamb = softmax_hessian(f) 42 | if use_cuda: 43 | lamb = lamb.cpu() 44 | lamb = torch.sum(lamb, dim=-1) 45 | lamb = lamb.detach() 46 | else: 47 | lamb = logistic_hessian(f) 48 | if use_cuda: 49 | lamb = lamb.cpu() 50 | lamb = torch.squeeze(lamb, dim=-1) 51 | lamb = lamb.detach() 52 | for cid in range(num_classes): 53 | p_c = data[target == cid] 54 | if len(p_c) > 0: 55 | s_c = lamb[target == cid] 56 | if len(s_c) > 0: 57 | if cid not in memorable_points: 58 | memorable_points[cid] = p_c 59 | scores[cid] = s_c 60 | else: 61 | memorable_points[cid] = torch.cat([memorable_points[cid], p_c], dim=0) 62 | scores[cid] = torch.cat([scores[cid], s_c], dim=0) 63 | if len(memorable_points[cid]) > num_points_per_class: 64 | _, indices = scores[cid].sort(descending=descending) 65 | memorable_points[cid] = memorable_points[cid][indices[:num_points_per_class]] 66 | scores[cid] = scores[cid][indices[:num_points_per_class]] 67 | r_points = [] 68 | r_labels = [] 69 | for cid in range(num_classes): 70 | r_points.append(memorable_points[cid]) 71 | r_labels.append(torch.ones(memorable_points[cid].shape[0], dtype=torch.long, 72 | device=memorable_points[cid].device)*cid) 73 | return [torch.cat(r_points, dim=0), torch.cat(r_labels, dim=0)] 74 | 75 | 76 | # Randomly select some points as memory 77 | def random_memorable_points(dataset, num_points, num_classes): 78 | memorable_points = {} 79 | num_points_per_class = int(num_points/num_classes) 80 | exact_num_points = num_points_per_class*num_classes 81 | idx_list = torch.randperm(len(dataset)) 82 | select_points_num = 0 83 | for idx in range(len(idx_list)): 84 | data, label = dataset[idx_list[idx]] 85 | cid = label.item() if isinstance(label, torch.Tensor) else label 86 | if cid in memorable_points: 87 | if len(memorable_points[cid]) < num_points_per_class: 88 | memorable_points[cid].append(data) 89 | select_points_num += 1 90 | else: 91 | memorable_points[cid] = [data] 92 | select_points_num += 1 93 | if select_points_num >= exact_num_points: 94 | break 95 | r_points = [] 96 | r_labels = [] 97 | for cid in range(num_classes): 98 | r_points.append(torch.stack(memorable_points[cid], dim=0)) 99 | r_labels.append(torch.ones(len(memorable_points[cid]), dtype=torch.long, 100 | device=r_points[cid].device)*cid) 101 | return [torch.cat(r_points, dim=0), torch.cat(r_labels, dim=0)] 102 | 103 | 104 | # Update the fisher matrix after training on a task 105 | def update_fisher(dataloader, model, opt, label_set=None, use_cuda=False): 106 | model.eval() 107 | for data, label in dataloader: 108 | if use_cuda: 109 | data = data.cuda() 110 | def closure(): 111 | opt.zero_grad() 112 | if label_set == None: 113 | logits = model.forward(data) 114 | else: 115 | logits = model.forward(data, label_set) 116 | return logits 117 | opt.update_fisher(closure) 118 | 119 | 120 | def save(opt, memorable_points, path): 121 | torch.save({ 122 | 'mu': opt.state['mu'], 123 | 'fisher': opt.state['fisher'], 124 | 'memorable_points': memorable_points 125 | }, path) 126 | 127 | 128 | def load(opt, path): 129 | checkpoint = torch.load(path) 130 | opt.state['mu'] = checkpoint['mu'] 131 | opt.state['fisher'] = checkpoint['fisher'] 132 | return checkpoint['memorable_points'] 133 | 134 | 135 | def softmax_predictive_accuracy(logits_list, y, ret_loss = False): 136 | probs_list = [F.log_softmax(logits, dim=1) for logits in logits_list] 137 | probs_tensor = torch.stack(probs_list, dim = 2) 138 | probs = torch.mean(probs_tensor, dim=2) 139 | if ret_loss: 140 | loss = F.nll_loss(probs, y, reduction='sum').item() 141 | _, pred_class = torch.max(probs, 1) 142 | correct = pred_class.eq(y.view_as(pred_class)).sum().item() 143 | if ret_loss: 144 | return correct, loss 145 | return correct 146 | --------------------------------------------------------------------------------