├── .gitignore ├── ACL-resnet ├── README.md ├── data │ └── split_miniimagenet.py └── src │ ├── acl.py │ ├── configs │ ├── config_cifar100.yml │ └── config_miniimagenet.yml │ ├── dataloaders │ ├── cifar100.py │ ├── datasets_utils.py │ ├── miniimagenet.py │ └── utils.py │ ├── main.py │ ├── networks │ ├── alexnet_acl.py │ ├── discriminator.py │ ├── mlp_acl.py │ └── resnet_acl.py │ └── utils.py ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── data ├── notMNIST.zip └── split_miniimagenet.py ├── requirements.txt └── src ├── acl.py ├── configs ├── config_cifar100.yml ├── config_miniimagenet.yml ├── config_mnist5.yml ├── config_multidatasets.yml └── config_pmnist.yml ├── dataloaders ├── cifar100.py ├── datasets_utils.py ├── miniimagenet.py ├── mnist5.py ├── mulitidatasets.py ├── pmnist.py └── utils.py ├── main.py ├── networks ├── alexnet_acl.py ├── discriminator.py └── mlp_acl.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Adversarial-Continual-Learning/a99dfadcc59a12d903af6e5366a025ca44b3af07/.gitignore -------------------------------------------------------------------------------- /ACL-resnet/README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Continual Learning 2 | 3 | 4 | This is the official PyTorch implementation of the [Adversarial Continual Learning](https://arxiv.org/abs/2003.09553) published at ECCV 2020. 5 | 6 | 7 | ## Notice: 8 | For the experiments shown in the main paper please refer to the [main directory](https://github.com/facebookresearch/Adversarial-Continual-Learning). This directory shall be used to use ACL with a ResNet18 backbone architecture. Follow the instructions [here](https://github.com/facebookresearch/Adversarial-Continual-Learning) to install the requirements and cloning the repo and use the following commands to run ACL with ResNet18 on CIFAR100 and miniImageNet from the current directory. If you have a different dataset, you can indeed create a Dataset class and corresponding dataloaders following how that is done for given datasets in `dataloaders` directory. 9 | 10 | 11 | Split CIFAR100 (20 Tasks): 12 | 13 | ``python main.py --config ./configs/config_cifar100.yml`` 14 | 15 | Split MiniImageNet (20 Tasks): 16 | 17 | `python main.py --config ./configs/config_miniimagenet.yml` 18 | 19 | 20 | ## Authors: 21 | [Sayna Ebrahimi](https://people.eecs.berkeley.edu/~sayna/) (UC Berkeley, FAIR), [Franziska Meier](https://am.is.tuebingen.mpg.de/person/fmeier) (FAIR), [Roberto Calandra](https://www.robertocalandra.com/about/) (FAIR), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) (UC Berkeley), [Marcus Rohrbach](http://rohrbach.vision/) (FAIR) 22 | 23 | ### Citation 24 | If using this code, parts of it, or developments from it, please cite our paper: 25 | ``` 26 | @article{ebrahimi2020adversarial, 27 | title={Adversarial Continual Learning}, 28 | author={Ebrahimi, Sayna and Meier, Franziska and Calandra, Roberto and Darrell, Trevor and Rohrbach, Marcus}, 29 | journal={arXiv preprint arXiv:2003.09553}, 30 | year={2020} 31 | } 32 | ``` 33 | 34 | #### Datasets 35 | 36 | *miniImageNet* data should be [downloaded](https://github.com/yaoyao-liu/mini-imagenet-tools#about-mini-ImageNet) and pickled as a dictionary (`data.pkl`) with `images` and `labels` keys and placed in a sub-folder in `ags.data_dir` named as `miniimagenet`. The script used to split `data.pkl` into training and test sets is included in data dorectory (`data/`) 37 | 38 | *notMNIST* dataset is included here in `./data/notMNIST` as it was used in our experiments. 39 | 40 | Other datasets will be automatically downloaded and extracted to `./data` if they do not exist. 41 | 42 | ## Questions/ Bugs 43 | * For questions/bugs, contact the author Sayna Ebrahimi via email sayna@berkeley.edu 44 | 45 | 46 | 47 | ## License 48 | This source code is released under The MIT License found in the LICENSE file in the root directory of this source tree. 49 | 50 | 51 | ## Acknowledgements 52 | Our code structure is inspired by [HAT](https://github.com/joansj/hat.). 53 | -------------------------------------------------------------------------------- /ACL-resnet/data/split_miniimagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pickle 7 | import numpy as np 8 | import os 9 | 10 | np.random.seed(1234) 11 | 12 | # we want 500 for training, 100 for test for wach class 13 | n = 500 14 | 15 | def get_total(data): 16 | data_x, data_y = [], [] 17 | for k, v in data.items(): 18 | for i in range(len(v)): 19 | data_x.append(v[i]) 20 | data_y.append(k) 21 | d = {} 22 | d['images'] = data_x 23 | d['labels'] = data_y 24 | return d 25 | 26 | 27 | # loading the pickled data 28 | with open(os.path.join('../data/miniimagenet/data.pkl'), 'rb') as f: 29 | data_dict = pickle.load(f) 30 | data = data_dict['images'] 31 | labels = data_dict['labels'] 32 | 33 | # split data into classes, 600 images per class 34 | class_dict = {} 35 | for i in range(len(set(labels))): 36 | class_dict[i] = [] 37 | 38 | for i in range(len(data)): 39 | class_dict[labels[i]].append(data[i]) 40 | 41 | # Split data for each class to 500 and 100 42 | x_train, x_test = {}, {} 43 | for i in range(len(set(labels))): 44 | np.random.shuffle(class_dict[i]) 45 | x_test[i] = class_dict[i][n:] 46 | x_train[i] = class_dict[i][:n] 47 | 48 | # mix the data 49 | d_train = get_total(x_train) 50 | d_test = get_total(x_test) 51 | 52 | with open(os.path.join('../data/miniimagenet/train.pkl'), 'wb') as f: 53 | pickle.dump(d_train, f) 54 | with open(os.path.join('../data/miniimagenet/test.pkl'), 'wb') as f: 55 | pickle.dump(d_test, f) -------------------------------------------------------------------------------- /ACL-resnet/src/configs/config_cifar100.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "cifar100" 8 | data_dir: "../data" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "yes" 13 | lam: 1 #diff loss lambda 14 | adv: 0.05 #adversarial loss regularizer 15 | orth: 0.1 #diff loss regularizer 16 | 17 | ntasks: 20 18 | use_memory: "no" 19 | samples: 0 20 | 21 | e_lr: 0.005 22 | e_wd: 0.01 23 | s_step: 5 24 | 25 | d_lr: 0.001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | arch: 'resnet' 35 | nlayers: 2 36 | units: 175 37 | head_units: 32 38 | latent_dim: 128 39 | 40 | batch_size: 64 41 | nepochs: 200 42 | pc_valid: 0.15 43 | 44 | 45 | workers: 4 46 | device: "cuda" 47 | -------------------------------------------------------------------------------- /ACL-resnet/src/configs/config_miniimagenet.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "miniimagenet" 8 | data_dir: "../data" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "yes" 13 | lam: 1 #diff loss lambda 14 | adv: 0.005 #adversarial loss regularizer 15 | orth: 0.1 #diff loss regularizer 16 | 17 | ntasks: 20 18 | use_memory: "no" 19 | samples: 0 20 | 21 | e_lr: 0.003 22 | e_wd: 0.01 23 | s_step: 5 24 | 25 | d_lr: 0.001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | arch: 'resnet' 35 | nlayers: 2 36 | units: 175 37 | head_units: 128 38 | latent_dim: 256 39 | 40 | batch_size: 64 41 | nepochs: 200 42 | pc_valid: 0.02 43 | 44 | 45 | workers: 4 46 | device: "cuda:4" 47 | -------------------------------------------------------------------------------- /ACL-resnet/src/dataloaders/cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import print_function 7 | from PIL import Image 8 | import os 9 | import os.path 10 | import sys 11 | 12 | 13 | if sys.version_info[0] == 2: 14 | import cPickle as pickle 15 | else: 16 | import pickle 17 | 18 | import torch.utils.data as data 19 | import numpy as np 20 | 21 | import torch 22 | from torchvision import datasets, transforms 23 | 24 | from utils import * 25 | 26 | 27 | class iCIFAR10(datasets.CIFAR10): 28 | 29 | def __init__(self, root, classes, memory_classes, memory, task_num, train, transform=None, target_transform=None, download=True): 30 | 31 | super(iCIFAR10, self).__init__(root, transform=transform, 32 | target_transform=target_transform, download=True) 33 | self.train = train # training set or test set 34 | if not isinstance(classes, list): 35 | classes = [classes] 36 | 37 | self.class_mapping = {c: i for i, c in enumerate(classes)} 38 | self.class_indices = {} 39 | 40 | for cls in classes: 41 | self.class_indices[self.class_mapping[cls]] = [] 42 | 43 | if self.train: 44 | train_data = [] 45 | train_labels = [] 46 | train_tt = [] # task module labels 47 | train_td = [] # disctiminator labels 48 | 49 | for i in range(len(self.data)): 50 | if self.targets[i] in classes: 51 | train_data.append(self.data[i]) 52 | train_labels.append(self.class_mapping[self.targets[i]]) 53 | train_tt.append(task_num) 54 | train_td.append(task_num+1) 55 | self.class_indices[self.class_mapping[self.targets[i]]].append(i) 56 | 57 | if memory_classes: 58 | for task_id in range(task_num): 59 | for i in range(len(memory[task_id]['x'])): 60 | if memory[task_id]['y'][i] in range(len(memory_classes[task_id])): 61 | train_data.append(memory[task_id]['x'][i]) 62 | train_labels.append(memory[task_id]['y'][i]) 63 | train_tt.append(memory[task_id]['tt'][i]) 64 | train_td.append(memory[task_id]['td'][i]) 65 | 66 | self.train_data = np.array(train_data) 67 | self.train_labels = train_labels 68 | self.train_tt = train_tt 69 | self.train_td = train_td 70 | 71 | 72 | if not self.train: 73 | f = self.test_list[0][0] 74 | file = os.path.join(self.root, self.base_folder, f) 75 | fo = open(file, 'rb') 76 | if sys.version_info[0] == 2: 77 | entry = pickle.load(fo) 78 | else: 79 | entry = pickle.load(fo, encoding='latin1') 80 | self.test_data = entry['data'] 81 | if 'labels' in entry: 82 | self.test_labels = entry['labels'] 83 | else: 84 | 85 | self.test_labels = entry['fine_labels'] 86 | fo.close() 87 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 88 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 89 | 90 | test_data = [] 91 | test_labels = [] 92 | test_tt = [] # task module labels 93 | test_td = [] # disctiminator labels 94 | for i in range(len(self.test_data)): 95 | if self.test_labels[i] in classes: 96 | test_data.append(self.test_data[i]) 97 | test_labels.append(self.class_mapping[self.test_labels[i]]) 98 | test_tt.append(task_num) 99 | test_td.append(task_num + 1) 100 | self.class_indices[self.class_mapping[self.test_labels[i]]].append(i) 101 | 102 | self.test_data = np.array(test_data) 103 | self.test_labels = test_labels 104 | self.test_tt = test_tt 105 | self.test_td = test_td 106 | 107 | 108 | def __getitem__(self, index): 109 | if self.train: 110 | img, target, tt, td = self.train_data[index], self.train_labels[index], self.train_tt[index], self.train_td[index] 111 | else: 112 | img, target, tt, td = self.test_data[index], self.test_labels[index], self.test_tt[index], self.test_td[index] 113 | 114 | # doing this so that it is consistent with all other datasets 115 | # to return a PIL Image 116 | try: 117 | img = Image.fromarray(img) 118 | except: 119 | pass 120 | 121 | try: 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | except: 125 | pass 126 | try: 127 | if self.target_transform is not None: 128 | target = self.target_transform(target) 129 | except: 130 | pass 131 | 132 | return img, target, tt, td 133 | 134 | 135 | 136 | 137 | def __len__(self): 138 | if self.train: 139 | return len(self.train_data) 140 | else: 141 | return len(self.test_data) 142 | 143 | 144 | 145 | class iCIFAR100(iCIFAR10): 146 | """`CIFAR100 `_ Dataset. 147 | This is a subclass of the `CIFAR10` Dataset. 148 | """ 149 | base_folder = 'cifar-100-python' 150 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 151 | filename = "cifar-100-python.tar.gz" 152 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 153 | train_list = [ 154 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 155 | ] 156 | 157 | test_list = [ 158 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 159 | ] 160 | meta = { 161 | 'filename': 'meta', 162 | 'key': 'fine_label_names', 163 | 'md5': '7973b15100ade9c7d40fb424638fde48', 164 | } 165 | 166 | 167 | 168 | class DatasetGen(object): 169 | """docstring for DatasetGen""" 170 | 171 | def __init__(self, args): 172 | super(DatasetGen, self).__init__() 173 | 174 | self.seed = args.seed 175 | self.batch_size=args.batch_size 176 | self.pc_valid=args.pc_valid 177 | self.root = args.data_dir 178 | self.latent_dim = args.latent_dim 179 | 180 | self.num_tasks = args.ntasks 181 | self.num_classes = 100 182 | 183 | self.num_samples = args.samples 184 | 185 | 186 | self.inputsize = [3,32,32] 187 | mean=[x/255 for x in [125.3,123.0,113.9]] 188 | std=[x/255 for x in [63.0,62.1,66.7]] 189 | 190 | self.transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 191 | 192 | self.taskcla = [[t, int(self.num_classes/self.num_tasks)] for t in range(self.num_tasks)] 193 | 194 | self.indices = {} 195 | self.dataloaders = {} 196 | self.idx={} 197 | 198 | self.num_workers = args.workers 199 | self.pin_memory = True 200 | 201 | np.random.seed(self.seed) 202 | task_ids = np.split(np.random.permutation(self.num_classes),self.num_tasks) 203 | self.task_ids = [list(arr) for arr in task_ids] 204 | 205 | 206 | self.train_set = {} 207 | self.test_set = {} 208 | self.train_split = {} 209 | 210 | self.task_memory = {} 211 | for i in range(self.num_tasks): 212 | self.task_memory[i] = {} 213 | self.task_memory[i]['x'] = [] 214 | self.task_memory[i]['y'] = [] 215 | self.task_memory[i]['tt'] = [] 216 | self.task_memory[i]['td'] = [] 217 | 218 | self.use_memory = args.use_memory 219 | 220 | def get(self, task_id): 221 | 222 | self.dataloaders[task_id] = {} 223 | sys.stdout.flush() 224 | 225 | 226 | if task_id == 0: 227 | memory_classes = None 228 | memory=None 229 | else: 230 | memory_classes = self.task_ids 231 | memory = self.task_memory 232 | 233 | self.train_set[task_id] = iCIFAR100(root=self.root, classes=self.task_ids[task_id], memory_classes=memory_classes, 234 | memory=memory, task_num=task_id, train=True, download=True, transform=self.transformation) 235 | self.test_set[task_id] = iCIFAR100(root=self.root, classes=self.task_ids[task_id], memory_classes=None, 236 | memory=None, task_num=task_id, train=False, 237 | download=True, transform=self.transformation) 238 | 239 | 240 | 241 | 242 | 243 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 244 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split]) 245 | 246 | self.train_split[task_id] = train_split 247 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, num_workers=self.num_workers, 248 | pin_memory=self.pin_memory,shuffle=True) 249 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=int(self.batch_size * self.pc_valid), 250 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 251 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 252 | pin_memory=self.pin_memory,shuffle=True) 253 | 254 | 255 | self.dataloaders[task_id]['train'] = train_loader 256 | self.dataloaders[task_id]['valid'] = valid_loader 257 | self.dataloaders[task_id]['test'] = test_loader 258 | self.dataloaders[task_id]['name'] = 'CIFAR100-{}-{}'.format(task_id,self.task_ids[task_id]) 259 | 260 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 261 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 262 | print ("Train+Val set size: {} images of {}x{}".format(len(valid_loader.dataset)+len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 263 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 264 | 265 | if self.use_memory == 'yes' and self.num_samples > 0 : 266 | self.update_memory(task_id) 267 | 268 | return self.dataloaders 269 | 270 | 271 | 272 | def update_memory(self, task_id): 273 | 274 | num_samples_per_class = self.num_samples // len(self.task_ids[task_id]) 275 | mem_class_mapping = {i: i for i, c in enumerate(self.task_ids[task_id])} 276 | 277 | 278 | # Looping over each class in the current task 279 | for i in range(len(self.task_ids[task_id])): 280 | # Getting all samples for this class 281 | data_loader = torch.utils.data.DataLoader(self.train_split[task_id], batch_size=1, 282 | num_workers=self.num_workers, 283 | pin_memory=self.pin_memory) 284 | # Randomly choosing num_samples_per_class for this class 285 | randind = torch.randperm(len(data_loader.dataset))[:num_samples_per_class] 286 | 287 | # Adding the selected samples to memory 288 | for ind in randind: 289 | self.task_memory[task_id]['x'].append(data_loader.dataset[ind][0]) 290 | self.task_memory[task_id]['y'].append(mem_class_mapping[i]) 291 | self.task_memory[task_id]['tt'].append(data_loader.dataset[ind][2]) 292 | self.task_memory[task_id]['td'].append(data_loader.dataset[ind][3]) 293 | 294 | print ('Memory updated by adding {} images'.format(len(self.task_memory[task_id]['x']))) 295 | 296 | 297 | -------------------------------------------------------------------------------- /ACL-resnet/src/dataloaders/miniimagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import print_function 7 | from PIL import Image 8 | import os 9 | import os.path 10 | import sys 11 | 12 | if sys.version_info[0] == 2: 13 | import cPickle as pickle 14 | else: 15 | import pickle 16 | 17 | import torch.utils.data as data 18 | import numpy as np 19 | 20 | import torch 21 | from torchvision import transforms 22 | 23 | from utils import * 24 | 25 | 26 | 27 | class MiniImageNet(torch.utils.data.Dataset): 28 | 29 | def __init__(self, root, train): 30 | super(MiniImageNet, self).__init__() 31 | if train: 32 | self.name='train' 33 | else: 34 | self.name='test' 35 | root = os.path.join(root, 'miniimagenet') 36 | with open(os.path.join(root,'{}.pkl'.format(self.name)), 'rb') as f: 37 | data_dict = pickle.load(f) 38 | 39 | self.data = data_dict['images'] 40 | self.labels = data_dict['labels'] 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self, i): 46 | img, label = self.data[i], self.labels[i] 47 | return img, label 48 | 49 | 50 | class iMiniImageNet(MiniImageNet): 51 | 52 | def __init__(self, root, classes, memory_classes, memory, task_num, train, transform=None): 53 | super(iMiniImageNet, self).__init__(root=root, train=train) 54 | 55 | self.transform = transform 56 | if not isinstance(classes, list): 57 | classes = [classes] 58 | 59 | self.class_mapping = {c: i for i, c in enumerate(classes)} 60 | self.class_indices = {} 61 | 62 | for cls in classes: 63 | self.class_indices[self.class_mapping[cls]] = [] 64 | 65 | data = [] 66 | labels = [] 67 | tt = [] # task module labels 68 | td = [] # disctiminator labels 69 | 70 | for i in range(len(self.data)): 71 | if self.labels[i] in classes: 72 | data.append(self.data[i]) 73 | labels.append(self.class_mapping[self.labels[i]]) 74 | tt.append(task_num) 75 | td.append(task_num+1) 76 | self.class_indices[self.class_mapping[self.labels[i]]].append(i) 77 | 78 | if memory_classes: 79 | for task_id in range(task_num): 80 | for i in range(len(memory[task_id]['x'])): 81 | if memory[task_id]['y'][i] in range(len(memory_classes[task_id])): 82 | data.append(memory[task_id]['x'][i]) 83 | labels.append(memory[task_id]['y'][i]) 84 | tt.append(memory[task_id]['tt'][i]) 85 | td.append(memory[task_id]['td'][i]) 86 | 87 | self.data = np.array(data) 88 | self.labels = labels 89 | self.tt = tt 90 | self.td = td 91 | 92 | 93 | 94 | def __getitem__(self, index): 95 | 96 | img, target, tt, td = self.data[index], self.labels[index], self.tt[index], self.td[index] 97 | 98 | # doing this so that it is consistent with all other datasets 99 | # to return a PIL Image 100 | if not torch.is_tensor(img): 101 | img = Image.fromarray(img) 102 | img = self.transform(img) 103 | return img, target, tt, td 104 | 105 | 106 | 107 | 108 | def __len__(self): 109 | return len(self.data) 110 | 111 | 112 | 113 | 114 | class DatasetGen(object): 115 | """docstring for DatasetGen""" 116 | 117 | def __init__(self, args): 118 | super(DatasetGen, self).__init__() 119 | 120 | self.seed = args.seed 121 | self.batch_size=args.batch_size 122 | self.pc_valid=args.pc_valid 123 | self.root = args.data_dir 124 | self.latent_dim = args.latent_dim 125 | self.use_memory = args.use_memory 126 | 127 | self.num_tasks = args.ntasks 128 | self.num_classes = 100 129 | 130 | self.num_samples = args.samples 131 | 132 | self.inputsize = [3,84,84] 133 | mean = [0.485, 0.456, 0.406] 134 | std = [0.229, 0.224, 0.225] 135 | 136 | self.transformation = transforms.Compose([ 137 | transforms.Resize((84,84)), 138 | transforms.ToTensor(), 139 | transforms.Normalize(mean=mean, std=std)]) 140 | 141 | self.taskcla = [[t, int(self.num_classes/self.num_tasks)] for t in range(self.num_tasks)] 142 | 143 | self.indices = {} 144 | self.dataloaders = {} 145 | self.idx={} 146 | 147 | self.num_workers = args.workers 148 | self.pin_memory = True 149 | 150 | np.random.seed(self.seed) 151 | task_ids = np.split(np.random.permutation(self.num_classes),self.num_tasks) 152 | self.task_ids = [list(arr) for arr in task_ids] 153 | 154 | self.train_set = {} 155 | self.train_split = {} 156 | self.test_set = {} 157 | 158 | 159 | self.task_memory = {} 160 | for i in range(self.num_tasks): 161 | self.task_memory[i] = {} 162 | self.task_memory[i]['x'] = [] 163 | self.task_memory[i]['y'] = [] 164 | self.task_memory[i]['tt'] = [] 165 | self.task_memory[i]['td'] = [] 166 | 167 | 168 | 169 | def get(self, task_id): 170 | 171 | self.dataloaders[task_id] = {} 172 | sys.stdout.flush() 173 | 174 | if task_id == 0: 175 | memory_classes = None 176 | memory=None 177 | else: 178 | memory_classes = self.task_ids 179 | memory = self.task_memory 180 | 181 | 182 | self.train_set[task_id] = iMiniImageNet(root=self.root, classes=self.task_ids[task_id], 183 | memory_classes=memory_classes, memory=memory, 184 | task_num=task_id, train=True, transform=self.transformation) 185 | 186 | self.test_set[task_id] = iMiniImageNet(root=self.root, classes=self.task_ids[task_id], memory_classes=None, 187 | memory=None, task_num=task_id, train=False, transform=self.transformation) 188 | 189 | 190 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 191 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split]) 192 | self.train_split[task_id] = train_split 193 | 194 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, num_workers=self.num_workers, 195 | pin_memory=self.pin_memory,shuffle=True) 196 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=int(self.batch_size * self.pc_valid), 197 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 198 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 199 | pin_memory=self.pin_memory, shuffle=True) 200 | 201 | 202 | self.dataloaders[task_id]['train'] = train_loader 203 | self.dataloaders[task_id]['valid'] = valid_loader 204 | self.dataloaders[task_id]['test'] = test_loader 205 | self.dataloaders[task_id]['name'] = 'iMiniImageNet-{}-{}'.format(task_id,self.task_ids[task_id]) 206 | self.dataloaders[task_id]['tsne'] = torch.utils.data.DataLoader(self.test_set[task_id], 207 | batch_size=len(test_loader.dataset), 208 | num_workers=self.num_workers, 209 | pin_memory=self.pin_memory, shuffle=True) 210 | 211 | print ("Task ID: ", task_id) 212 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 213 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 214 | print ("Train+Val set size: {} images of {}x{}".format(len(valid_loader.dataset)+len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 215 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 216 | 217 | if self.use_memory == 'yes' and self.num_samples > 0 : 218 | self.update_memory(task_id) 219 | 220 | 221 | return self.dataloaders 222 | 223 | 224 | 225 | def update_memory(self, task_id): 226 | num_samples_per_class = self.num_samples // len(self.task_ids[task_id]) 227 | mem_class_mapping = {i: i for i, c in enumerate(self.task_ids[task_id])} 228 | 229 | for i in range(len(self.task_ids[task_id])): 230 | data_loader = torch.utils.data.DataLoader(self.train_split[task_id], batch_size=1, 231 | num_workers=self.num_workers, 232 | pin_memory=self.pin_memory) 233 | 234 | randind = torch.randperm(len(data_loader.dataset))[:num_samples_per_class] # randomly sample some data 235 | 236 | 237 | for ind in randind: 238 | self.task_memory[task_id]['x'].append(data_loader.dataset[ind][0]) 239 | self.task_memory[task_id]['y'].append(mem_class_mapping[i]) 240 | self.task_memory[task_id]['tt'].append(data_loader.dataset[ind][2]) 241 | self.task_memory[task_id]['td'].append(data_loader.dataset[ind][3]) 242 | 243 | print ('Memory updated by adding {} images'.format(len(self.task_memory[task_id]['x']))) -------------------------------------------------------------------------------- /ACL-resnet/src/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # https://github.com/pytorch/vision/blob/8635be94d1216f10fb8302da89233bd86445e449/torchvision/datasets/utils.py 8 | 9 | import os 10 | import os.path 11 | import hashlib 12 | import gzip 13 | import errno 14 | import tarfile 15 | import zipfile 16 | import numpy as np 17 | import torch 18 | import codecs 19 | 20 | from torch.utils.model_zoo import tqdm 21 | 22 | 23 | def gen_bar_updater(): 24 | pbar = tqdm(total=None) 25 | 26 | def bar_update(count, block_size, total_size): 27 | if pbar.total is None and total_size: 28 | pbar.total = total_size 29 | progress_bytes = count * block_size 30 | pbar.update(progress_bytes - pbar.n) 31 | 32 | return bar_update 33 | 34 | 35 | def calculate_md5(fpath, chunk_size=1024 * 1024): 36 | md5 = hashlib.md5() 37 | with open(fpath, 'rb') as f: 38 | for chunk in iter(lambda: f.read(chunk_size), b''): 39 | md5.update(chunk) 40 | return md5.hexdigest() 41 | 42 | 43 | def check_md5(fpath, md5, **kwargs): 44 | return md5 == calculate_md5(fpath, **kwargs) 45 | 46 | 47 | def check_integrity(fpath, md5=None): 48 | if not os.path.isfile(fpath): 49 | return False 50 | if md5 is None: 51 | return True 52 | return check_md5(fpath, md5) 53 | 54 | 55 | def makedir_exist_ok(dirpath): 56 | """ 57 | Python2 support for os.makedirs(.., exist_ok=True) 58 | """ 59 | try: 60 | os.makedirs(dirpath) 61 | except OSError as e: 62 | if e.errno == errno.EEXIST: 63 | pass 64 | else: 65 | raise 66 | 67 | 68 | def download_url(url, root, filename=None, md5=None): 69 | """Download a file from a url and place it in root. 70 | 71 | Args: 72 | url (str): URL to download file from 73 | root (str): Directory to place downloaded file in 74 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 75 | md5 (str, optional): MD5 checksum of the download. If None, do not check 76 | """ 77 | from six.moves import urllib 78 | 79 | root = os.path.expanduser(root) 80 | if not filename: 81 | filename = os.path.basename(url) 82 | fpath = os.path.join(root, filename) 83 | 84 | makedir_exist_ok(root) 85 | 86 | # downloads file 87 | if check_integrity(fpath, md5): 88 | print('Using downloaded and verified file: ' + fpath) 89 | else: 90 | try: 91 | print('Downloading ' + url + ' to ' + fpath) 92 | urllib.request.urlretrieve( 93 | url, fpath, 94 | reporthook=gen_bar_updater() 95 | ) 96 | except (urllib.error.URLError, IOError) as e: 97 | if url[:5] == 'https': 98 | url = url.replace('https:', 'http:') 99 | print('Failed download. Trying https -> http instead.' 100 | ' Downloading ' + url + ' to ' + fpath) 101 | urllib.request.urlretrieve( 102 | url, fpath, 103 | reporthook=gen_bar_updater() 104 | ) 105 | else: 106 | raise e 107 | 108 | 109 | def list_dir(root, prefix=False): 110 | """List all directories at a given root 111 | 112 | Args: 113 | root (str): Path to directory whose folders need to be listed 114 | prefix (bool, optional): If true, prepends the path to each result, otherwise 115 | only returns the name of the directories found 116 | """ 117 | root = os.path.expanduser(root) 118 | directories = list( 119 | filter( 120 | lambda p: os.path.isdir(os.path.join(root, p)), 121 | os.listdir(root) 122 | ) 123 | ) 124 | 125 | if prefix is True: 126 | directories = [os.path.join(root, d) for d in directories] 127 | 128 | return directories 129 | 130 | 131 | def list_files(root, suffix, prefix=False): 132 | """List all files ending with a suffix at a given root 133 | 134 | Args: 135 | root (str): Path to directory whose folders need to be listed 136 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 137 | It uses the Python "str.endswith" method and is passed directly 138 | prefix (bool, optional): If true, prepends the path to each result, otherwise 139 | only returns the name of the files found 140 | """ 141 | root = os.path.expanduser(root) 142 | files = list( 143 | filter( 144 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 145 | os.listdir(root) 146 | ) 147 | ) 148 | 149 | if prefix is True: 150 | files = [os.path.join(root, d) for d in files] 151 | 152 | return files 153 | 154 | 155 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 156 | """Download a Google Drive file from and place it in root. 157 | 158 | Args: 159 | file_id (str): id of file to be downloaded 160 | root (str): Directory to place downloaded file in 161 | filename (str, optional): Name to save the file under. If None, use the id of the file. 162 | md5 (str, optional): MD5 checksum of the download. If None, do not check 163 | """ 164 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 165 | import requests 166 | url = "https://docs.google.com/uc?export=download" 167 | 168 | root = os.path.expanduser(root) 169 | if not filename: 170 | filename = file_id 171 | fpath = os.path.join(root, filename) 172 | 173 | makedir_exist_ok(root) 174 | 175 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 176 | print('Using downloaded and verified file: ' + fpath) 177 | else: 178 | session = requests.Session() 179 | 180 | response = session.get(url, params={'id': file_id}, stream=True) 181 | token = _get_confirm_token(response) 182 | 183 | if token: 184 | params = {'id': file_id, 'confirm': token} 185 | response = session.get(url, params=params, stream=True) 186 | 187 | _save_response_content(response, fpath) 188 | 189 | 190 | def _get_confirm_token(response): 191 | for key, value in response.cookies.items(): 192 | if key.startswith('download_warning'): 193 | return value 194 | 195 | return None 196 | 197 | 198 | def _save_response_content(response, destination, chunk_size=32768): 199 | with open(destination, "wb") as f: 200 | pbar = tqdm(total=None) 201 | progress = 0 202 | for chunk in response.iter_content(chunk_size): 203 | if chunk: # filter out keep-alive new chunks 204 | f.write(chunk) 205 | progress += len(chunk) 206 | pbar.update(progress - pbar.n) 207 | pbar.close() 208 | 209 | 210 | def _is_tar(filename): 211 | return filename.endswith(".tar") 212 | 213 | 214 | def _is_targz(filename): 215 | return filename.endswith(".tar.gz") 216 | 217 | 218 | def _is_gzip(filename): 219 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 220 | 221 | 222 | def _is_zip(filename): 223 | return filename.endswith(".zip") 224 | 225 | 226 | def extract_archive(from_path, to_path=None, remove_finished=False): 227 | if to_path is None: 228 | to_path = os.path.dirname(from_path) 229 | 230 | if _is_tar(from_path): 231 | with tarfile.open(from_path, 'r') as tar: 232 | tar.extractall(path=to_path) 233 | elif _is_targz(from_path): 234 | with tarfile.open(from_path, 'r:gz') as tar: 235 | tar.extractall(path=to_path) 236 | elif _is_gzip(from_path): 237 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 238 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 239 | out_f.write(zip_f.read()) 240 | elif _is_zip(from_path): 241 | with zipfile.ZipFile(from_path, 'r') as z: 242 | z.extractall(to_path) 243 | else: 244 | raise ValueError("Extraction of {} not supported".format(from_path)) 245 | 246 | if remove_finished: 247 | os.remove(from_path) 248 | 249 | 250 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None, 251 | md5=None, remove_finished=False): 252 | download_root = os.path.expanduser(download_root) 253 | if extract_root is None: 254 | extract_root = download_root 255 | if not filename: 256 | filename = os.path.basename(url) 257 | 258 | download_url(url, download_root, filename, md5) 259 | 260 | archive = os.path.join(download_root, filename) 261 | print("Extracting {} to {}".format(archive, extract_root)) 262 | extract_archive(archive, extract_root, remove_finished) 263 | 264 | 265 | def iterable_to_str(iterable): 266 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 267 | 268 | 269 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): 270 | if not isinstance(value, torch._six.string_classes): 271 | if arg is None: 272 | msg = "Expected type str, but got type {type}." 273 | else: 274 | msg = "Expected type str for argument {arg}, but got type {type}." 275 | msg = msg.format(type=type(value), arg=arg) 276 | raise ValueError(msg) 277 | 278 | if valid_values is None: 279 | return value 280 | 281 | if value not in valid_values: 282 | if custom_msg is not None: 283 | msg = custom_msg 284 | else: 285 | msg = ("Unknown value '{value}' for argument {arg}. " 286 | "Valid values are {{{valid_values}}}.") 287 | msg = msg.format(value=value, arg=arg, 288 | valid_values=iterable_to_str(valid_values)) 289 | raise ValueError(msg) 290 | 291 | return value 292 | 293 | 294 | def get_int(b): 295 | return int(codecs.encode(b, 'hex'), 16) 296 | 297 | 298 | def open_maybe_compressed_file(path): 299 | """Return a file object that possibly decompresses 'path' on the fly. 300 | Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. 301 | """ 302 | if not isinstance(path, torch._six.string_classes): 303 | return path 304 | if path.endswith('.gz'): 305 | import gzip 306 | return gzip.open(path, 'rb') 307 | if path.endswith('.xz'): 308 | import lzma 309 | return lzma.open(path, 'rb') 310 | return open(path, 'rb') 311 | 312 | 313 | def read_sn3_pascalvincent_tensor(path, strict=True): 314 | """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). 315 | Argument may be a filename, compressed filename, or file object. 316 | """ 317 | # typemap 318 | if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): 319 | read_sn3_pascalvincent_tensor.typemap = { 320 | 8: (torch.uint8, np.uint8, np.uint8), 321 | 9: (torch.int8, np.int8, np.int8), 322 | 11: (torch.int16, np.dtype('>i2'), 'i2'), 323 | 12: (torch.int32, np.dtype('>i4'), 'i4'), 324 | 13: (torch.float32, np.dtype('>f4'), 'f4'), 325 | 14: (torch.float64, np.dtype('>f8'), 'f8')} 326 | # read 327 | with open_maybe_compressed_file(path) as f: 328 | data = f.read() 329 | # parse 330 | magic = get_int(data[0:4]) 331 | nd = magic % 256 332 | ty = magic // 256 333 | assert nd >= 1 and nd <= 3 334 | assert ty >= 8 and ty <= 14 335 | m = read_sn3_pascalvincent_tensor.typemap[ty] 336 | s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] 337 | parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) 338 | assert parsed.shape[0] == np.prod(s) or not strict 339 | return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) 340 | 341 | 342 | def read_label_file(path): 343 | with open(path, 'rb') as f: 344 | x = read_sn3_pascalvincent_tensor(f, strict=False) 345 | assert(x.dtype == torch.uint8) 346 | assert(x.ndimension() == 1) 347 | return x.long() 348 | 349 | 350 | def read_image_file(path): 351 | with open(path, 'rb') as f: 352 | x = read_sn3_pascalvincent_tensor(f, strict=False) 353 | assert(x.dtype == torch.uint8) 354 | assert(x.ndimension() == 3) 355 | return x -------------------------------------------------------------------------------- /ACL-resnet/src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os,argparse,time 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from copy import deepcopy 10 | import torch 11 | 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import utils 17 | 18 | 19 | tstart=time.time() 20 | 21 | 22 | # Arguments 23 | parser = argparse.ArgumentParser(description='Adversarial Continual Learning...') 24 | # Load the config file 25 | parser.add_argument('--config', type=str, default='./configs/config_mnist5.yml') 26 | flags = parser.parse_args() 27 | args = OmegaConf.load(flags.config) 28 | 29 | print() 30 | 31 | 32 | ######################################################################################################################## 33 | 34 | # Args -- Experiment 35 | if args.experiment=='pmnist': 36 | from dataloaders import pmnist as datagenerator 37 | elif args.experiment=='mnist5': 38 | from dataloaders import mnist5 as datagenerator 39 | elif args.experiment=='cifar100': 40 | from dataloaders import cifar100 as datagenerator 41 | elif args.experiment=='miniimagenet': 42 | from dataloaders import miniimagenet as datagenerator 43 | elif args.experiment=='multidatasets': 44 | from dataloaders import mulitidatasets as datagenerator 45 | else: 46 | raise NotImplementedError 47 | 48 | from acl import ACL as approach 49 | 50 | # Args -- Network 51 | 52 | 53 | if args.experiment == 'mnist5' or args.experiment == 'pmnist': 54 | from networks import mlp_acl as network 55 | elif args.experiment == 'cifar100' or args.experiment == 'miniimagenet' or args.experiment == 'multidatasets': 56 | if args.arch == 'alexnet': 57 | from networks import alexnet_acl as network 58 | elif args.arch == 'resnet': 59 | from networks import resnet_acl as network 60 | else: 61 | raise NotImplementedError 62 | else: 63 | raise NotImplementedError 64 | 65 | ######################################################################################################################## 66 | 67 | def run(args, run_id): 68 | 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | if torch.cuda.is_available(): 72 | torch.cuda.manual_seed(args.seed) 73 | 74 | # Faster run but not deterministic: 75 | # torch.backends.cudnn.benchmark = True 76 | 77 | # To get deterministic results that match with paper at cost of lower speed: 78 | torch.backends.cudnn.deterministic = True 79 | torch.backends.cudnn.benchmark = False 80 | 81 | # Data loader 82 | print('Instantiate data generators and model...') 83 | dataloader = datagenerator.DatasetGen(args) 84 | args.taskcla, args.inputsize = dataloader.taskcla, dataloader.inputsize 85 | if args.experiment == 'multidatasets': args.lrs = dataloader.lrs 86 | 87 | # Model 88 | net = network.Net(args) 89 | net = net.to(args.device) 90 | 91 | net.print_model_size() 92 | # print (net) 93 | 94 | # Approach 95 | appr=approach(net,args,network=network) 96 | 97 | # Loop tasks 98 | acc=np.zeros((len(args.taskcla),len(args.taskcla)),dtype=np.float32) 99 | lss=np.zeros((len(args.taskcla),len(args.taskcla)),dtype=np.float32) 100 | 101 | for t,ncla in args.taskcla: 102 | 103 | print('*'*250) 104 | dataset = dataloader.get(t) 105 | print(' '*105, 'Dataset {:2d} ({:s})'.format(t+1,dataset[t]['name'])) 106 | print('*'*250) 107 | 108 | # Train 109 | appr.train(t,dataset[t]) 110 | print('-'*250) 111 | print() 112 | 113 | for u in range(t+1): 114 | # Load previous model and replace the shared module with the current one 115 | test_model = appr.load_model(u) 116 | test_res = appr.test(dataset[u]['test'], u, model=test_model) 117 | 118 | print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<'.format(u, dataset[u]['name'], 119 | test_res['loss_t'], 120 | test_res['acc_t'])) 121 | acc[t, u] = test_res['acc_t'] 122 | lss[t, u] = test_res['loss_t'] 123 | 124 | 125 | # Save 126 | print() 127 | print('Saved accuracies at '+os.path.join(args.checkpoint,args.output)) 128 | np.savetxt(os.path.join(args.checkpoint,args.output),acc,'%.6f') 129 | 130 | # Extract embeddings to plot in tensorboard for miniimagenet 131 | if args.tsne == 'yes' and args.experiment == 'miniimagenet': 132 | appr.get_tsne_embeddings_first_ten_tasks(dataset, model=appr.load_model(t)) 133 | appr.get_tsne_embeddings_last_three_tasks(dataset, model=appr.load_model(t)) 134 | 135 | avg_acc, gem_bwt = utils.print_log_acc_bwt(args.taskcla, acc, lss, output_path=args.checkpoint, run_id=run_id) 136 | 137 | return avg_acc, gem_bwt 138 | 139 | 140 | 141 | ####################################################################################################################### 142 | 143 | 144 | def main(args): 145 | utils.make_directories(args) 146 | utils.some_sanity_checks(args) 147 | utils.save_code(args) 148 | 149 | print('=' * 100) 150 | print('Arguments =') 151 | for arg in vars(args): 152 | print('\t' + arg + ':', getattr(args, arg)) 153 | print('=' * 100) 154 | 155 | 156 | accuracies, forgetting = [], [] 157 | for n in range(args.num_runs): 158 | args.seed = n 159 | args.output = '{}_{}_tasks_seed_{}.txt'.format(args.experiment, args.ntasks, args.seed) 160 | print ("args.output: ", args.output) 161 | 162 | print (" >>>> Run #", n) 163 | acc, bwt = run(args, n) 164 | accuracies.append(acc) 165 | forgetting.append(bwt) 166 | 167 | 168 | print('*' * 100) 169 | print ("Average over {} runs: ".format(args.num_runs)) 170 | print ('AVG ACC: {:5.4f}% \pm {:5.4f}'.format(np.array(accuracies).mean(), np.array(accuracies).std())) 171 | print ('AVG BWT: {:5.2f}% \pm {:5.4f}'.format(np.array(forgetting).mean(), np.array(forgetting).std())) 172 | 173 | 174 | print ("All Done! ") 175 | print('[Elapsed time = {:.1f} min]'.format((time.time()-tstart)/(60))) 176 | utils.print_time() 177 | 178 | 179 | def test_trained_model(args, final_model_id): 180 | args.seed = 0 181 | 182 | print('Instantiate data generators and model...') 183 | dataloader = datagenerator.DatasetGen(args) 184 | args.taskcla, args.inputsize = dataloader.taskcla, dataloader.inputsize 185 | if args.experiment == 'multidatasets': args.lrs = dataloader.lrs 186 | 187 | def get_model(final_model_id, test_data_id): 188 | # Load the test model 189 | test_net = network.Net(args) 190 | checkpoint_test = torch.load(os.path.join(args.checkpoint, 'model_{}.pth.tar'.format(test_data_id))) 191 | test_net.load_state_dict(checkpoint_test['model_state_dict']) 192 | 193 | # Load your final trained model 194 | net = network.Net(args) 195 | checkpoint = torch.load(os.path.join(args.checkpoint, 'model_{}.pth.tar'.format(final_model_id))) 196 | net.load_state_dict(checkpoint['model_state_dict']) 197 | 198 | # # Change the shared module with the final model's shared module 199 | final_shared = deepcopy(net.shared.state_dict()) 200 | test_net.shared.load_state_dict(final_shared) 201 | test_net = test_net.to(args.device) 202 | 203 | return test_net 204 | 205 | for t,ncla in args.taskcla: 206 | print('*'*250) 207 | dataset = dataloader.get(t) 208 | print(' '*105, 'Dataset {:2d} ({:s})'.format(t+1,dataset[t]['name'])) 209 | print('*'*250) 210 | 211 | # Model 212 | test_model = get_model(final_model_id, test_data_id=t) 213 | 214 | # Approach 215 | appr = approach(test_model, args, network=network) 216 | 217 | # Test 218 | test_res = appr.inference(dataset[t]['test'], t, model=test_model) 219 | 220 | print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.4f}% <<<'.format(t, dataset[t]['name'], 221 | test_res['loss_t'], 222 | test_res['acc_t'])) 223 | 224 | 225 | 226 | 227 | ####################################################################################################################### 228 | 229 | if __name__ == '__main__': 230 | main(args) 231 | # test_trained_model(args, final_model_id=4) -------------------------------------------------------------------------------- /ACL-resnet/src/networks/alexnet_acl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import utils 8 | 9 | class Shared(torch.nn.Module): 10 | 11 | def __init__(self,args): 12 | super(Shared, self).__init__() 13 | 14 | self.ncha,size,_=args.inputsize 15 | self.taskcla=args.taskcla 16 | self.latent_dim = args.latent_dim 17 | 18 | if args.experiment == 'cifar100': 19 | hiddens = [64, 128, 256, 1024, 1024, 512] 20 | 21 | elif args.experiment == 'miniimagenet': 22 | hiddens = [64, 128, 256, 512, 512, 512] 23 | 24 | # ---------------------------------- 25 | elif args.experiment == 'multidatasets': 26 | hiddens = [64, 128, 256, 1024, 1024, 512] 27 | 28 | else: 29 | raise NotImplementedError 30 | 31 | self.conv1=torch.nn.Conv2d(self.ncha,hiddens[0],kernel_size=size//8) 32 | s=utils.compute_conv_output_size(size,size//8) 33 | s=s//2 34 | self.conv2=torch.nn.Conv2d(hiddens[0],hiddens[1],kernel_size=size//10) 35 | s=utils.compute_conv_output_size(s,size//10) 36 | s=s//2 37 | self.conv3=torch.nn.Conv2d(hiddens[1],hiddens[2],kernel_size=2) 38 | s=utils.compute_conv_output_size(s,2) 39 | s=s//2 40 | self.maxpool=torch.nn.MaxPool2d(2) 41 | self.relu=torch.nn.ReLU() 42 | 43 | self.drop1=torch.nn.Dropout(0.2) 44 | self.drop2=torch.nn.Dropout(0.5) 45 | self.fc1=torch.nn.Linear(hiddens[2]*s*s,hiddens[3]) 46 | self.fc2=torch.nn.Linear(hiddens[3],hiddens[4]) 47 | self.fc3=torch.nn.Linear(hiddens[4],hiddens[5]) 48 | self.fc4=torch.nn.Linear(hiddens[5], self.latent_dim) 49 | 50 | 51 | def forward(self, x_s): 52 | x_s = x_s.view_as(x_s) 53 | h = self.maxpool(self.drop1(self.relu(self.conv1(x_s)))) 54 | h = self.maxpool(self.drop1(self.relu(self.conv2(h)))) 55 | h = self.maxpool(self.drop2(self.relu(self.conv3(h)))) 56 | h = h.view(x_s.size(0), -1) 57 | h = self.drop2(self.relu(self.fc1(h))) 58 | h = self.drop2(self.relu(self.fc2(h))) 59 | h = self.drop2(self.relu(self.fc3(h))) 60 | h = self.drop2(self.relu(self.fc4(h))) 61 | return h 62 | 63 | 64 | 65 | class Private(torch.nn.Module): 66 | def __init__(self, args): 67 | super(Private, self).__init__() 68 | 69 | self.ncha,self.size,_=args.inputsize 70 | self.taskcla=args.taskcla 71 | self.latent_dim = args.latent_dim 72 | self.num_tasks = args.ntasks 73 | self.device = args.device 74 | 75 | if args.experiment == 'cifar100': 76 | hiddens=[32,32] 77 | flatten=1152 78 | 79 | elif args.experiment == 'miniimagenet': 80 | # hiddens=[8,8] 81 | # flatten=1800 82 | hiddens=[16,16] 83 | flatten=3600 84 | 85 | 86 | elif args.experiment == 'multidatasets': 87 | hiddens=[32,32] 88 | flatten=1152 89 | 90 | 91 | else: 92 | raise NotImplementedError 93 | 94 | self.task_out = torch.nn.Sequential() 95 | self.task_out.add_module('conv1', torch.nn.Conv2d(self.ncha, hiddens[0], kernel_size=self.size // 8)) 96 | self.task_out.add_module('relu1', torch.nn.ReLU(inplace=True)) 97 | self.task_out.add_module('drop1', torch.nn.Dropout(0.2)) 98 | self.task_out.add_module('maxpool1', torch.nn.MaxPool2d(2)) 99 | self.task_out.add_module('conv2', torch.nn.Conv2d(hiddens[0], hiddens[1], kernel_size=self.size // 10)) 100 | self.task_out.add_module('relu2', torch.nn.ReLU(inplace=True)) 101 | self.task_out.add_module('dropout2', torch.nn.Dropout(0.5)) 102 | self.task_out.add_module('maxpool2', torch.nn.MaxPool2d(2)) 103 | 104 | self.linear = torch.nn.Sequential() 105 | self.linear.add_module('linear1', torch.nn.Linear(flatten, self.latent_dim)) 106 | self.linear.add_module('relu3', torch.nn.ReLU(inplace=True)) 107 | 108 | def forward(self, x): 109 | x = x.view_as(x) 110 | out = self.task_out(x) 111 | out = out.view(out.size(0), -1) 112 | out = self.linear(out) 113 | return out 114 | 115 | # def forward(self, x, task_id): 116 | # x = x.view_as(x) 117 | # out = self.task_out[2*task_id].forward(x) 118 | # out = out.view(out.size(0),-1) 119 | # out = self.task_out[2*task_id+1].forward(out) 120 | # return out 121 | 122 | 123 | 124 | class Net(torch.nn.Module): 125 | 126 | def __init__(self, args): 127 | super(Net, self).__init__() 128 | self.ncha,size,_=args.inputsize 129 | self.taskcla=args.taskcla 130 | self.latent_dim = args.latent_dim 131 | self.ntasks = args.ntasks 132 | self.samples = args.samples 133 | self.image_size = self.ncha*size*size 134 | self.args=args 135 | 136 | self.hidden1 = args.head_units 137 | self.hidden2 = args.head_units//2 138 | 139 | self.shared = Shared(args) 140 | self.private = Private(args) 141 | 142 | self.head = torch.nn.Sequential( 143 | torch.nn.Linear(2*self.latent_dim, self.hidden1), 144 | torch.nn.ReLU(inplace=True), 145 | torch.nn.Dropout(), 146 | torch.nn.Linear(self.hidden1, self.hidden2), 147 | torch.nn.ReLU(inplace=True), 148 | torch.nn.Linear(self.hidden2, self.taskcla[0][1]) 149 | ) 150 | 151 | 152 | def forward(self, x_s, x_p, tt=None): 153 | 154 | x_s = x_s.view_as(x_s) 155 | x_p = x_p.view_as(x_p) 156 | 157 | # x_s = self.shared(x_s) 158 | # x_p = self.private(x_p) 159 | # 160 | # x = torch.cat([x_p, x_s], dim=1) 161 | 162 | # if self.args.experiment == 'multidatasets': 163 | # # if no memory is used this is faster: 164 | # y=[] 165 | # for i,_ in self.taskcla: 166 | # y.append(self.head[i](x)) 167 | # return y[task_id] 168 | # else: 169 | # return torch.stack([self.head[tt[i]].forward(x[i]) for i in range(x.size(0))]) 170 | 171 | # if torch.is_tensor(tt): 172 | # return torch.stack([self.head[tt[i]].forward(x[i]) for i in range(x.size(0))]) 173 | # else: 174 | # return self.head(x) 175 | output = {} 176 | output['shared'] = self.shared(x_s) 177 | output['private'] = self.private(x_p) 178 | concat_features = torch.cat([output['private'], output['shared']], dim=1) 179 | if torch.is_tensor(tt): 180 | output['out'] = torch.stack([self.head[tt[i]].forward(concat_features[i]) for i in range( 181 | concat_features.size(0))]) 182 | else: 183 | output['out'] = self.head(concat_features) 184 | return output 185 | 186 | 187 | # def get_encoded_ftrs(self, x_s, x_p, task_id=None): 188 | # return self.shared(x_s), self.private(x_p) 189 | 190 | def print_model_size(self): 191 | 192 | count_P = sum(p.numel() for p in self.private.parameters() if p.requires_grad) 193 | count_S = sum(p.numel() for p in self.shared.parameters() if p.requires_grad) 194 | count_H = sum(p.numel() for p in self.head.parameters() if p.requires_grad) 195 | 196 | print("Size of the network for one task including (S+P+p)") 197 | print('Num parameters in S = %s ' % (self.pretty_print(count_S))) 198 | print('Num parameters in P = %s ' % (self.pretty_print(count_P))) 199 | print('Num parameters in p = %s ' % (self.pretty_print(count_H))) 200 | print('Num parameters in P+p = %s ' % self.pretty_print(count_P + count_H)) 201 | print('--------------------------> Architecture size in total for all tasks: %s parameters (%sB)' % ( 202 | self.pretty_print(count_S + self.ntasks*count_P + self.ntasks*count_H), 203 | self.pretty_print(4 * (count_S + self.ntasks*count_P + self.ntasks*count_H)))) 204 | 205 | classes_per_task = self.taskcla[0][1] 206 | print("--------------------------> Memory size: %s samples per task (%sB)" % (self.samples*classes_per_task, 207 | self.pretty_print( 208 | self.ntasks * 4 * self.samples * classes_per_task* self.image_size))) 209 | print("------------------------------------------------------------------------------") 210 | print(" TOTAL: %sB" % self.pretty_print( 211 | 4 * (count_S + self.ntasks *count_P + self.ntasks *count_H) + self.ntasks * 4 * self.samples * classes_per_task * self.image_size)) 212 | 213 | def pretty_print(self, num): 214 | magnitude = 0 215 | while abs(num) >= 1000: 216 | magnitude += 1 217 | num /= 1000.0 218 | return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 219 | 220 | 221 | -------------------------------------------------------------------------------- /ACL-resnet/src/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import utils 8 | 9 | class Discriminator(torch.nn.Module): 10 | def __init__(self,args,task_id): 11 | super(Discriminator, self).__init__() 12 | 13 | self.num_tasks=args.ntasks 14 | self.units=args.units 15 | self.latent_dim=args.latent_dim 16 | 17 | 18 | if args.diff == 'yes': 19 | self.dis = torch.nn.Sequential( 20 | GradientReversal(args.lam), 21 | torch.nn.Linear(self.latent_dim, args.units), 22 | torch.nn.LeakyReLU(), 23 | torch.nn.Linear(args.units, args.units), 24 | torch.nn.Linear(args.units, task_id + 2) 25 | ) 26 | else: 27 | self.dis = torch.nn.Sequential( 28 | torch.nn.Linear(self.latent_dim, args.units), 29 | torch.nn.LeakyReLU(), 30 | torch.nn.Linear(args.units, args.units), 31 | torch.nn.Linear(args.units, task_id + 2) 32 | ) 33 | 34 | 35 | def forward(self, z): 36 | return self.dis(z) 37 | 38 | def pretty_print(self, num): 39 | magnitude=0 40 | while abs(num) >= 1000: 41 | magnitude+=1 42 | num/=1000.0 43 | return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 44 | 45 | 46 | def get_size(self): 47 | count=sum(p.numel() for p in self.dis.parameters() if p.requires_grad) 48 | print('Num parameters in D = %s ' % (self.pretty_print(count))) 49 | 50 | 51 | class GradientReversalFunction(torch.autograd.Function): 52 | """ 53 | From: 54 | https://github.com/jvanvugt/pytorch-domain-adaptation/blob/cb65581f20b71ff9883dd2435b2275a1fd4b90df/utils.py#L26 55 | 56 | Gradient Reversal Layer from: 57 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 58 | Forward pass is the identity function. In the backward pass, 59 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 60 | """ 61 | 62 | @staticmethod 63 | def forward(ctx, x, lambda_): 64 | ctx.lambda_ = lambda_ 65 | return x.clone() 66 | 67 | @staticmethod 68 | def backward(ctx, grads): 69 | lambda_ = ctx.lambda_ 70 | lambda_ = grads.new_tensor(lambda_) 71 | dx = -lambda_ * grads 72 | return dx, None 73 | 74 | 75 | class GradientReversal(torch.nn.Module): 76 | def __init__(self, lambda_): 77 | super(GradientReversal, self).__init__() 78 | self.lambda_ = lambda_ 79 | 80 | def forward(self, x): 81 | return GradientReversalFunction.apply(x, self.lambda_) -------------------------------------------------------------------------------- /ACL-resnet/src/networks/mlp_acl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class Private(torch.nn.Module): 10 | def __init__(self, args): 11 | super(Private, self).__init__() 12 | 13 | self.ncha,self.size,_=args.inputsize 14 | self.taskcla=args.taskcla 15 | self.latent_dim = args.latent_dim 16 | self.num_tasks = args.ntasks 17 | self.nhid = args.units 18 | self.device = args.device 19 | 20 | self.task_out = torch.nn.ModuleList() 21 | for _ in range(self.num_tasks): 22 | self.linear = torch.nn.Sequential() 23 | self.linear.add_module('linear', torch.nn.Linear(self.ncha*self.size*self.size, self.latent_dim)) 24 | self.linear.add_module('relu', torch.nn.ReLU(inplace=True)) 25 | self.task_out.append(self.linear) 26 | 27 | def forward(self, x_p, task_id): 28 | x_p = x_p.view(x_p.size(0), -1) 29 | return self.task_out[task_id].forward(x_p) 30 | 31 | 32 | 33 | class Shared(torch.nn.Module): 34 | 35 | def __init__(self,args): 36 | super(Shared, self).__init__() 37 | 38 | ncha,self.size,_=args.inputsize 39 | self.taskcla=args.taskcla 40 | self.latent_dim = args.latent_dim 41 | self.nhid = args.units 42 | self.nlayers = args.nlayers 43 | 44 | self.relu=torch.nn.ReLU() 45 | self.drop=torch.nn.Dropout(0.2) 46 | self.fc1=torch.nn.Linear(ncha*self.size*self.size, self.nhid) 47 | 48 | if self.nlayers == 3: 49 | self.fc2 = torch.nn.Linear(self.nhid, self.nhid) 50 | self.fc3=torch.nn.Linear(self.nhid,self.latent_dim) 51 | else: 52 | self.fc2 = torch.nn.Linear(self.nhid,self.latent_dim) 53 | 54 | def forward(self, x_s): 55 | 56 | h = x_s.view(x_s.size(0), -1) 57 | h = self.drop(self.relu(self.fc1(h))) 58 | h = self.drop(self.relu(self.fc2(h))) 59 | if self.nlayers == 3: 60 | h = self.drop(self.relu(self.fc3(h))) 61 | 62 | return h 63 | 64 | 65 | class Net(torch.nn.Module): 66 | 67 | def __init__(self, args): 68 | super(Net, self).__init__() 69 | ncha,size,_=args.inputsize 70 | self.taskcla=args.taskcla 71 | self.latent_dim = args.latent_dim 72 | self.num_tasks = args.ntasks 73 | self.device = args.device 74 | 75 | if args.experiment == 'mnist5': 76 | self.hidden1 = 28 77 | self.hidden2 = 14 78 | elif args.experiment == 'pmnist': 79 | self.hidden1 = 28 80 | self.hidden2 = 28 81 | 82 | self.samples = args.samples 83 | 84 | self.shared = Shared(args) 85 | self.private = Private(args) 86 | 87 | self.head = torch.nn.ModuleList() 88 | for i in range(self.num_tasks): 89 | self.head.append( 90 | torch.nn.Sequential( 91 | torch.nn.Linear(2 * self.latent_dim, self.hidden1), 92 | torch.nn.ReLU(inplace=True), 93 | torch.nn.Dropout(), 94 | torch.nn.Linear(self.hidden1, self.hidden2), 95 | torch.nn.ReLU(inplace=True), 96 | torch.nn.Linear(self.hidden2, self.taskcla[i][1]) 97 | )) 98 | 99 | def forward(self,x_s, x_p, tt, task_id): 100 | 101 | h_s = x_s.view(x_s.size(0), -1) 102 | h_p = x_s.view(x_p.size(0), -1) 103 | 104 | x_s = self.shared(h_s) 105 | x_p = self.private(h_p, task_id) 106 | 107 | x = torch.cat([x_p, x_s], dim=1) 108 | 109 | return torch.stack([self.head[tt[i]].forward(x[i]) for i in range(x.size(0))]) 110 | 111 | 112 | def get_encoded_ftrs(self, x_s, x_p, task_id): 113 | return self.shared(x_s), self.private(x_p, task_id) 114 | 115 | 116 | def print_model_size(self): 117 | count_P = sum(p.numel() for p in self.private.parameters() if p.requires_grad) 118 | count_S = sum(p.numel() for p in self.shared.parameters() if p.requires_grad) 119 | count_H = sum(p.numel() for p in self.head.parameters() if p.requires_grad) 120 | 121 | print('Num parameters in S = %s ' % (self.pretty_print(count_S))) 122 | print('Num parameters in P = %s, per task = %s ' % (self.pretty_print(count_P),self.pretty_print(count_P/self.num_tasks))) 123 | print('Num parameters in p = %s, per task = %s ' % (self.pretty_print(count_H),self.pretty_print(count_H/self.num_tasks))) 124 | print('Num parameters in P+p = %s ' % self.pretty_print(count_P+count_H)) 125 | print('--------------------------> Total architecture size: %s parameters (%sB)' % (self.pretty_print(count_S + count_P + count_H), 126 | self.pretty_print(4*(count_S + count_P + count_H)))) 127 | 128 | def pretty_print(self, num): 129 | magnitude=0 130 | while abs(num) >= 1000: 131 | magnitude+=1 132 | num/=1000.0 133 | return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 134 | -------------------------------------------------------------------------------- /ACL-resnet/src/networks/resnet_acl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | class Shared(torch.nn.Module): 12 | 13 | def __init__(self,args): 14 | super(Shared, self).__init__() 15 | 16 | 17 | self.taskcla=args.taskcla 18 | self.latent_dim = args.latent_dim 19 | ncha,size,_ = args.inputsize 20 | 21 | self.pretrained = False 22 | 23 | if args.experiment == 'cifar100': 24 | hiddens = [64, 128, 256] 25 | 26 | elif args.experiment == 'miniimagenet': 27 | hiddens = [1024, 512, 256] 28 | 29 | else: 30 | raise NotImplementedError 31 | 32 | # Small resnet 33 | resnet = resnet18_small(self.latent_dim, shared=True) 34 | self.features = torch.nn.Sequential(*list(resnet.children())[:-2]) 35 | 36 | if args.experiment == 'miniimagenet': 37 | # num_ftrs = 4608 38 | num_ftrs = 2304 # without average pool (-2) 39 | 40 | elif args.experiment == 'cifar100': 41 | # num_ftrs = 25088 # without average pool 42 | num_ftrs = 256 43 | else: 44 | raise NotImplementedError 45 | 46 | self.relu=torch.nn.ReLU() 47 | self.drop1=torch.nn.Dropout(0.2) 48 | self.drop2=torch.nn.Dropout(0.5) 49 | self.fc1=torch.nn.Linear(num_ftrs,hiddens[0]) 50 | self.fc2=torch.nn.Linear(hiddens[0],hiddens[1]) 51 | self.fc3=torch.nn.Linear(hiddens[1],hiddens[1]) 52 | self.fc4=torch.nn.Linear(hiddens[1], self.latent_dim) 53 | 54 | def forward(self, x): 55 | x = x.view_as(x) 56 | x = self.features(x) 57 | x = torch.flatten(x, 1) 58 | x = self.drop2(self.relu(self.fc1(x))) 59 | x = self.drop2(self.relu(self.fc2(x))) 60 | x = self.drop2(self.relu(self.fc3(x))) 61 | x = self.drop2(self.relu(self.fc4(x))) 62 | return x 63 | 64 | 65 | 66 | 67 | 68 | class Net(torch.nn.Module): 69 | 70 | def __init__(self, args): 71 | super(Net, self).__init__() 72 | ncha,size,_=args.inputsize 73 | self.image_size = ncha * size * size 74 | 75 | self.taskcla = args.taskcla 76 | self.latent_dim = args.latent_dim 77 | self.ntasks = args.ntasks 78 | self.samples = args.samples 79 | self.image_size = ncha * size * size 80 | self.use_memory = args.use_memory 81 | 82 | self.hidden1 = args.head_units 83 | self.hidden2 = args.head_units 84 | 85 | self.shared = Shared(args) 86 | self.private = resnet18_small(self.latent_dim, shared=False) 87 | 88 | self.head = torch.nn.Sequential( 89 | torch.nn.Linear(2*self.latent_dim, self.hidden1), 90 | torch.nn.ReLU(inplace=True), 91 | torch.nn.Dropout(), 92 | torch.nn.Linear(self.hidden1, self.hidden2), 93 | torch.nn.ReLU(inplace=True), 94 | torch.nn.Linear(self.hidden2, self.taskcla[0][1]) 95 | ) 96 | 97 | 98 | def forward(self, x_s, x_p, tt=None): 99 | 100 | x_s = x_s.view_as(x_s) 101 | x_p = x_p.view_as(x_p) 102 | 103 | # x_s = self.shared(x_s) 104 | # x_p = self.private(x_p) 105 | 106 | # x = torch.cat([x_p, x_s], dim=1) 107 | 108 | # if self.args.experiment == 'multidatasets': 109 | # # if no memory is used this is faster: 110 | # y=[] 111 | # for i,_ in self.taskcla: 112 | # y.append(self.head[i](x)) 113 | # return y[task_id] 114 | # else: 115 | # return torch.stack([self.head[tt[i]].forward(x[i]) for i in range(x.size(0))]) 116 | 117 | 118 | output = {} 119 | output['shared'] = self.shared(x_s) 120 | output['private'] = self.private(x_p) 121 | concat_features = torch.cat([output['private'], output['shared']], dim=1) 122 | if torch.is_tensor(tt): 123 | output['out'] = torch.stack([self.head[tt[i]].forward(concat_features[i]) for i in range( 124 | concat_features.size(0))]) 125 | else: 126 | output['out'] = self.head(concat_features) 127 | return output 128 | 129 | # output['shared'] = self.shared(x_s) 130 | # output['private'] = self.private(x_p) 131 | # 132 | # concat_features = torch.cat([output['private'], output['shared']], dim=1) 133 | # 134 | # if torch.is_tensor(tt): 135 | # 136 | # output['out'] = torch.stack([self.head[tt[i]].forward(concat_features[i]) for i in range(concat_features.size(0))]) 137 | # else: 138 | # if self.use_memory == 'no': 139 | # output['out'] = self.head.forward(concat_features) 140 | # 141 | # elif self.use_memory == 'yes': 142 | # y = [] 143 | # for i, _ in self.taskcla: 144 | # y.append(self.head[i](concat_features)) 145 | # output['out'] = y[task_id] 146 | # 147 | # return output 148 | 149 | 150 | # def get_encoded_ftrs(self, x_s, x_p, task_id=None): 151 | # return self.shared(x_s), self.private(x_p) 152 | 153 | 154 | def print_model_size(self): 155 | 156 | count_P = sum(p.numel() for p in self.private.parameters() if p.requires_grad) 157 | count_S = sum(p.numel() for p in self.shared.parameters() if p.requires_grad) 158 | count_H = sum(p.numel() for p in self.head.parameters() if p.requires_grad) 159 | 160 | print("Size of the network for one task including (S+P+p)") 161 | print('Num parameters in S = %s ' % (self.pretty_print(count_S))) 162 | print('Num parameters in P = %s ' % (self.pretty_print(count_P))) 163 | print('Num parameters in p = %s ' % (self.pretty_print(count_H))) 164 | print('Num parameters in P+p = %s ' % self.pretty_print(count_P + count_H)) 165 | print('--------------------------> Architecture size in total for all tasks: %s parameters (%sB)' % ( 166 | self.pretty_print(count_S + self.ntasks*count_P + self.ntasks*count_H), 167 | self.pretty_print(4 * (count_S + self.ntasks*count_P + self.ntasks*count_H)))) 168 | 169 | classes_per_task = self.taskcla[0][1] 170 | print("--------------------------> Memory size: %s samples per task (%sB)" % (self.samples*classes_per_task, 171 | self.pretty_print( 172 | self.ntasks * 4 * self.samples * classes_per_task* self.image_size))) 173 | print("------------------------------------------------------------------------------") 174 | print(" TOTAL: %sB" % self.pretty_print( 175 | 4 * (count_S + self.ntasks *count_P + self.ntasks *count_H) + self.ntasks * 4 * self.samples * classes_per_task * self.image_size)) 176 | 177 | def pretty_print(self, num): 178 | magnitude = 0 179 | while abs(num) >= 1000: 180 | magnitude += 1 181 | num /= 1000.0 182 | return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 183 | 184 | 185 | 186 | 187 | class _CustomDataParallel(torch.nn.DataParallel): 188 | def __init__(self, model): 189 | super(_CustomDataParallel, self).__init__(model) 190 | 191 | def __getattr__(self, name): 192 | try: 193 | return super(_CustomDataParallel, self).__getattr__(name) 194 | except AttributeError: 195 | return getattr(self.module, name) 196 | 197 | 198 | 199 | 200 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 201 | """3x3 convolution with padding""" 202 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 203 | padding=dilation, groups=groups, bias=False, dilation=dilation) 204 | 205 | 206 | def conv1x1(in_planes, out_planes, stride=1): 207 | """1x1 convolution""" 208 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 209 | 210 | 211 | class BasicBlock(nn.Module): 212 | expansion = 1 213 | 214 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 215 | base_width=64, dilation=1, norm_layer=None): 216 | super(BasicBlock, self).__init__() 217 | if norm_layer is None: 218 | norm_layer = nn.BatchNorm2d 219 | if groups != 1 or base_width != 64: 220 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 221 | if dilation > 1: 222 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 223 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 224 | self.conv1 = conv3x3(inplanes, planes, stride) 225 | self.bn1 = norm_layer(planes) 226 | self.relu = nn.ReLU(inplace=True) 227 | self.conv2 = conv3x3(planes, planes) 228 | self.bn2 = norm_layer(planes) 229 | self.downsample = downsample 230 | self.stride = stride 231 | 232 | def forward(self, x): 233 | identity = x 234 | 235 | out = self.conv1(x) 236 | out = self.bn1(out) 237 | out = self.relu(out) 238 | 239 | out = self.conv2(out) 240 | out = self.bn2(out) 241 | 242 | if self.downsample is not None: 243 | identity = self.downsample(x) 244 | 245 | out += identity 246 | out = self.relu(out) 247 | 248 | return out 249 | 250 | 251 | class Bottleneck(nn.Module): 252 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 253 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 254 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 255 | # This variant is also known as ResNet V1.5 and improves accuracy according to 256 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 257 | 258 | expansion = 4 259 | 260 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 261 | base_width=64, dilation=1, norm_layer=None): 262 | super(Bottleneck, self).__init__() 263 | if norm_layer is None: 264 | norm_layer = nn.BatchNorm2d 265 | width = int(planes * (base_width / 64.)) * groups 266 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 267 | self.conv1 = conv1x1(inplanes, width) 268 | self.bn1 = norm_layer(width) 269 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 270 | self.bn2 = norm_layer(width) 271 | self.conv3 = conv1x1(width, planes * self.expansion) 272 | self.bn3 = norm_layer(planes * self.expansion) 273 | self.relu = nn.ReLU(inplace=True) 274 | self.downsample = downsample 275 | self.stride = stride 276 | 277 | def forward(self, x): 278 | identity = x 279 | 280 | out = self.conv1(x) 281 | out = self.bn1(out) 282 | out = self.relu(out) 283 | 284 | out = self.conv2(out) 285 | out = self.bn2(out) 286 | out = self.relu(out) 287 | 288 | out = self.conv3(out) 289 | out = self.bn3(out) 290 | 291 | if self.downsample is not None: 292 | identity = self.downsample(x) 293 | 294 | out += identity 295 | out = self.relu(out) 296 | 297 | return out 298 | 299 | 300 | 301 | 302 | class ResNet(nn.Module): 303 | 304 | def __init__(self, shared, block, layers, num_classes, zero_init_residual=False, 305 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 306 | norm_layer=None): 307 | super(ResNet, self).__init__() 308 | if norm_layer is None: 309 | norm_layer = nn.BatchNorm2d 310 | self._norm_layer = norm_layer 311 | 312 | self.inplanes = 64 313 | self.dilation = 1 314 | 315 | # small resnet 316 | if shared: 317 | hiddens = [32, 64, 128, 256] 318 | else: 319 | hiddens = [16, 32, 32, 64] 320 | 321 | # original resnet 322 | # hiddens = [64, 128, 256, 512] 323 | 324 | if replace_stride_with_dilation is None: 325 | # each element in the tuple indicates if we should replace 326 | # the 2x2 stride with a dilated convolution instead 327 | replace_stride_with_dilation = [False, False, False] 328 | if len(replace_stride_with_dilation) != 3: 329 | raise ValueError("replace_stride_with_dilation should be None " 330 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 331 | self.groups = groups 332 | self.base_width = width_per_group 333 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 334 | bias=False) 335 | self.bn1 = norm_layer(self.inplanes) 336 | self.relu = nn.ReLU(inplace=True) 337 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 338 | self.layer1 = self._make_layer(block, hiddens[0], layers[0]) 339 | self.layer2 = self._make_layer(block, hiddens[1], layers[1], stride=2, 340 | dilate=replace_stride_with_dilation[0]) 341 | self.layer3 = self._make_layer(block, hiddens[2], layers[2], stride=2, 342 | dilate=replace_stride_with_dilation[1]) 343 | self.layer4 = self._make_layer(block, hiddens[3], layers[3], stride=2, 344 | dilate=replace_stride_with_dilation[2]) 345 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 346 | self.fc = nn.Linear(hiddens[3] * block.expansion, num_classes) 347 | 348 | for m in self.modules(): 349 | if isinstance(m, nn.Conv2d): 350 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 351 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 352 | nn.init.constant_(m.weight, 1) 353 | nn.init.constant_(m.bias, 0) 354 | 355 | # Zero-initialize the last BN in each residual branch, 356 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 357 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 358 | if zero_init_residual: 359 | for m in self.modules(): 360 | if isinstance(m, Bottleneck): 361 | nn.init.constant_(m.bn3.weight, 0) 362 | elif isinstance(m, BasicBlock): 363 | nn.init.constant_(m.bn2.weight, 0) 364 | 365 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 366 | norm_layer = self._norm_layer 367 | downsample = None 368 | previous_dilation = self.dilation 369 | if dilate: 370 | self.dilation *= stride 371 | stride = 1 372 | if stride != 1 or self.inplanes != planes * block.expansion: 373 | downsample = nn.Sequential( 374 | conv1x1(self.inplanes, planes * block.expansion, stride), 375 | norm_layer(planes * block.expansion), 376 | ) 377 | 378 | layers = [] 379 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 380 | self.base_width, previous_dilation, norm_layer)) 381 | self.inplanes = planes * block.expansion 382 | for _ in range(1, blocks): 383 | layers.append(block(self.inplanes, planes, groups=self.groups, 384 | base_width=self.base_width, dilation=self.dilation, 385 | norm_layer=norm_layer)) 386 | 387 | return nn.Sequential(*layers) 388 | 389 | def _forward_impl(self, x): 390 | x = self.conv1(x) 391 | x = self.bn1(x) 392 | x = self.relu(x) 393 | x = self.maxpool(x) 394 | 395 | x = self.layer1(x) 396 | x = self.layer2(x) 397 | x = self.layer3(x) 398 | x = self.layer4(x) 399 | 400 | x = self.avgpool(x) 401 | x = torch.flatten(x, 1) 402 | x = self.fc(x) 403 | x = self.relu(x) 404 | return x 405 | 406 | def forward(self, x): 407 | return self._forward_impl(x) 408 | 409 | 410 | 411 | 412 | def resnet18_small(latend_dim, shared): 413 | # r"""ResNet-18 model from 414 | # `"Deep Residual Learning for Image Recognition" `_ 415 | return ResNet(shared, BasicBlock, [2, 2, 2, 2], num_classes=latend_dim) 416 | 417 | -------------------------------------------------------------------------------- /ACL-resnet/src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | from copy import deepcopy 9 | import pickle 10 | import time 11 | import uuid 12 | from subprocess import call 13 | ######################################################################################################################## 14 | 15 | def human_format(num): 16 | magnitude=0 17 | while abs(num)>=1000: 18 | magnitude+=1 19 | num/=1000.0 20 | return '%.1f%s'%(num,['','K','M','G','T','P'][magnitude]) 21 | 22 | 23 | def report_tr(res, e, sbatch, clock0, clock1): 24 | # Training performance 25 | print( 26 | '| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train losses={:.3f} | T: loss={:.3f}, acc={:5.2f}% | D: loss={:.3f}, acc={:5.1f}%, ' 27 | 'Diff loss:{:.3f} |'.format( 28 | e + 1, 29 | 1000 * sbatch * (clock1 - clock0) / res['size'], 30 | 1000 * sbatch * (time.time() - clock1) / res['size'], res['loss_tot'], 31 | res['loss_t'], res['acc_t'], res['loss_a'], res['acc_d'], res['loss_d']), end='') 32 | 33 | def report_val(res): 34 | # Validation performance 35 | print(' Valid losses={:.3f} | T: loss={:.6f}, acc={:5.2f}%, | D: loss={:.3f}, acc={:5.2f}%, Diff loss={:.3f} |'.format( 36 | res['loss_tot'], res['loss_t'], res['acc_t'], res['loss_a'], res['acc_d'], res['loss_d']), end='') 37 | 38 | 39 | ######################################################################################################################## 40 | 41 | def get_model(model): 42 | return deepcopy(model.state_dict()) 43 | 44 | ######################################################################################################################## 45 | 46 | def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1): 47 | return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1)) 48 | 49 | ######################################################################################################################## 50 | 51 | def save_print_log(taskcla, acc, lss, output_path): 52 | 53 | print('*'*100) 54 | print('Accuracies =') 55 | for i in range(acc.shape[0]): 56 | print('\t',end=',') 57 | for j in range(acc.shape[1]): 58 | print('{:5.4f}% '.format(acc[i,j]),end=',') 59 | print() 60 | print ('ACC: {:5.4f}%'.format((np.mean(acc[acc.shape[0]-1,:])))) 61 | print() 62 | 63 | print ('BWD Transfer = ') 64 | 65 | print () 66 | print ("Diagonal R_ii") 67 | for i in range(acc.shape[0]): 68 | print('\t',end='') 69 | print('{:5.2f}% '.format(np.diag(acc)[i]), end=',') 70 | 71 | 72 | print() 73 | print ("Last row") 74 | for i in range(acc.shape[0]): 75 | print('\t', end=',') 76 | print('{:5.2f}% '.format(acc[-1][i]), end=',') 77 | 78 | print() 79 | # BWT calculated based on GEM paper (https://arxiv.org/abs/1706.08840) 80 | gem_bwt = sum(acc[-1]-np.diag(acc))/ (len(acc[-1])-1) 81 | # BWT calculated based on our UCB paper (https://openreview.net/pdf?id=HklUCCVKDB) 82 | ucb_bwt = (acc[-1] - np.diag(acc)).mean() 83 | print ('BWT: {:5.2f}%'.format(gem_bwt)) 84 | # print ('BWT (UCB paper): {:5.2f}%'.format(ucb_bwt)) 85 | 86 | print('*'*100) 87 | print('Done!') 88 | 89 | 90 | logs = {} 91 | # save results 92 | logs['name'] = output_path 93 | logs['taskcla'] = taskcla 94 | logs['acc'] = acc 95 | logs['loss'] = lss 96 | logs['gem_bwt'] = gem_bwt 97 | logs['ucb_bwt'] = ucb_bwt 98 | logs['rii'] = np.diag(acc) 99 | logs['rij'] = acc[-1] 100 | 101 | # pickle 102 | with open(os.path.join(output_path, 'logs.p'), 'wb') as output: 103 | pickle.dump(logs, output) 104 | 105 | print ("Log file saved in ", os.path.join(output_path, 'logs.p')) 106 | 107 | 108 | def print_log_acc_bwt(taskcla, acc, lss, output_path, run_id): 109 | 110 | print('*'*100) 111 | print('Accuracies =') 112 | for i in range(acc.shape[0]): 113 | print('\t',end=',') 114 | for j in range(acc.shape[1]): 115 | print('{:5.4f}% '.format(acc[i,j]),end=',') 116 | print() 117 | 118 | avg_acc = np.mean(acc[acc.shape[0]-1,:]) 119 | print ('ACC: {:5.4f}%'.format(avg_acc)) 120 | print() 121 | print() 122 | # BWT calculated based on GEM paper (https://arxiv.org/abs/1706.08840) 123 | gem_bwt = sum(acc[-1]-np.diag(acc))/ (len(acc[-1])-1) 124 | # BWT calculated based on UCB paper (https://arxiv.org/abs/1906.02425) 125 | ucb_bwt = (acc[-1] - np.diag(acc)).mean() 126 | print ('BWT: {:5.2f}%'.format(gem_bwt)) 127 | # print ('BWT (UCB paper): {:5.2f}%'.format(ucb_bwt)) 128 | 129 | print('*'*100) 130 | print('Done!') 131 | 132 | 133 | logs = {} 134 | # save results 135 | logs['name'] = output_path 136 | logs['taskcla'] = taskcla 137 | logs['acc'] = acc 138 | logs['loss'] = lss 139 | logs['gem_bwt'] = gem_bwt 140 | logs['ucb_bwt'] = ucb_bwt 141 | logs['rii'] = np.diag(acc) 142 | logs['rij'] = acc[-1] 143 | 144 | # pickle 145 | path = os.path.join(output_path, 'logs_run_id_{}.p'.format(run_id)) 146 | with open(path, 'wb') as output: 147 | pickle.dump(logs, output) 148 | 149 | print ("Log file saved in ", path) 150 | return avg_acc, gem_bwt 151 | 152 | 153 | def print_running_acc_bwt(acc, task_num): 154 | print() 155 | acc = acc[:task_num+1,:task_num+1] 156 | avg_acc = np.mean(acc[acc.shape[0] - 1, :]) 157 | gem_bwt = sum(acc[-1] - np.diag(acc)) / (len(acc[-1]) - 1) 158 | print('ACC: {:5.4f}% || BWT: {:5.2f}% '.format(avg_acc, gem_bwt)) 159 | print() 160 | 161 | 162 | def make_directories(args): 163 | uid = uuid.uuid4().hex 164 | if args.checkpoint is None: 165 | os.mkdir('checkpoints') 166 | args.checkpoint = os.path.join('./checkpoints/',uid) 167 | os.mkdir(args.checkpoint) 168 | else: 169 | if not os.path.exists(args.checkpoint): 170 | os.mkdir(args.checkpoint) 171 | args.checkpoint = os.path.join(args.checkpoint, uid) 172 | os.mkdir(args.checkpoint) 173 | 174 | 175 | 176 | 177 | def some_sanity_checks(args): 178 | # Making sure the chosen experiment matches with the number of tasks performed in the paper: 179 | datasets_tasks = {} 180 | datasets_tasks['mnist5']=[5] 181 | datasets_tasks['pmnist']=[10,20,30,40] 182 | datasets_tasks['cifar100']=[20] 183 | datasets_tasks['miniimagenet']=[20] 184 | datasets_tasks['multidatasets']=[5] 185 | 186 | 187 | if not args.ntasks in datasets_tasks[args.experiment]: 188 | raise Exception("Chosen number of tasks ({}) does not match with {} experiment".format(args.ntasks,args.experiment)) 189 | 190 | # Making sure if memory usage is happenning: 191 | if args.use_memory == 'yes' and not args.samples > 0: 192 | raise Exception("Flags required to use memory: --use_memory yes --samples n where n>0") 193 | 194 | if args.use_memory == 'no' and args.samples > 0: 195 | raise Exception("Flags required to use memory: --use_memory yes --samples n where n>0") 196 | 197 | 198 | 199 | def save_code(args): 200 | cwd = os.getcwd() 201 | des = os.path.join(args.checkpoint, 'code') + '/' 202 | if not os.path.exists(des): 203 | os.mkdir(des) 204 | 205 | def get_folder(folder): 206 | return os.path.join(cwd,folder) 207 | 208 | folders = [get_folder(item) for item in ['dataloaders', 'networks', 'configs', 'main.py', 'acl.py', 'utils.py']] 209 | 210 | for folder in folders: 211 | call('cp -rf {} {}'.format(folder, des),shell=True) 212 | 213 | 214 | def print_time(): 215 | from datetime import datetime 216 | 217 | # datetime object containing current date and time 218 | now = datetime.now() 219 | 220 | # dd/mm/YY H:M:S 221 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 222 | print("Job finished at =", dt_string) 223 | 224 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to acl 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 2 spaces for indentation rather than tabs 31 | * 80 character line length 32 | * ... 33 | 34 | ## License 35 | By contributing to acl, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2020-present, Facebook, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Continual Learning 2 | 3 | 4 | This is the official PyTorch implementation of the [Adversarial Continual Learning](https://arxiv.org/abs/2003.09553) at *ECCV 2020*. . 5 | 6 | 7 | ## Abstract 8 | 9 | Continual learning aims to learn new tasks without forgetting previously learned ones. We hypothesize that representations learned to solve each task in a sequence have shared structure while containing some task-specific properties. We show that shared features are significantly less prone to forgetting and propose a novel hybrid continual learning framework that learns a disjoint representation for task-invariant and task-specific features required to solve a sequence of tasks. Our model combines architecture growth to prevent forgetting of task-specific skills and an experience replay approach to preserve shared skills. We demonstrate our hybrid approach is effective in avoiding forgetting and show it is superior to both architecture-based and memory-based approaches on class incrementally learning of a single dataset as well as a sequence of multiple datasets in image classification. 10 | 11 | ## Authors: 12 | [Sayna Ebrahimi](https://people.eecs.berkeley.edu/~sayna/) (UC Berkeley, FAIR), [Franziska Meier](https://am.is.tuebingen.mpg.de/person/fmeier) (FAIR), [Roberto Calandra](https://www.robertocalandra.com/about/) (FAIR), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) (UC Berkeley), [Marcus Rohrbach](http://rohrbach.vision/) (FAIR) 13 | 14 | ### Citation 15 | If using this code, parts of it, or developments from it, please cite our paper: 16 | ``` 17 | @article{ebrahimi2020adversarial, 18 | title={Adversarial Continual Learning}, 19 | author={Ebrahimi, Sayna and Meier, Franziska and Calandra, Roberto and Darrell, Trevor and Rohrbach, Marcus}, 20 | journal={arXiv preprint arXiv:2003.09553}, 21 | year={2020} 22 | } 23 | ``` 24 | 25 | ### Prerequisites: 26 | - Linux-64 27 | - Python 3.6 28 | - PyTorch 1.3.1 29 | - CPU or NVIDIA GPU + CUDA10 CuDNN7.5 30 | 31 | 32 | ### Installation 33 | - Create a conda environment and install required packages: 34 | ```bash 35 | conda create -n python=3.6 36 | conda activate 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | - Clone this repo: 41 | ```bash 42 | mkdir ACL 43 | cd ACL 44 | git clone git@github.com:facebookresearch/Adversarial-Continual-Learning.git 45 | ``` 46 | 47 | - The following structure is expected in the main directory: 48 | 49 | ``` 50 | ./src : main directory where all scripts are placed in 51 | ./data : data directory 52 | ./src/checkpoints : results are saved in here 53 | ``` 54 | 55 | ##### For each datasest run the following commands from `src` directory. Config file for each experiment contains the hyperparameters we used in the paper: 56 | 57 | Split MNIST (5 Tasks): 58 | 59 | `python main.py --config ./configs/config_mnist5.yml` 60 | 61 | 62 | Permuted MNIST (10 Tasks): 63 | 64 | `python main.py --config ./configs/config_pmnist.yml` 65 | 66 | 67 | Split CIFAR100 (20 Tasks): 68 | 69 | ``python main.py --config ./configs/config_cifar100.yml`` 70 | 71 | Split MiniImageNet (20 Tasks): 72 | 73 | `python main.py --config ./configs/config_miniimagenet.yml` 74 | 75 | Sequence of 5 Tasks (CIFAR10, MNIST, notMNIST, Fashion MNIST, SVHN) 76 | 77 | `python main.py --config ./configs/config_multidatasets.yml` 78 | 79 | ###### ACL with ResNet18 backbone 80 | See [here](https://github.com/facebookresearch/Adversarial-Continual-Learning/tree/master/ACL-resnet). 81 | 82 | #### Datasets 83 | 84 | *miniImageNet* data should be [downloaded](https://github.com/yaoyao-liu/mini-imagenet-tools#about-mini-ImageNet) and pickled as a dictionary (`data.pkl`) with `images` and `labels` keys and placed in a sub-folder in `ags.data_dir` named as `miniimagenet`. The script used to split `data.pkl` into training and test sets is included in data dorectory (`data/`) 85 | 86 | *notMNIST* dataset is included here in `./data/notMNIST` as it was used in our experiments. 87 | 88 | Other datasets will be automatically downloaded and extracted to `./data` if they do not exist. 89 | 90 | ## Questions/ Bugs 91 | * For questions/bugs, contact the author Sayna Ebrahimi via email sayna@berkeley.edu 92 | 93 | 94 | 95 | ## License 96 | This source code is released under The MIT License found in the LICENSE file in the root directory of this source tree. 97 | 98 | 99 | ## Acknowledgements 100 | Our code structure is inspired by [HAT](https://github.com/joansj/hat.). 101 | -------------------------------------------------------------------------------- /data/notMNIST.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Adversarial-Continual-Learning/a99dfadcc59a12d903af6e5366a025ca44b3af07/data/notMNIST.zip -------------------------------------------------------------------------------- /data/split_miniimagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pickle 7 | import numpy as np 8 | import os 9 | 10 | np.random.seed(1234) 11 | 12 | # we want 500 for training, 100 for test for wach class 13 | n = 500 14 | 15 | def get_total(data): 16 | data_x, data_y = [], [] 17 | for k, v in data.items(): 18 | for i in range(len(v)): 19 | data_x.append(v[i]) 20 | data_y.append(k) 21 | d = {} 22 | d['images'] = data_x 23 | d['labels'] = data_y 24 | return d 25 | 26 | 27 | # loading the pickled data 28 | with open(os.path.join('../data/miniimagenet/data.pkl'), 'rb') as f: 29 | data_dict = pickle.load(f) 30 | data = data_dict['images'] 31 | labels = data_dict['labels'] 32 | 33 | # split data into classes, 600 images per class 34 | class_dict = {} 35 | for i in range(len(set(labels))): 36 | class_dict[i] = [] 37 | 38 | for i in range(len(data)): 39 | class_dict[labels[i]].append(data[i]) 40 | 41 | # Split data for each class to 500 and 100 42 | x_train, x_test = {}, {} 43 | for i in range(len(set(labels))): 44 | np.random.shuffle(class_dict[i]) 45 | x_test[i] = class_dict[i][n:] 46 | x_train[i] = class_dict[i][:n] 47 | 48 | # mix the data 49 | d_train = get_total(x_train) 50 | d_test = get_total(x_test) 51 | 52 | with open(os.path.join('../data/miniimagenet/train.pkl'), 'wb') as f: 53 | pickle.dump(d_train, f) 54 | with open(os.path.join('../data/miniimagenet/test.pkl'), 'wb') as f: 55 | pickle.dump(d_test, f) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | astor==0.8.0 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | blessings==1.7 7 | certifi==2019.11.28 8 | cffi==1.11.5 9 | chardet==3.0.4 10 | Click==7.0 11 | configparser==4.0.2 12 | cycler==0.10.0 13 | decorator==4.3.0 14 | defusedxml==0.6.0 15 | docker-pycreds==0.4.0 16 | entrypoints==0.3 17 | gast==0.3.2 18 | gitdb==4.0.2 19 | GitPython==3.1.0 20 | google-pasta==0.1.8 21 | gpustat==0.6.0 22 | GPUtil==1.4.0 23 | gql==0.2.0 24 | graphql-core==1.1 25 | grpcio==1.25.0 26 | h5py==2.10.0 27 | idna==2.9 28 | imageio==2.6.1 29 | importlib-metadata==1.2.0 30 | ipykernel==5.1.3 31 | ipython==7.2.0 32 | ipython-genutils==0.2.0 33 | jedi==0.13.2 34 | Jinja2==2.10.3 35 | joblib==0.14.0 36 | json5==0.8.5 37 | jsonschema==3.2.0 38 | jupyter-client==5.3.4 39 | jupyter-core==4.6.1 40 | jupyterlab==1.2.3 41 | jupyterlab-server==1.0.6 42 | Keras-Applications==1.0.8 43 | Keras-Preprocessing==1.1.0 44 | kiwisolver==1.1.0 45 | Markdown==2.6.11 46 | MarkupSafe==1.1.1 47 | matplotlib==3.1.2 48 | missingno==0.4.2 49 | mistune==0.8.4 50 | more-itertools==8.0.2 51 | nbconvert==5.6.1 52 | nbformat==4.4.0 53 | networkx==2.4 54 | notebook==6.0.2 55 | numpy==1.17.5 56 | nvidia-ml-py3==7.352.0 57 | olefile==0.46 58 | omegaconf==1.4.0 59 | pandas==0.25.3 60 | pandocfilters==1.4.2 61 | parso==0.3.1 62 | pathtools==0.1.2 63 | pexpect==4.6.0 64 | pickleshare==0.7.5 65 | Pillow==5.3.0 66 | prometheus-client==0.7.1 67 | promise==2.3 68 | prompt-toolkit==2.0.7 69 | protobuf==3.7.1 70 | psutil==5.6.7 71 | ptyprocess==0.6.0 72 | pycparser==2.19 73 | Pygments==2.3.1 74 | pyparsing==2.4.5 75 | pyrsistent==0.15.6 76 | python-dateutil==2.8.1 77 | pytorch-revgrad==0.1.1 78 | pytz==2019.3 79 | PyWavelets==1.1.1 80 | PyYAML==5.2 81 | pyzmq==18.1.1 82 | requests==2.23.0 83 | scikit-image==0.16.2 84 | scikit-learn==0.22 85 | scipy==1.1.0 86 | seaborn==0.9.0 87 | Send2Trash==1.5.0 88 | sentry-sdk==0.14.2 89 | shap==0.32.1 90 | shortuuid==1.0.1 91 | six==1.12.0 92 | smmap==3.0.1 93 | subprocess32==3.5.4 94 | tensorboard==1.14.0 95 | tensorboardX==1.6 96 | tensorflow-estimator==1.14.0 97 | tensorflow-gpu==1.14.0 98 | termcolor==1.1.0 99 | terminado==0.8.3 100 | testpath==0.4.4 101 | torch==1.3.1 102 | torchvision==0.4.2 103 | tornado==6.0.3 104 | tqdm==4.40.1 105 | traitlets==4.3.2 106 | urllib3==1.25.8 107 | wandb==0.8.29 108 | watchdog==0.10.2 109 | wcwidth==0.1.7 110 | webencodings==0.5.1 111 | Werkzeug==0.15.2 112 | wrapt==1.11.2 113 | zipp==0.6.0 114 | -------------------------------------------------------------------------------- /src/acl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import sys, time, os 8 | import numpy as np 9 | import torch 10 | import copy 11 | import utils 12 | 13 | from copy import deepcopy 14 | from tqdm import tqdm 15 | 16 | sys.path.append('../') 17 | 18 | from networks.discriminator import Discriminator 19 | 20 | class ACL(object): 21 | 22 | def __init__(self, model, args, network): 23 | self.args=args 24 | self.nepochs=args.nepochs 25 | self.sbatch=args.batch_size 26 | 27 | # optimizer & adaptive lr 28 | self.e_lr=args.e_lr 29 | self.d_lr=args.d_lr 30 | 31 | if not args.experiment == 'multidatasets': 32 | self.e_lr=[args.e_lr] * args.ntasks 33 | self.d_lr=[args.d_lr] * args.ntasks 34 | else: 35 | self.e_lr = [self.args.lrs[i][1] for i in range(len(args.lrs))] 36 | self.d_lr = [self.args.lrs[i][1]/10. for i in range(len(args.lrs))] 37 | print ("d_lrs : ", self.d_lr) 38 | 39 | self.lr_min=args.lr_min 40 | self.lr_factor=args.lr_factor 41 | self.lr_patience=args.lr_patience 42 | 43 | self.samples=args.samples 44 | 45 | self.device=args.device 46 | self.checkpoint=args.checkpoint 47 | 48 | self.adv_loss_reg=args.adv 49 | self.diff_loss_reg=args.orth 50 | self.s_steps=args.s_step 51 | self.d_steps=args.d_step 52 | 53 | self.diff=args.diff 54 | 55 | self.network=network 56 | self.inputsize=args.inputsize 57 | self.taskcla=args.taskcla 58 | self.num_tasks=args.ntasks 59 | 60 | # Initialize generator and discriminator 61 | self.model=model 62 | self.discriminator=self.get_discriminator(0) 63 | self.discriminator.get_size() 64 | 65 | self.latent_dim=args.latent_dim 66 | 67 | self.task_loss=torch.nn.CrossEntropyLoss().to(self.device) 68 | self.adversarial_loss_d=torch.nn.CrossEntropyLoss().to(self.device) 69 | self.adversarial_loss_s=torch.nn.CrossEntropyLoss().to(self.device) 70 | self.diff_loss=DiffLoss().to(self.device) 71 | 72 | self.optimizer_S=self.get_S_optimizer(0) 73 | self.optimizer_D=self.get_D_optimizer(0) 74 | 75 | self.task_encoded={} 76 | 77 | self.mu=0.0 78 | self.sigma=1.0 79 | 80 | print() 81 | 82 | def get_discriminator(self, task_id): 83 | discriminator=Discriminator(self.args, task_id).to(self.args.device) 84 | return discriminator 85 | 86 | def get_S_optimizer(self, task_id, e_lr=None): 87 | if e_lr is None: e_lr=self.e_lr[task_id] 88 | optimizer_S=torch.optim.SGD(self.model.parameters(), momentum=self.args.mom, 89 | weight_decay=self.args.e_wd, lr=e_lr) 90 | return optimizer_S 91 | 92 | def get_D_optimizer(self, task_id, d_lr=None): 93 | if d_lr is None: d_lr=self.d_lr[task_id] 94 | optimizer_D=torch.optim.SGD(self.discriminator.parameters(), weight_decay=self.args.d_wd, lr=d_lr) 95 | return optimizer_D 96 | 97 | def train(self, task_id, dataset): 98 | self.discriminator=self.get_discriminator(task_id) 99 | 100 | best_loss=np.inf 101 | best_model=utils.get_model(self.model) 102 | 103 | 104 | best_loss_d=np.inf 105 | best_model_d=utils.get_model(self.discriminator) 106 | 107 | dis_lr_update=True 108 | d_lr=self.d_lr[task_id] 109 | patience_d=self.lr_patience 110 | self.optimizer_D=self.get_D_optimizer(task_id, d_lr) 111 | 112 | e_lr=self.e_lr[task_id] 113 | patience=self.lr_patience 114 | self.optimizer_S=self.get_S_optimizer(task_id, e_lr) 115 | 116 | 117 | for e in range(self.nepochs): 118 | 119 | # Train 120 | clock0=time.time() 121 | self.train_epoch(dataset['train'], task_id) 122 | clock1=time.time() 123 | 124 | train_res=self.eval_(dataset['train'], task_id) 125 | 126 | utils.report_tr(train_res, e, self.sbatch, clock0, clock1) 127 | 128 | # lowering the learning rate in the beginning if it predicts random chance for the first 5 epochs 129 | if (self.args.experiment == 'cifar100' or self.args.experiment == 'miniimagenet') and e == 4: 130 | random_chance=20. 131 | threshold=random_chance + 2 132 | 133 | if train_res['acc_t'] < threshold: 134 | # Restore best validation model 135 | d_lr=self.d_lr[task_id] / 10. 136 | self.optimizer_D=self.get_D_optimizer(task_id, d_lr) 137 | print("Performance on task {} is {} so Dis's lr is decreased to {}".format(task_id, train_res[ 138 | 'acc_t'], d_lr), end=" ") 139 | 140 | e_lr=self.e_lr[task_id] / 10. 141 | self.optimizer_S=self.get_S_optimizer(task_id, e_lr) 142 | 143 | self.discriminator=self.get_discriminator(task_id) 144 | 145 | if task_id > 0: 146 | self.model=self.load_checkpoint(task_id - 1) 147 | else: 148 | self.model=self.network.Net(self.args).to(self.args.device) 149 | 150 | 151 | # Valid 152 | valid_res=self.eval_(dataset['valid'], task_id) 153 | utils.report_val(valid_res) 154 | 155 | 156 | # Adapt lr for S and D 157 | if valid_res['loss_tot'] < best_loss: 158 | best_loss=valid_res['loss_tot'] 159 | best_model=utils.get_model(self.model) 160 | patience=self.lr_patience 161 | print(' *', end='') 162 | else: 163 | patience-=1 164 | if patience <= 0: 165 | e_lr/=self.lr_factor 166 | print(' lr={:.1e}'.format(e_lr), end='') 167 | if e_lr < self.lr_min: 168 | print() 169 | break 170 | patience=self.lr_patience 171 | self.optimizer_S=self.get_S_optimizer(task_id, e_lr) 172 | 173 | if train_res['loss_a'] < best_loss_d: 174 | best_loss_d=train_res['loss_a'] 175 | best_model_d=utils.get_model(self.discriminator) 176 | patience_d=self.lr_patience 177 | else: 178 | patience_d-=1 179 | if patience_d <= 0 and dis_lr_update: 180 | d_lr/=self.lr_factor 181 | print(' Dis lr={:.1e}'.format(d_lr)) 182 | if d_lr < self.lr_min: 183 | dis_lr_update=False 184 | print("Dis lr reached minimum value") 185 | print() 186 | patience_d=self.lr_patience 187 | self.optimizer_D=self.get_D_optimizer(task_id, d_lr) 188 | print() 189 | 190 | # Restore best validation model (early-stopping) 191 | self.model.load_state_dict(copy.deepcopy(best_model)) 192 | self.discriminator.load_state_dict(copy.deepcopy(best_model_d)) 193 | 194 | self.save_all_models(task_id) 195 | 196 | 197 | def train_epoch(self, train_loader, task_id): 198 | 199 | self.model.train() 200 | self.discriminator.train() 201 | 202 | for data, target, tt, td in train_loader: 203 | 204 | x=data.to(device=self.device) 205 | y=target.to(device=self.device, dtype=torch.long) 206 | tt=tt.to(device=self.device) 207 | 208 | # Detaching samples in the batch which do not belong to the current task before feeding them to P 209 | t_current=task_id * torch.ones_like(tt) 210 | body_mask=torch.eq(t_current, tt).cpu().numpy() 211 | # x_task_module=data.to(device=self.device) 212 | x_task_module=data.clone() 213 | for index in range(x.size(0)): 214 | if body_mask[index] == 0: 215 | x_task_module[index]=x_task_module[index].detach() 216 | x_task_module=x_task_module.to(device=self.device) 217 | 218 | # Discriminator's real and fake task labels 219 | t_real_D=td.to(self.device) 220 | t_fake_D=torch.zeros_like(t_real_D).to(self.device) 221 | 222 | # ================================================================== # 223 | # Train Shared Module # 224 | # ================================================================== # 225 | # training S for s_steps 226 | for s_step in range(self.s_steps): 227 | self.optimizer_S.zero_grad() 228 | self.model.zero_grad() 229 | 230 | output=self.model(x, x_task_module, tt, task_id) 231 | task_loss=self.task_loss(output, y) 232 | 233 | shared_encoded, task_encoded=self.model.get_encoded_ftrs(x, x_task_module, task_id) 234 | dis_out_gen_training=self.discriminator.forward(shared_encoded, t_real_D, task_id) 235 | adv_loss=self.adversarial_loss_s(dis_out_gen_training, t_real_D) 236 | 237 | if self.diff == 'yes': 238 | diff_loss=self.diff_loss(shared_encoded, task_encoded) 239 | else: 240 | diff_loss=torch.tensor(0).to(device=self.device, dtype=torch.float32) 241 | self.diff_loss_reg=0 242 | 243 | total_loss=task_loss + self.adv_loss_reg * adv_loss + self.diff_loss_reg * diff_loss 244 | total_loss.backward(retain_graph=True) 245 | 246 | self.optimizer_S.step() 247 | 248 | # ================================================================== # 249 | # Train Discriminator # 250 | # ================================================================== # 251 | # training discriminator for d_steps 252 | for d_step in range(self.d_steps): 253 | self.optimizer_D.zero_grad() 254 | self.discriminator.zero_grad() 255 | 256 | # training discriminator on real data 257 | output=self.model(x, x_task_module, tt, task_id) 258 | shared_encoded, task_out=self.model.get_encoded_ftrs(x, x_task_module, task_id) 259 | dis_real_out=self.discriminator.forward(shared_encoded.detach(), t_real_D, task_id) 260 | dis_real_loss=self.adversarial_loss_d(dis_real_out, t_real_D) 261 | if self.args.experiment == 'miniimagenet': 262 | dis_real_loss*=self.adv_loss_reg 263 | dis_real_loss.backward(retain_graph=True) 264 | 265 | # training discriminator on fake data 266 | z_fake=torch.as_tensor(np.random.normal(self.mu, self.sigma, (x.size(0), self.latent_dim)),dtype=torch.float32, device=self.device) 267 | dis_fake_out=self.discriminator.forward(z_fake, t_real_D, task_id) 268 | dis_fake_loss=self.adversarial_loss_d(dis_fake_out, t_fake_D) 269 | if self.args.experiment == 'miniimagenet': 270 | dis_fake_loss*=self.adv_loss_reg 271 | dis_fake_loss.backward(retain_graph=True) 272 | 273 | self.optimizer_D.step() 274 | 275 | return 276 | 277 | 278 | def eval_(self, data_loader, task_id): 279 | loss_a, loss_t, loss_d, loss_total=0, 0, 0, 0 280 | correct_d, correct_t = 0, 0 281 | num=0 282 | batch=0 283 | 284 | self.model.eval() 285 | self.discriminator.eval() 286 | 287 | res={} 288 | with torch.no_grad(): 289 | for batch, (data, target, tt, td) in enumerate(data_loader): 290 | x=data.to(device=self.device) 291 | y=target.to(device=self.device, dtype=torch.long) 292 | tt=tt.to(device=self.device) 293 | t_real_D=td.to(self.device) 294 | 295 | # Forward 296 | output=self.model(x, x, tt, task_id) 297 | shared_out, task_out=self.model.get_encoded_ftrs(x, x, task_id) 298 | _, pred=output.max(1) 299 | correct_t+=pred.eq(y.view_as(pred)).sum().item() 300 | 301 | # Discriminator's performance: 302 | output_d=self.discriminator.forward(shared_out, t_real_D, task_id) 303 | _, pred_d=output_d.max(1) 304 | correct_d+=pred_d.eq(t_real_D.view_as(pred_d)).sum().item() 305 | 306 | # Loss values 307 | task_loss=self.task_loss(output, y) 308 | adv_loss=self.adversarial_loss_d(output_d, t_real_D) 309 | 310 | if self.diff == 'yes': 311 | diff_loss=self.diff_loss(shared_out, task_out) 312 | else: 313 | diff_loss=torch.tensor(0).to(device=self.device, dtype=torch.float32) 314 | self.diff_loss_reg=0 315 | 316 | total_loss = task_loss + self.adv_loss_reg * adv_loss + self.diff_loss_reg * diff_loss 317 | 318 | loss_t+=task_loss 319 | loss_a+=adv_loss 320 | loss_d+=diff_loss 321 | loss_total+=total_loss 322 | 323 | num+=x.size(0) 324 | 325 | res['loss_t'], res['acc_t']=loss_t.item() / (batch + 1), 100 * correct_t / num 326 | res['loss_a'], res['acc_d']=loss_a.item() / (batch + 1), 100 * correct_d / num 327 | res['loss_d']=loss_d.item() / (batch + 1) 328 | res['loss_tot']=loss_total.item() / (batch + 1) 329 | res['size']=self.loader_size(data_loader) 330 | 331 | return res 332 | 333 | # 334 | 335 | def test(self, data_loader, task_id, model): 336 | loss_a, loss_t, loss_d, loss_total=0, 0, 0, 0 337 | correct_d, correct_t=0, 0 338 | num=0 339 | batch=0 340 | 341 | model.eval() 342 | self.discriminator.eval() 343 | 344 | res={} 345 | with torch.no_grad(): 346 | for batch, (data, target, tt, td) in enumerate(data_loader): 347 | x=data.to(device=self.device) 348 | y=target.to(device=self.device, dtype=torch.long) 349 | tt=tt.to(device=self.device) 350 | t_real_D=td.to(self.device) 351 | 352 | # Forward 353 | output=model.forward(x, x, tt, task_id) 354 | shared_out, task_out=model.get_encoded_ftrs(x, x, task_id) 355 | 356 | _, pred=output.max(1) 357 | correct_t+=pred.eq(y.view_as(pred)).sum().item() 358 | 359 | # Discriminator's performance: 360 | output_d=self.discriminator.forward(shared_out, tt, task_id) 361 | _, pred_d=output_d.max(1) 362 | correct_d+=pred_d.eq(t_real_D.view_as(pred_d)).sum().item() 363 | 364 | if self.diff == 'yes': 365 | diff_loss=self.diff_loss(shared_out, task_out) 366 | else: 367 | diff_loss=torch.tensor(0).to(device=self.device, dtype=torch.float32) 368 | self.diff_loss_reg=0 369 | 370 | # Loss values 371 | adv_loss=self.adversarial_loss_d(output_d, t_real_D) 372 | task_loss=self.task_loss(output, y) 373 | 374 | total_loss=task_loss + self.adv_loss_reg * adv_loss + self.diff_loss_reg * diff_loss 375 | 376 | loss_t+=task_loss 377 | loss_a+=adv_loss 378 | loss_d+=diff_loss 379 | loss_total+=total_loss 380 | 381 | num+=x.size(0) 382 | 383 | res['loss_t'], res['acc_t']=loss_t.item() / (batch + 1), 100 * correct_t / num 384 | res['loss_a'], res['acc_d']=loss_a.item() / (batch + 1), 100 * correct_d / num 385 | res['loss_d']=loss_d.item() / (batch + 1) 386 | res['loss_tot']=loss_total.item() / (batch + 1) 387 | res['size']=self.loader_size(data_loader) 388 | 389 | return res 390 | 391 | 392 | 393 | def save_all_models(self, task_id): 394 | print("Saving all models for task {} ...".format(task_id+1)) 395 | dis=utils.get_model(self.discriminator) 396 | torch.save({'model_state_dict': dis, 397 | }, os.path.join(self.checkpoint, 'discriminator_{}.pth.tar'.format(task_id))) 398 | 399 | model=utils.get_model(self.model) 400 | torch.save({'model_state_dict': model, 401 | }, os.path.join(self.checkpoint, 'model_{}.pth.tar'.format(task_id))) 402 | 403 | 404 | 405 | def load_model(self, task_id): 406 | 407 | # Load a previous model 408 | net=self.network.Net(self.args) 409 | checkpoint=torch.load(os.path.join(self.checkpoint, 'model_{}.pth.tar'.format(task_id))) 410 | net.load_state_dict(checkpoint['model_state_dict']) 411 | 412 | # # Change the previous shared module with the current one 413 | current_shared_module=deepcopy(self.model.shared.state_dict()) 414 | net.shared.load_state_dict(current_shared_module) 415 | 416 | net=net.to(self.args.device) 417 | return net 418 | 419 | 420 | def load_checkpoint(self, task_id): 421 | print("Loading checkpoint for task {} ...".format(task_id)) 422 | 423 | # Load a previous model 424 | net=self.network.Net(self.args) 425 | checkpoint=torch.load(os.path.join(self.checkpoint, 'model_{}.pth.tar'.format(task_id))) 426 | net.load_state_dict(checkpoint['model_state_dict']) 427 | net=net.to(self.args.device) 428 | return net 429 | 430 | 431 | def loader_size(self, data_loader): 432 | return data_loader.dataset.__len__() 433 | 434 | 435 | 436 | def get_tsne_embeddings_first_ten_tasks(self, dataset, model): 437 | from tensorboardX import SummaryWriter 438 | 439 | model.eval() 440 | 441 | tag_ = '_diff_{}'.format(self.args.diff) 442 | all_images, all_shared, all_private = [], [], [] 443 | 444 | # Test final model on first 10 tasks: 445 | writer = SummaryWriter() 446 | for t in range(10): 447 | for itr, (data, _, tt, td) in enumerate(dataset[t]['tsne']): 448 | x = data.to(device=self.device) 449 | tt = tt.to(device=self.device) 450 | output = model.forward(x, x, tt, t) 451 | shared_out, private_out = model.get_encoded_ftrs(x, x, t) 452 | all_shared.append(shared_out) 453 | all_private.append(private_out) 454 | all_images.append(x) 455 | 456 | print (torch.stack(all_shared).size()) 457 | 458 | tag = ['Shared10_{}_{}'.format(tag_,i) for i in range(1,11)] 459 | writer.add_embedding(mat=torch.stack(all_shared,dim=1).data, label_img=torch.stack(all_images,dim=1).data, metadata=list(range(1,11)), 460 | tag=tag)#, metadata_header=list(range(1,11))) 461 | 462 | tag = ['Private10_{}_{}'.format(tag_, i) for i in range(1, 11)] 463 | writer.add_embedding(mat=torch.stack(all_private,dim=1).data, label_img=torch.stack(all_images,dim=1).data, metadata=list(range(1,11)), 464 | tag=tag)#,metadata_header=list(range(1,11))) 465 | writer.close() 466 | 467 | 468 | def get_tsne_embeddings_last_three_tasks(self, dataset, model): 469 | from tensorboardX import SummaryWriter 470 | 471 | # Test final model on last 3 tasks: 472 | model.eval() 473 | tag = '_diff_{}'.format(self.args.diff) 474 | 475 | for t in [17,18,19]: 476 | all_images, all_labels, all_shared, all_private = [], [], [], [] 477 | writer = SummaryWriter() 478 | for itr, (data, target, tt, td) in enumerate(dataset[t]['tsne']): 479 | x = data.to(device=self.device) 480 | y = target.to(device=self.device, dtype=torch.long) 481 | tt = tt.to(device=self.device) 482 | output = model.forward(x, x, tt, t) 483 | shared_out, private_out = model.get_encoded_ftrs(x, x, t) 484 | # print (shared_out.size()) 485 | 486 | all_shared.append(shared_out) 487 | all_private.append(private_out) 488 | all_images.append(x) 489 | all_labels.append(y) 490 | 491 | writer.add_embedding(mat=torch.stack(all_shared,dim=1).data, label_img=torch.stack(all_images,dim=1).data, 492 | metadata=list(range(1,6)), tag='Shared_{}_{}'.format(t, tag)) 493 | # ,metadata_header=list(range(1,6))) 494 | writer.add_embedding(mat=torch.stack(all_private,dim=1).data, label_img=torch.stack(all_images,dim=1).data, 495 | metadata=list(range(1,6)), tag='Private_{}_{}'.format(t, tag)) 496 | # ,metadata_header=list(range(1,6))) 497 | 498 | writer.close() 499 | 500 | 501 | 502 | # 503 | class DiffLoss(torch.nn.Module): 504 | # From: Domain Separation Networks (https://arxiv.org/abs/1608.06019) 505 | # Konstantinos Bousmalis, George Trigeorgis, Nathan Silberman, Dilip Krishnan, Dumitru Erhan 506 | 507 | def __init__(self): 508 | super(DiffLoss, self).__init__() 509 | 510 | def forward(self, D1, D2): 511 | D1=D1.view(D1.size(0), -1) 512 | D1_norm=torch.norm(D1, p=2, dim=1, keepdim=True).detach() 513 | D1_norm=D1.div(D1_norm.expand_as(D1) + 1e-6) 514 | 515 | D2=D2.view(D2.size(0), -1) 516 | D2_norm=torch.norm(D2, p=2, dim=1, keepdim=True).detach() 517 | D2_norm=D2.div(D2_norm.expand_as(D2) + 1e-6) 518 | 519 | # return torch.mean((D1_norm.mm(D2_norm.t()).pow(2))) 520 | return torch.mean((D1_norm.mm(D2_norm.t()).pow(2))) 521 | -------------------------------------------------------------------------------- /src/configs/config_cifar100.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "cifar100" 8 | data_dir: "../data" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "yes" 13 | lam: 1 #diff loss lambda 14 | adv: 0.05 #adversarial loss regularizer 15 | orth: 0.1 #diff loss regularizer 16 | 17 | ntasks: 20 18 | use_memory: "no" 19 | samples: 0 20 | 21 | e_lr: 0.01 22 | e_wd: 0.01 23 | s_step: 5 24 | 25 | d_lr: 0.001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | nlayers: 2 35 | units: 175 36 | head_units: 32 37 | latent_dim: 128 38 | 39 | batch_size: 64 40 | nepochs: 200 41 | pc_valid: 0.15 42 | 43 | 44 | workers: 4 45 | device: "cuda:3" 46 | -------------------------------------------------------------------------------- /src/configs/config_miniimagenet.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "miniimagenet" 8 | data_dir: "../data" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "yes" 13 | lam: 1 #diff loss lambda 14 | adv: 0.005 #adversarial loss regularizer 15 | orth: 0.1 #diff loss regularizer 16 | 17 | ntasks: 20 18 | use_memory: "yes" 19 | samples: 1 20 | 21 | e_lr: 0.003 22 | e_wd: 0.01 23 | s_step: 5 24 | 25 | d_lr: 0.001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | nlayers: 2 35 | units: 175 36 | head_units: 128 37 | latent_dim: 256 38 | 39 | batch_size: 64 40 | nepochs: 200 41 | pc_valid: 0.02 42 | 43 | 44 | workers: 4 45 | device: "cuda:4" 46 | -------------------------------------------------------------------------------- /src/configs/config_mnist5.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "mnist5" 8 | data_dir: "../data" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "no" 13 | lam: 1 #diff loss lambda 14 | adv: 0.05 #adversarial loss regularizer 15 | orth: 0.1 #diff loss regularizer 16 | 17 | ntasks: 5 18 | use_memory: "no" 19 | samples: 0 20 | 21 | e_lr: 0.001 22 | e_wd: 0.01 23 | s_step: 20 24 | 25 | d_lr: 0.0001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | nlayers: 2 35 | units: 175 36 | head_units: 28 37 | latent_dim: 64 38 | 39 | batch_size: 32 40 | nepochs: 200 41 | pc_valid: 0.15 42 | 43 | workers: 4 44 | device: "cuda:7" 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /src/configs/config_multidatasets.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "multidatasets" 8 | data_dir: "../data/" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "no" 13 | lam: 1 #diff loss lambda 14 | adv: 0.05 #adversarial loss regularizer 15 | orth: 0.075 #diff loss regularizer 16 | 17 | ntasks: 5 18 | use_memory: "no" #memory is not supported for this experiment 19 | samples: 0 20 | 21 | e_lr: 0.01 22 | e_wd: 1.0e-4 23 | s_step: 5 24 | 25 | d_lr: 0.001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | nlayers: 2 35 | units: 175 36 | head_units: 220 37 | latent_dim: 130 38 | 39 | batch_size: 64 40 | nepochs: 200 41 | pc_valid: 0.15 42 | 43 | 44 | workers: 4 45 | device: "cuda:4" 46 | -------------------------------------------------------------------------------- /src/configs/config_pmnist.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | num_runs: 5 7 | experiment: "pmnist" 8 | data_dir: "../data" 9 | checkpoint: "./checkpoints/" 10 | output: "" 11 | tsne: "no" 12 | diff: "no" 13 | lam: 1 #diff loss lambda 14 | adv: 0.05 #adversarial loss regularizer 15 | orth: 0.1 #diff loss regularizer 16 | 17 | ntasks: 40 18 | use_memory: "no" 19 | samples: 0 20 | 21 | e_lr: 0.001 22 | e_wd: 0.01 23 | s_step: 5 24 | 25 | d_lr: 0.0001 26 | d_wd: 0.01 27 | d_step: 1 28 | 29 | lr_factor: 3 30 | lr_min: 1.0e-06 31 | lr_patience: 5 32 | mom: 0.9 33 | 34 | nlayers: 2 35 | units: 175 36 | head_units: 28 37 | latent_dim: 64 38 | 39 | batch_size: 64 40 | nepochs: 200 41 | pc_valid: 0.15 42 | 43 | 44 | workers: 4 45 | device: "cuda:7" 46 | 47 | 48 | -------------------------------------------------------------------------------- /src/dataloaders/cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import print_function 7 | from PIL import Image 8 | import os 9 | import os.path 10 | import sys 11 | 12 | 13 | if sys.version_info[0] == 2: 14 | import cPickle as pickle 15 | else: 16 | import pickle 17 | 18 | import torch.utils.data as data 19 | import numpy as np 20 | 21 | import torch 22 | from torchvision import datasets, transforms 23 | 24 | from utils import * 25 | 26 | 27 | class iCIFAR10(datasets.CIFAR10): 28 | 29 | def __init__(self, root, classes, memory_classes, memory, task_num, train, transform=None, target_transform=None, download=True): 30 | 31 | super(iCIFAR10, self).__init__(root, transform=transform, 32 | target_transform=target_transform, download=True) 33 | self.train = train # training set or test set 34 | if not isinstance(classes, list): 35 | classes = [classes] 36 | 37 | self.class_mapping = {c: i for i, c in enumerate(classes)} 38 | self.class_indices = {} 39 | 40 | for cls in classes: 41 | self.class_indices[self.class_mapping[cls]] = [] 42 | 43 | if self.train: 44 | train_data = [] 45 | train_labels = [] 46 | train_tt = [] # task module labels 47 | train_td = [] # disctiminator labels 48 | 49 | for i in range(len(self.data)): 50 | if self.targets[i] in classes: 51 | train_data.append(self.data[i]) 52 | train_labels.append(self.class_mapping[self.targets[i]]) 53 | train_tt.append(task_num) 54 | train_td.append(task_num+1) 55 | self.class_indices[self.class_mapping[self.targets[i]]].append(i) 56 | 57 | if memory_classes: 58 | for task_id in range(task_num): 59 | for i in range(len(memory[task_id]['x'])): 60 | if memory[task_id]['y'][i] in range(len(memory_classes[task_id])): 61 | train_data.append(memory[task_id]['x'][i]) 62 | train_labels.append(memory[task_id]['y'][i]) 63 | train_tt.append(memory[task_id]['tt'][i]) 64 | train_td.append(memory[task_id]['td'][i]) 65 | 66 | self.train_data = np.array(train_data) 67 | self.train_labels = train_labels 68 | self.train_tt = train_tt 69 | self.train_td = train_td 70 | 71 | 72 | if not self.train: 73 | f = self.test_list[0][0] 74 | file = os.path.join(self.root, self.base_folder, f) 75 | fo = open(file, 'rb') 76 | if sys.version_info[0] == 2: 77 | entry = pickle.load(fo) 78 | else: 79 | entry = pickle.load(fo, encoding='latin1') 80 | self.test_data = entry['data'] 81 | if 'labels' in entry: 82 | self.test_labels = entry['labels'] 83 | else: 84 | 85 | self.test_labels = entry['fine_labels'] 86 | fo.close() 87 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 88 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 89 | 90 | test_data = [] 91 | test_labels = [] 92 | test_tt = [] # task module labels 93 | test_td = [] # disctiminator labels 94 | for i in range(len(self.test_data)): 95 | if self.test_labels[i] in classes: 96 | test_data.append(self.test_data[i]) 97 | test_labels.append(self.class_mapping[self.test_labels[i]]) 98 | test_tt.append(task_num) 99 | test_td.append(task_num + 1) 100 | self.class_indices[self.class_mapping[self.test_labels[i]]].append(i) 101 | 102 | self.test_data = np.array(test_data) 103 | self.test_labels = test_labels 104 | self.test_tt = test_tt 105 | self.test_td = test_td 106 | 107 | 108 | def __getitem__(self, index): 109 | if self.train: 110 | img, target, tt, td = self.train_data[index], self.train_labels[index], self.train_tt[index], self.train_td[index] 111 | else: 112 | img, target, tt, td = self.test_data[index], self.test_labels[index], self.test_tt[index], self.test_td[index] 113 | 114 | # doing this so that it is consistent with all other datasets 115 | # to return a PIL Image 116 | try: 117 | img = Image.fromarray(img) 118 | except: 119 | pass 120 | 121 | try: 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | except: 125 | pass 126 | try: 127 | if self.target_transform is not None: 128 | target = self.target_transform(target) 129 | except: 130 | pass 131 | 132 | return img, target, tt, td 133 | 134 | 135 | 136 | 137 | def __len__(self): 138 | if self.train: 139 | return len(self.train_data) 140 | else: 141 | return len(self.test_data) 142 | 143 | 144 | 145 | class iCIFAR100(iCIFAR10): 146 | """`CIFAR100 `_ Dataset. 147 | This is a subclass of the `CIFAR10` Dataset. 148 | """ 149 | base_folder = 'cifar-100-python' 150 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 151 | filename = "cifar-100-python.tar.gz" 152 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 153 | train_list = [ 154 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 155 | ] 156 | 157 | test_list = [ 158 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 159 | ] 160 | meta = { 161 | 'filename': 'meta', 162 | 'key': 'fine_label_names', 163 | 'md5': '7973b15100ade9c7d40fb424638fde48', 164 | } 165 | 166 | 167 | 168 | class DatasetGen(object): 169 | """docstring for DatasetGen""" 170 | 171 | def __init__(self, args): 172 | super(DatasetGen, self).__init__() 173 | 174 | self.seed = args.seed 175 | self.batch_size=args.batch_size 176 | self.pc_valid=args.pc_valid 177 | self.root = args.data_dir 178 | self.latent_dim = args.latent_dim 179 | 180 | self.num_tasks = args.ntasks 181 | self.num_classes = 100 182 | 183 | self.num_samples = args.samples 184 | 185 | 186 | self.inputsize = [3,32,32] 187 | mean=[x/255 for x in [125.3,123.0,113.9]] 188 | std=[x/255 for x in [63.0,62.1,66.7]] 189 | 190 | self.transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 191 | 192 | self.taskcla = [[t, int(self.num_classes/self.num_tasks)] for t in range(self.num_tasks)] 193 | 194 | self.indices = {} 195 | self.dataloaders = {} 196 | self.idx={} 197 | 198 | self.num_workers = args.workers 199 | self.pin_memory = True 200 | 201 | np.random.seed(self.seed) 202 | task_ids = np.split(np.random.permutation(self.num_classes),self.num_tasks) 203 | self.task_ids = [list(arr) for arr in task_ids] 204 | 205 | 206 | self.train_set = {} 207 | self.test_set = {} 208 | self.train_split = {} 209 | 210 | self.task_memory = {} 211 | for i in range(self.num_tasks): 212 | self.task_memory[i] = {} 213 | self.task_memory[i]['x'] = [] 214 | self.task_memory[i]['y'] = [] 215 | self.task_memory[i]['tt'] = [] 216 | self.task_memory[i]['td'] = [] 217 | 218 | self.use_memory = args.use_memory 219 | 220 | def get(self, task_id): 221 | 222 | self.dataloaders[task_id] = {} 223 | sys.stdout.flush() 224 | 225 | 226 | if task_id == 0: 227 | memory_classes = None 228 | memory=None 229 | else: 230 | memory_classes = self.task_ids 231 | memory = self.task_memory 232 | 233 | self.train_set[task_id] = iCIFAR100(root=self.root, classes=self.task_ids[task_id], memory_classes=memory_classes, 234 | memory=memory, task_num=task_id, train=True, download=True, transform=self.transformation) 235 | self.test_set[task_id] = iCIFAR100(root=self.root, classes=self.task_ids[task_id], memory_classes=None, 236 | memory=None, task_num=task_id, train=False, 237 | download=True, transform=self.transformation) 238 | 239 | 240 | 241 | 242 | 243 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 244 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split]) 245 | 246 | self.train_split[task_id] = train_split 247 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, num_workers=self.num_workers, 248 | pin_memory=self.pin_memory,shuffle=True) 249 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=int(self.batch_size * self.pc_valid), 250 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 251 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 252 | pin_memory=self.pin_memory,shuffle=True) 253 | 254 | 255 | self.dataloaders[task_id]['train'] = train_loader 256 | self.dataloaders[task_id]['valid'] = valid_loader 257 | self.dataloaders[task_id]['test'] = test_loader 258 | self.dataloaders[task_id]['name'] = 'CIFAR100-{}-{}'.format(task_id,self.task_ids[task_id]) 259 | 260 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 261 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 262 | print ("Train+Val set size: {} images of {}x{}".format(len(valid_loader.dataset)+len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 263 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 264 | 265 | if self.use_memory == 'yes' and self.num_samples > 0 : 266 | self.update_memory(task_id) 267 | 268 | return self.dataloaders 269 | 270 | 271 | 272 | def update_memory(self, task_id): 273 | 274 | num_samples_per_class = self.num_samples // len(self.task_ids[task_id]) 275 | mem_class_mapping = {i: i for i, c in enumerate(self.task_ids[task_id])} 276 | 277 | 278 | # Looping over each class in the current task 279 | for i in range(len(self.task_ids[task_id])): 280 | # Getting all samples for this class 281 | data_loader = torch.utils.data.DataLoader(self.train_split[task_id], batch_size=1, 282 | num_workers=self.num_workers, 283 | pin_memory=self.pin_memory) 284 | # Randomly choosing num_samples_per_class for this class 285 | randind = torch.randperm(len(data_loader.dataset))[:num_samples_per_class] 286 | 287 | # Adding the selected samples to memory 288 | for ind in randind: 289 | self.task_memory[task_id]['x'].append(data_loader.dataset[ind][0]) 290 | self.task_memory[task_id]['y'].append(mem_class_mapping[i]) 291 | self.task_memory[task_id]['tt'].append(data_loader.dataset[ind][2]) 292 | self.task_memory[task_id]['td'].append(data_loader.dataset[ind][3]) 293 | 294 | print ('Memory updated by adding {} images'.format(len(self.task_memory[task_id]['x']))) 295 | 296 | 297 | -------------------------------------------------------------------------------- /src/dataloaders/miniimagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import print_function 7 | from PIL import Image 8 | import os 9 | import os.path 10 | import sys 11 | 12 | if sys.version_info[0] == 2: 13 | import cPickle as pickle 14 | else: 15 | import pickle 16 | 17 | import torch.utils.data as data 18 | import numpy as np 19 | 20 | import torch 21 | from torchvision import transforms 22 | 23 | from utils import * 24 | 25 | 26 | 27 | class MiniImageNet(torch.utils.data.Dataset): 28 | 29 | def __init__(self, root, train): 30 | super(MiniImageNet, self).__init__() 31 | if train: 32 | self.name='train' 33 | else: 34 | self.name='test' 35 | root = os.path.join(root, 'miniimagenet') 36 | with open(os.path.join(root,'{}.pkl'.format(self.name)), 'rb') as f: 37 | data_dict = pickle.load(f) 38 | 39 | self.data = data_dict['images'] 40 | self.labels = data_dict['labels'] 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self, i): 46 | img, label = self.data[i], self.labels[i] 47 | return img, label 48 | 49 | 50 | class iMiniImageNet(MiniImageNet): 51 | 52 | def __init__(self, root, classes, memory_classes, memory, task_num, train, transform=None): 53 | super(iMiniImageNet, self).__init__(root=root, train=train) 54 | 55 | self.transform = transform 56 | if not isinstance(classes, list): 57 | classes = [classes] 58 | 59 | self.class_mapping = {c: i for i, c in enumerate(classes)} 60 | self.class_indices = {} 61 | 62 | for cls in classes: 63 | self.class_indices[self.class_mapping[cls]] = [] 64 | 65 | data = [] 66 | labels = [] 67 | tt = [] # task module labels 68 | td = [] # disctiminator labels 69 | 70 | for i in range(len(self.data)): 71 | if self.labels[i] in classes: 72 | data.append(self.data[i]) 73 | labels.append(self.class_mapping[self.labels[i]]) 74 | tt.append(task_num) 75 | td.append(task_num+1) 76 | self.class_indices[self.class_mapping[self.labels[i]]].append(i) 77 | 78 | if memory_classes: 79 | for task_id in range(task_num): 80 | for i in range(len(memory[task_id]['x'])): 81 | if memory[task_id]['y'][i] in range(len(memory_classes[task_id])): 82 | data.append(memory[task_id]['x'][i]) 83 | labels.append(memory[task_id]['y'][i]) 84 | tt.append(memory[task_id]['tt'][i]) 85 | td.append(memory[task_id]['td'][i]) 86 | 87 | self.data = np.array(data) 88 | self.labels = labels 89 | self.tt = tt 90 | self.td = td 91 | 92 | 93 | 94 | def __getitem__(self, index): 95 | 96 | img, target, tt, td = self.data[index], self.labels[index], self.tt[index], self.td[index] 97 | 98 | # doing this so that it is consistent with all other datasets 99 | # to return a PIL Image 100 | if not torch.is_tensor(img): 101 | img = Image.fromarray(img) 102 | img = self.transform(img) 103 | return img, target, tt, td 104 | 105 | 106 | 107 | 108 | def __len__(self): 109 | return len(self.data) 110 | 111 | 112 | 113 | 114 | class DatasetGen(object): 115 | """docstring for DatasetGen""" 116 | 117 | def __init__(self, args): 118 | super(DatasetGen, self).__init__() 119 | 120 | self.seed = args.seed 121 | self.batch_size=args.batch_size 122 | self.pc_valid=args.pc_valid 123 | self.root = args.data_dir 124 | self.latent_dim = args.latent_dim 125 | self.use_memory = args.use_memory 126 | 127 | self.num_tasks = args.ntasks 128 | self.num_classes = 100 129 | 130 | self.num_samples = args.samples 131 | 132 | self.inputsize = [3,84,84] 133 | mean = [0.485, 0.456, 0.406] 134 | std = [0.229, 0.224, 0.225] 135 | 136 | self.transformation = transforms.Compose([ 137 | transforms.Resize((84,84)), 138 | transforms.ToTensor(), 139 | transforms.Normalize(mean=mean, std=std)]) 140 | 141 | self.taskcla = [[t, int(self.num_classes/self.num_tasks)] for t in range(self.num_tasks)] 142 | 143 | self.indices = {} 144 | self.dataloaders = {} 145 | self.idx={} 146 | 147 | self.num_workers = args.workers 148 | self.pin_memory = True 149 | 150 | np.random.seed(self.seed) 151 | task_ids = np.split(np.random.permutation(self.num_classes),self.num_tasks) 152 | self.task_ids = [list(arr) for arr in task_ids] 153 | 154 | self.train_set = {} 155 | self.train_split = {} 156 | self.test_set = {} 157 | 158 | 159 | self.task_memory = {} 160 | for i in range(self.num_tasks): 161 | self.task_memory[i] = {} 162 | self.task_memory[i]['x'] = [] 163 | self.task_memory[i]['y'] = [] 164 | self.task_memory[i]['tt'] = [] 165 | self.task_memory[i]['td'] = [] 166 | 167 | 168 | 169 | def get(self, task_id): 170 | 171 | self.dataloaders[task_id] = {} 172 | sys.stdout.flush() 173 | 174 | if task_id == 0: 175 | memory_classes = None 176 | memory=None 177 | else: 178 | memory_classes = self.task_ids 179 | memory = self.task_memory 180 | 181 | 182 | self.train_set[task_id] = iMiniImageNet(root=self.root, classes=self.task_ids[task_id], 183 | memory_classes=memory_classes, memory=memory, 184 | task_num=task_id, train=True, transform=self.transformation) 185 | 186 | self.test_set[task_id] = iMiniImageNet(root=self.root, classes=self.task_ids[task_id], memory_classes=None, 187 | memory=None, task_num=task_id, train=False, transform=self.transformation) 188 | 189 | 190 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 191 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split]) 192 | self.train_split[task_id] = train_split 193 | 194 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, num_workers=self.num_workers, 195 | pin_memory=self.pin_memory,shuffle=True) 196 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=int(self.batch_size * self.pc_valid), 197 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 198 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 199 | pin_memory=self.pin_memory, shuffle=True) 200 | 201 | 202 | self.dataloaders[task_id]['train'] = train_loader 203 | self.dataloaders[task_id]['valid'] = valid_loader 204 | self.dataloaders[task_id]['test'] = test_loader 205 | self.dataloaders[task_id]['name'] = 'iMiniImageNet-{}-{}'.format(task_id,self.task_ids[task_id]) 206 | self.dataloaders[task_id]['tsne'] = torch.utils.data.DataLoader(self.test_set[task_id], 207 | batch_size=len(test_loader.dataset), 208 | num_workers=self.num_workers, 209 | pin_memory=self.pin_memory, shuffle=True) 210 | 211 | print ("Task ID: ", task_id) 212 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 213 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 214 | print ("Train+Val set size: {} images of {}x{}".format(len(valid_loader.dataset)+len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 215 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 216 | 217 | if self.use_memory == 'yes' and self.num_samples > 0 : 218 | self.update_memory(task_id) 219 | 220 | 221 | return self.dataloaders 222 | 223 | 224 | 225 | def update_memory(self, task_id): 226 | num_samples_per_class = self.num_samples // len(self.task_ids[task_id]) 227 | mem_class_mapping = {i: i for i, c in enumerate(self.task_ids[task_id])} 228 | 229 | for i in range(len(self.task_ids[task_id])): 230 | data_loader = torch.utils.data.DataLoader(self.train_split[task_id], batch_size=1, 231 | num_workers=self.num_workers, 232 | pin_memory=self.pin_memory) 233 | 234 | randind = torch.randperm(len(data_loader.dataset))[:num_samples_per_class] # randomly sample some data 235 | 236 | 237 | for ind in randind: 238 | self.task_memory[task_id]['x'].append(data_loader.dataset[ind][0]) 239 | self.task_memory[task_id]['y'].append(mem_class_mapping[i]) 240 | self.task_memory[task_id]['tt'].append(data_loader.dataset[ind][2]) 241 | self.task_memory[task_id]['td'].append(data_loader.dataset[ind][3]) 242 | 243 | print ('Memory updated by adding {} images'.format(len(self.task_memory[task_id]['x']))) -------------------------------------------------------------------------------- /src/dataloaders/mnist5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import print_function 7 | from PIL import Image 8 | import torch 9 | import numpy as np 10 | import os.path 11 | import sys 12 | 13 | import torch.utils.data as data 14 | from torchvision import datasets, transforms 15 | 16 | 17 | 18 | class iMNIST(datasets.MNIST): 19 | 20 | def __init__(self, root, classes, memory_classes, memory, task_num, train, transform=None, target_transform=None, download=True): 21 | 22 | super(iMNIST, self).__init__(root, task_num, transform=transform, 23 | target_transform=target_transform, download=download) 24 | 25 | self.train = train # training set or test set 26 | self.root = root 27 | self.target_transform=target_transform 28 | self.transform=transform 29 | if download: 30 | self.download() 31 | 32 | if not self._check_exists(): 33 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') 34 | 35 | if self.train: 36 | data_file = self.training_file 37 | else: 38 | data_file = self.test_file 39 | 40 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 41 | self.data=np.array(self.data).astype(np.float32) 42 | self.targets=list(np.array(self.targets)) 43 | 44 | self.train = train # training set or test set 45 | if not isinstance(classes, list): 46 | classes = [classes] 47 | 48 | self.class_mapping = {c: i for i, c in enumerate(classes)} 49 | self.class_indices = {} 50 | 51 | for cls in classes: 52 | self.class_indices[self.class_mapping[cls]] = [] 53 | 54 | 55 | data = [] 56 | targets = [] 57 | tt = [] # task module labels 58 | td = [] # discriminator labels 59 | 60 | for i in range(len(self.data)): 61 | if self.targets[i] in classes: 62 | data.append(self.data[i]) 63 | targets.append(self.class_mapping[self.targets[i]]) 64 | tt.append(task_num) 65 | td.append(task_num+1) 66 | self.class_indices[self.class_mapping[self.targets[i]]].append(i) 67 | 68 | 69 | if self.train: 70 | if memory_classes: 71 | for task_id in range(task_num): 72 | for i in range(len(memory[task_id]['x'])): 73 | if memory[task_id]['y'][i] in range(len(memory_classes[task_id])): 74 | data.append(memory[task_id]['x'][i]) 75 | targets.append(memory[task_id]['y'][i]) 76 | tt.append(memory[task_id]['tt'][i]) 77 | td.append(memory[task_id]['td'][i]) 78 | 79 | 80 | self.data = data.copy() 81 | self.targets = targets.copy() 82 | self.tt = tt.copy() 83 | self.td = td.copy() 84 | 85 | 86 | 87 | def __getitem__(self, index): 88 | """ 89 | Args: 90 | index (int): Index 91 | 92 | Returns: 93 | tuple: (image, target) where target is index of the target class. 94 | """ 95 | img, target, tt, td = self.data[index], int(self.targets[index]), self.tt[index], self.td[index] 96 | 97 | # doing this so that it is consistent with all other datasets 98 | # to return a PIL Image 99 | try: 100 | img = Image.fromarray(img.numpy(), mode='L') 101 | except: 102 | pass 103 | 104 | try: 105 | if self.transform is not None: img = self.transform(img) 106 | except: 107 | pass 108 | try: 109 | if self.target_transform is not None: tt = self.target_transform(tt) 110 | if self.target_transform is not None: td = self.target_transform(td) 111 | except: 112 | pass 113 | 114 | return img, target, tt, td 115 | 116 | 117 | def __len__(self): 118 | return len(self.data) 119 | 120 | 121 | 122 | 123 | class DatasetGen(object): 124 | """docstring for DatasetGen""" 125 | 126 | def __init__(self, args): 127 | super(DatasetGen, self).__init__() 128 | 129 | self.seed = args.seed 130 | self.batch_size=args.batch_size 131 | self.pc_valid=args.pc_valid 132 | self.root = args.data_dir 133 | self.latent_dim = args.latent_dim 134 | 135 | self.num_tasks = args.ntasks 136 | self.num_classes = 10 137 | 138 | self.num_samples = args.samples 139 | 140 | self.inputsize = [1,28,28] 141 | mean = (0.1307,) 142 | std = (0.3081,) 143 | 144 | self.transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 145 | 146 | self.taskcla = [[t, int(self.num_classes/self.num_tasks)] for t in range(self.num_tasks)] 147 | 148 | self.indices = {} 149 | self.dataloaders = {} 150 | self.idx={} 151 | 152 | self.num_workers = args.workers 153 | self.pin_memory = True 154 | 155 | np.random.seed(self.seed) 156 | self.task_ids = [[0,1], [2,3], [4,5], [6,7], [8,9]] 157 | 158 | self.train_set = {} 159 | self.test_set = {} 160 | 161 | self.task_memory = {} 162 | for i in range(self.num_tasks): 163 | self.task_memory[i] = {} 164 | self.task_memory[i]['x'] = [] 165 | self.task_memory[i]['y'] = [] 166 | self.task_memory[i]['tt'] = [] 167 | self.task_memory[i]['td'] = [] 168 | 169 | 170 | def get(self, task_id): 171 | 172 | self.dataloaders[task_id] = {} 173 | sys.stdout.flush() 174 | 175 | if task_id == 0: 176 | memory_classes = None 177 | memory=None 178 | else: 179 | memory_classes = self.task_ids 180 | memory = self.task_memory 181 | 182 | self.train_set[task_id] = iMNIST(root=self.root, classes=self.task_ids[task_id], memory_classes=memory_classes, 183 | memory=memory, task_num=task_id, train=True, 184 | download=True, transform=self.transformation) 185 | self.test_set[task_id] = iMNIST(root=self.root, classes=self.task_ids[task_id], memory_classes=None, 186 | memory=None, task_num=task_id, train=False, 187 | download=True, transform=self.transformation) 188 | 189 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 190 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split]) 191 | 192 | 193 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, num_workers=self.num_workers, 194 | pin_memory=self.pin_memory, drop_last=True,shuffle=True) 195 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=int(self.batch_size * self.pc_valid),shuffle=True, 196 | num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=True) 197 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 198 | pin_memory=self.pin_memory, drop_last=True,shuffle=True) 199 | 200 | self.dataloaders[task_id]['train'] = train_loader 201 | self.dataloaders[task_id]['valid'] = valid_loader 202 | self.dataloaders[task_id]['test'] = test_loader 203 | self.dataloaders[task_id]['name'] = '5Split-MNIST-{}-{}'.format(task_id,self.task_ids[task_id]) 204 | 205 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 206 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 207 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 208 | return self.dataloaders 209 | 210 | 211 | def update_memory(self, task_id): 212 | num_samples_per_class = self.num_samples // len(self.task_ids[task_id]) 213 | mem_class_mapping = {i: i for i, c in enumerate(self.task_ids[task_id])} 214 | 215 | # Looping over each class in the current task 216 | for i in range(len(self.task_ids[task_id])): 217 | 218 | dataset = iMNIST(root=self.root, classes=self.task_ids[task_id][i], memory_classes=None, memory=None, 219 | task_num=task_id, train=True, download=True, transform=self.transformation) 220 | 221 | data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, 222 | num_workers=self.num_workers, 223 | pin_memory=self.pin_memory) 224 | 225 | # Randomly choosing num_samples_per_class for this class 226 | randind = torch.randperm(len(data_loader.dataset))[:num_samples_per_class] 227 | 228 | # Adding the selected samples to memory 229 | for ind in randind: 230 | self.task_memory[task_id]['x'].append(data_loader.dataset[ind][0]) 231 | self.task_memory[task_id]['y'].append(mem_class_mapping[i]) 232 | self.task_memory[task_id]['tt'].append(data_loader.dataset[ind][2]) 233 | self.task_memory[task_id]['td'].append(data_loader.dataset[ind][3]) -------------------------------------------------------------------------------- /src/dataloaders/mulitidatasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import print_function 7 | import sys 8 | 9 | if sys.version_info[0] == 2: 10 | import cPickle as pickle 11 | else: 12 | import pickle 13 | 14 | import torch.utils.data as data 15 | import torch.utils.data 16 | from .datasets_utils import * 17 | from utils import * 18 | from torchvision import transforms 19 | 20 | mean_datasets = { 21 | 'CIFAR10': [x/255 for x in [125.3,123.0,113.9]], 22 | 'notMNIST': (0.4254,), 23 | 'MNIST': (0.1,) , 24 | 'SVHN':[0.4377,0.4438,0.4728] , 25 | 'FashionMNIST': (0.2190,), 26 | 27 | } 28 | std_datasets = { 29 | 'CIFAR10': [x/255 for x in [63.0,62.1,66.7]], 30 | 'notMNIST': (0.4501,), 31 | 'MNIST': (0.2752,), 32 | 'SVHN': [0.198,0.201,0.197], 33 | 'FashionMNIST': (0.3318,) 34 | } 35 | 36 | classes_datasets = { 37 | 'CIFAR10': 10, 38 | 'notMNIST': 10, 39 | 'MNIST': 10, 40 | 'SVHN': 10, 41 | 'FashionMNIST': 10, 42 | } 43 | 44 | lr_datasets = { 45 | 'CIFAR10': 0.001, 46 | 'notMNIST': 0.01, 47 | 'MNIST': 0.01, 48 | 'SVHN': 0.001, 49 | 'FashionMNIST': 0.01, 50 | 51 | } 52 | 53 | 54 | gray_datasets = { 55 | 'CIFAR10': False, 56 | 'notMNIST': True, 57 | 'MNIST': True, 58 | 'SVHN': False, 59 | 'FashionMNIST': True, 60 | } 61 | 62 | 63 | 64 | class DatasetGen(object): 65 | """docstring for DatasetGen""" 66 | 67 | def __init__(self, args): 68 | super(DatasetGen, self).__init__() 69 | 70 | self.seed = args.seed 71 | self.batch_size=args.batch_size 72 | self.pc_valid=args.pc_valid 73 | self.root = args.data_dir 74 | self.latent_dim = args.latent_dim 75 | 76 | self.num_tasks = args.ntasks 77 | self.num_samples = args.samples 78 | 79 | 80 | self.inputsize = [3,32,32] 81 | 82 | self.indices = {} 83 | self.dataloaders = {} 84 | self.idx={} 85 | 86 | self.num_workers = args.workers 87 | self.pin_memory = True 88 | 89 | np.random.seed(self.seed) 90 | self.datasets_idx = list(np.random.permutation(self.num_tasks)) 91 | print('Task order =', [list(classes_datasets.keys())[item] for item in self.datasets_idx]) 92 | self.datasets_names = [list(classes_datasets.keys())[item] for item in self.datasets_idx] 93 | 94 | 95 | self.taskcla = [] 96 | self.lrs = [] 97 | 98 | for i in range(self.num_tasks): 99 | t = self.datasets_idx[i] 100 | self.taskcla.append([i, list(classes_datasets.values())[t]]) 101 | self.lrs.append([i, list(lr_datasets.values())[t]]) 102 | print('Learning Rates =', self.lrs) 103 | print('taskcla =', self.taskcla) 104 | 105 | 106 | self.train_set = {} 107 | self.train_split = {} 108 | self.test_set = {} 109 | 110 | self.args=args 111 | 112 | 113 | self.dataloaders, self.memory_set = {}, {} 114 | self.memoryloaders = {} 115 | 116 | self.dataloaders, self.memory_set, self.indices = {}, {}, {} 117 | self.memoryloaders = {} 118 | self.saliency_loaders, self.saliency_set = {}, {} 119 | 120 | for i in range(self.num_tasks): 121 | self.dataloaders[i] = {} 122 | self.memory_set[i] = {} 123 | self.memoryloaders[i] = {} 124 | self.indices[i] = {} 125 | # self.saliency_set = {} 126 | self.saliency_loaders[i] = {} 127 | 128 | self.download = True 129 | 130 | self.train_set = {} 131 | self.test_set = {} 132 | self.train_split = {} 133 | 134 | self.task_memory = {} 135 | for i in range(self.num_tasks): 136 | self.task_memory[i] = {} 137 | self.task_memory[i]['x'] = [] 138 | self.task_memory[i]['y'] = [] 139 | self.task_memory[i]['tt'] = [] 140 | self.task_memory[i]['td'] = [] 141 | 142 | self.use_memory = args.use_memory 143 | 144 | 145 | def get_dataset(self, dataset_idx, task_num, num_samples_per_class=False, normalize=True): 146 | dataset_name = list(mean_datasets.keys())[dataset_idx] 147 | nspc = num_samples_per_class 148 | if normalize: 149 | transformation = transforms.Compose([transforms.ToTensor(), 150 | transforms.Normalize(mean_datasets[dataset_name],std_datasets[dataset_name])]) 151 | mnist_transformation = transforms.Compose([ 152 | transforms.Pad(padding=2, fill=0), 153 | transforms.ToTensor(), 154 | transforms.Normalize(mean_datasets[dataset_name], std_datasets[dataset_name])]) 155 | else: 156 | transformation = transforms.Compose([transforms.ToTensor()]) 157 | mnist_transformation = transforms.Compose([ 158 | transforms.Pad(padding=2, fill=0), 159 | transforms.ToTensor(), 160 | ]) 161 | 162 | # target_transormation = transforms.Compose([transforms.ToTensor()]) 163 | target_transormation = None 164 | 165 | if dataset_idx == 0: 166 | trainset = CIFAR10_(root=self.root, task_num=task_num, num_samples_per_class=nspc, train=True, download=self.download, target_transform = target_transormation, transform=transformation) 167 | testset = CIFAR10_(root=self.root, task_num=task_num, num_samples_per_class=nspc, train=False, download=self.download, target_transform = target_transormation, transform=transformation) 168 | 169 | if dataset_idx == 1: 170 | trainset = notMNIST_(root=self.root, task_num=task_num, num_samples_per_class=nspc, train=True, download=self.download, target_transform = target_transormation, transform=mnist_transformation) 171 | testset = notMNIST_(root=self.root, task_num=task_num, num_samples_per_class=nspc, train=False, download=self.download, target_transform = target_transormation, transform=mnist_transformation) 172 | 173 | if dataset_idx == 2: 174 | trainset = MNIST_RGB(root=self.root, train=True, num_samples_per_class=nspc, task_num=task_num, download=self.download, target_transform = target_transormation, transform=mnist_transformation) 175 | testset = MNIST_RGB(root=self.root, train=False, num_samples_per_class=nspc, task_num=task_num, download=self.download, target_transform = target_transormation, transform=mnist_transformation) 176 | 177 | if dataset_idx == 3: 178 | trainset = SVHN_(root=self.root, train=True, num_samples_per_class=nspc, task_num=task_num, download=self.download, target_transform = target_transormation, transform=transformation) 179 | testset = SVHN_(root=self.root, train=False, num_samples_per_class=nspc, task_num=task_num, download=self.download, target_transform = target_transormation, transform=transformation) 180 | 181 | if dataset_idx == 4: 182 | trainset = FashionMNIST_(root=self.root, num_samples_per_class=nspc, task_num=task_num, train=True, download=self.download, target_transform = target_transormation, transform=mnist_transformation) 183 | testset = FashionMNIST_(root=self.root, num_samples_per_class=nspc, task_num=task_num, train=False, download=self.download, target_transform = target_transormation, transform=mnist_transformation) 184 | 185 | return trainset, testset 186 | 187 | 188 | def get(self, task_id): 189 | 190 | self.dataloaders[task_id] = {} 191 | sys.stdout.flush() 192 | 193 | current_dataset_idx = self.datasets_idx[task_id] 194 | dataset_name = list(mean_datasets.keys())[current_dataset_idx] 195 | self.train_set[task_id], self.test_set[task_id] = self.get_dataset(current_dataset_idx,task_id) 196 | 197 | self.num_classes = classes_datasets[dataset_name] 198 | 199 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 200 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], [len(self.train_set[task_id]) - split, split]) 201 | 202 | self.train_split[task_id] = train_split 203 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, num_workers=self.num_workers, 204 | pin_memory=self.pin_memory,shuffle=True) 205 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=int(self.batch_size * self.pc_valid), 206 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 207 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, num_workers=self.num_workers, 208 | pin_memory=self.pin_memory,shuffle=True) 209 | 210 | 211 | self.dataloaders[task_id]['train'] = train_loader 212 | self.dataloaders[task_id]['valid'] = valid_loader 213 | self.dataloaders[task_id]['test'] = test_loader 214 | self.dataloaders[task_id]['name'] = '{} - {} classes - {} images'.format(dataset_name, 215 | classes_datasets[dataset_name], 216 | len(self.train_set[task_id])) 217 | self.dataloaders[task_id]['classes'] = self.num_classes 218 | 219 | 220 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 221 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 222 | print ("Train+Val set size: {} images of {}x{}".format(len(valid_loader.dataset)+len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 223 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 224 | 225 | if self.use_memory == 'yes' and self.num_samples > 0 : 226 | self.update_memory(task_id) 227 | 228 | return self.dataloaders 229 | 230 | 231 | 232 | def update_memory(self, task_id): 233 | 234 | num_samples_per_class = self.num_samples // len(self.task_ids[task_id]) 235 | mem_class_mapping = {i: i for i, c in enumerate(self.task_ids[task_id])} 236 | 237 | # Looping over each class in the current task 238 | for i in range(len(self.task_ids[task_id])): 239 | # Getting all samples for this class 240 | data_loader = torch.utils.data.DataLoader(self.train_split[task_id], batch_size=1, 241 | num_workers=self.num_workers, 242 | pin_memory=self.pin_memory) 243 | # Randomly choosing num_samples_per_class for this class 244 | randind = torch.randperm(len(data_loader.dataset))[:num_samples_per_class] 245 | 246 | # Adding the selected samples to memory 247 | for ind in randind: 248 | self.task_memory[task_id]['x'].append(data_loader.dataset[ind][0]) 249 | self.task_memory[task_id]['y'].append(mem_class_mapping[i]) 250 | self.task_memory[task_id]['tt'].append(data_loader.dataset[ind][2]) 251 | self.task_memory[task_id]['td'].append(data_loader.dataset[ind][3]) 252 | 253 | print('Memory updated by adding {} images'.format(len(self.task_memory[task_id]['x']))) 254 | 255 | 256 | 257 | def report_size(self,dataset_name,task_id): 258 | print("Dataset {} size: {} ".format(dataset_name, len(self.train_set[task_id]))) 259 | 260 | -------------------------------------------------------------------------------- /src/dataloaders/pmnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import sys, os 8 | import numpy as np 9 | from PIL import Image 10 | import torch.utils.data as data 11 | from torchvision import datasets, transforms 12 | from sklearn.utils import shuffle 13 | from utils import * 14 | 15 | 16 | class PermutedMNIST(datasets.MNIST): 17 | 18 | def __init__(self, root, task_num, train=True, permute_idx=None, transform=None): 19 | super(PermutedMNIST, self).__init__(root, train, download=True) 20 | 21 | if self.train: 22 | data_file = self.training_file 23 | else: 24 | data_file = self.test_file 25 | 26 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 27 | 28 | self.data = torch.stack([img.float().view(-1)[permute_idx] for img in self.data]) 29 | self.tl = (task_num) * torch.ones(len(self.data),dtype=torch.long) 30 | self.td = (task_num+1) * torch.ones(len(self.data),dtype=torch.long) 31 | 32 | 33 | def __getitem__(self, index): 34 | 35 | img, target, tl, td = self.data[index], self.targets[index], self.tl[index], self.td[index] 36 | 37 | if self.transform is not None: 38 | img = self.transform(img) 39 | if self.target_transform is not None: 40 | print ("We are transforming") 41 | target = self.target_transform(target) 42 | 43 | return img, target, tl, td 44 | 45 | def __len__(self): 46 | return self.data.size(0) 47 | 48 | 49 | 50 | class DatasetGen(object): 51 | 52 | def __init__(self, args): 53 | super(DatasetGen, self).__init__() 54 | 55 | self.seed = args.seed 56 | self.batch_size=args.batch_size 57 | self.pc_valid=args.pc_valid 58 | self.num_samples = args.samples 59 | self.num_tasks = args.ntasks 60 | self.root = args.data_dir 61 | self.use_memory = args.use_memory 62 | 63 | self.inputsize = [1, 28, 28] 64 | mean = (0.1307,) 65 | std = (0.3081,) 66 | self.transformation = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)]) 67 | 68 | self.taskcla = [[t, 10] for t in range(self.num_tasks)] 69 | 70 | self.train_set, self.test_set = {}, {} 71 | self.indices = {} 72 | self.dataloaders = {} 73 | self.idx={} 74 | self.get_idx() 75 | 76 | self.pin_memory = True 77 | self.num_workers = args.workers 78 | 79 | self.task_memory = [] 80 | 81 | 82 | 83 | def get(self, task_id): 84 | 85 | self.dataloaders[task_id] = {} 86 | sys.stdout.flush() 87 | 88 | if task_id == 0: 89 | self.train_set[task_id] = PermutedMNIST(root=self.root, task_num=task_id, train=True, 90 | permute_idx=self.idx[task_id], transform=self.transformation) 91 | 92 | if self.use_memory == 'yes' and self.num_samples > 0: 93 | indices=torch.randperm(len(self.train_set[task_id]))[:self.num_samples] 94 | rand_subset=torch.utils.data.Subset(self.train_set[task_id], indices) 95 | self.task_memory.append(rand_subset) 96 | 97 | else: 98 | if self.use_memory == 'yes' and self.num_samples > 0: 99 | current_dataset = PermutedMNIST(root=self.root, task_num=task_id, train=True, 100 | permute_idx=self.idx[task_id], transform=self.transformation) 101 | d = [] 102 | d.append(current_dataset) 103 | for m in self.task_memory: 104 | d.append(m) 105 | self.train_set[task_id] = torch.utils.data.ConcatDataset(d) 106 | 107 | indices=torch.randperm(len(current_dataset))[:self.num_samples] 108 | rand_subset=torch.utils.data.Subset(current_dataset, indices) 109 | self.task_memory.append(rand_subset) 110 | 111 | else: 112 | self.train_set[task_id] = PermutedMNIST(root=self.root, task_num=task_id, train=True, 113 | permute_idx=self.idx[task_id], transform=self.transformation) 114 | 115 | self.test_set[task_id] = PermutedMNIST(root=self.root, task_num=task_id, train=False, 116 | permute_idx=self.idx[task_id], transform=self.transformation) 117 | 118 | split = int(np.floor(self.pc_valid * len(self.train_set[task_id]))) 119 | train_split, valid_split = torch.utils.data.random_split(self.train_set[task_id], 120 | [len(self.train_set[task_id]) - split, split]) 121 | 122 | train_loader = torch.utils.data.DataLoader(train_split, batch_size=self.batch_size, 123 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 124 | valid_loader = torch.utils.data.DataLoader(valid_split, batch_size=self.batch_size, 125 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 126 | test_loader = torch.utils.data.DataLoader(self.test_set[task_id], batch_size=self.batch_size, 127 | num_workers=self.num_workers, pin_memory=self.pin_memory,shuffle=True) 128 | 129 | self.dataloaders[task_id]['train'] = train_loader 130 | self.dataloaders[task_id]['valid'] = valid_loader 131 | self.dataloaders[task_id]['test'] = test_loader 132 | self.dataloaders[task_id]['name'] = 'pmnist-{}'.format(task_id+1) 133 | 134 | print ("Training set size: {} images of {}x{}".format(len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 135 | print ("Validation set size: {} images of {}x{}".format(len(valid_loader.dataset),self.inputsize[1],self.inputsize[1])) 136 | print ("Train+Val set size: {} images of {}x{}".format(len(valid_loader.dataset)+len(train_loader.dataset),self.inputsize[1],self.inputsize[1])) 137 | print ("Test set size: {} images of {}x{}".format(len(test_loader.dataset),self.inputsize[1],self.inputsize[1])) 138 | 139 | return self.dataloaders 140 | 141 | 142 | def get_idx(self): 143 | for i in range(len(self.taskcla)): 144 | idx = list(range(self.inputsize[1] * self.inputsize[2])) 145 | self.idx[i] = shuffle(idx, random_state=self.seed * 100 + i) 146 | 147 | 148 | -------------------------------------------------------------------------------- /src/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # https://github.com/pytorch/vision/blob/8635be94d1216f10fb8302da89233bd86445e449/torchvision/datasets/utils.py 8 | 9 | import os 10 | import os.path 11 | import hashlib 12 | import gzip 13 | import errno 14 | import tarfile 15 | import zipfile 16 | import numpy as np 17 | import torch 18 | import codecs 19 | 20 | from torch.utils.model_zoo import tqdm 21 | 22 | 23 | def gen_bar_updater(): 24 | pbar = tqdm(total=None) 25 | 26 | def bar_update(count, block_size, total_size): 27 | if pbar.total is None and total_size: 28 | pbar.total = total_size 29 | progress_bytes = count * block_size 30 | pbar.update(progress_bytes - pbar.n) 31 | 32 | return bar_update 33 | 34 | 35 | def calculate_md5(fpath, chunk_size=1024 * 1024): 36 | md5 = hashlib.md5() 37 | with open(fpath, 'rb') as f: 38 | for chunk in iter(lambda: f.read(chunk_size), b''): 39 | md5.update(chunk) 40 | return md5.hexdigest() 41 | 42 | 43 | def check_md5(fpath, md5, **kwargs): 44 | return md5 == calculate_md5(fpath, **kwargs) 45 | 46 | 47 | def check_integrity(fpath, md5=None): 48 | if not os.path.isfile(fpath): 49 | return False 50 | if md5 is None: 51 | return True 52 | return check_md5(fpath, md5) 53 | 54 | 55 | def makedir_exist_ok(dirpath): 56 | """ 57 | Python2 support for os.makedirs(.., exist_ok=True) 58 | """ 59 | try: 60 | os.makedirs(dirpath) 61 | except OSError as e: 62 | if e.errno == errno.EEXIST: 63 | pass 64 | else: 65 | raise 66 | 67 | 68 | def download_url(url, root, filename=None, md5=None): 69 | """Download a file from a url and place it in root. 70 | 71 | Args: 72 | url (str): URL to download file from 73 | root (str): Directory to place downloaded file in 74 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 75 | md5 (str, optional): MD5 checksum of the download. If None, do not check 76 | """ 77 | from six.moves import urllib 78 | 79 | root = os.path.expanduser(root) 80 | if not filename: 81 | filename = os.path.basename(url) 82 | fpath = os.path.join(root, filename) 83 | 84 | makedir_exist_ok(root) 85 | 86 | # downloads file 87 | if check_integrity(fpath, md5): 88 | print('Using downloaded and verified file: ' + fpath) 89 | else: 90 | try: 91 | print('Downloading ' + url + ' to ' + fpath) 92 | urllib.request.urlretrieve( 93 | url, fpath, 94 | reporthook=gen_bar_updater() 95 | ) 96 | except (urllib.error.URLError, IOError) as e: 97 | if url[:5] == 'https': 98 | url = url.replace('https:', 'http:') 99 | print('Failed download. Trying https -> http instead.' 100 | ' Downloading ' + url + ' to ' + fpath) 101 | urllib.request.urlretrieve( 102 | url, fpath, 103 | reporthook=gen_bar_updater() 104 | ) 105 | else: 106 | raise e 107 | 108 | 109 | def list_dir(root, prefix=False): 110 | """List all directories at a given root 111 | 112 | Args: 113 | root (str): Path to directory whose folders need to be listed 114 | prefix (bool, optional): If true, prepends the path to each result, otherwise 115 | only returns the name of the directories found 116 | """ 117 | root = os.path.expanduser(root) 118 | directories = list( 119 | filter( 120 | lambda p: os.path.isdir(os.path.join(root, p)), 121 | os.listdir(root) 122 | ) 123 | ) 124 | 125 | if prefix is True: 126 | directories = [os.path.join(root, d) for d in directories] 127 | 128 | return directories 129 | 130 | 131 | def list_files(root, suffix, prefix=False): 132 | """List all files ending with a suffix at a given root 133 | 134 | Args: 135 | root (str): Path to directory whose folders need to be listed 136 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 137 | It uses the Python "str.endswith" method and is passed directly 138 | prefix (bool, optional): If true, prepends the path to each result, otherwise 139 | only returns the name of the files found 140 | """ 141 | root = os.path.expanduser(root) 142 | files = list( 143 | filter( 144 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 145 | os.listdir(root) 146 | ) 147 | ) 148 | 149 | if prefix is True: 150 | files = [os.path.join(root, d) for d in files] 151 | 152 | return files 153 | 154 | 155 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 156 | """Download a Google Drive file from and place it in root. 157 | 158 | Args: 159 | file_id (str): id of file to be downloaded 160 | root (str): Directory to place downloaded file in 161 | filename (str, optional): Name to save the file under. If None, use the id of the file. 162 | md5 (str, optional): MD5 checksum of the download. If None, do not check 163 | """ 164 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 165 | import requests 166 | url = "https://docs.google.com/uc?export=download" 167 | 168 | root = os.path.expanduser(root) 169 | if not filename: 170 | filename = file_id 171 | fpath = os.path.join(root, filename) 172 | 173 | makedir_exist_ok(root) 174 | 175 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 176 | print('Using downloaded and verified file: ' + fpath) 177 | else: 178 | session = requests.Session() 179 | 180 | response = session.get(url, params={'id': file_id}, stream=True) 181 | token = _get_confirm_token(response) 182 | 183 | if token: 184 | params = {'id': file_id, 'confirm': token} 185 | response = session.get(url, params=params, stream=True) 186 | 187 | _save_response_content(response, fpath) 188 | 189 | 190 | def _get_confirm_token(response): 191 | for key, value in response.cookies.items(): 192 | if key.startswith('download_warning'): 193 | return value 194 | 195 | return None 196 | 197 | 198 | def _save_response_content(response, destination, chunk_size=32768): 199 | with open(destination, "wb") as f: 200 | pbar = tqdm(total=None) 201 | progress = 0 202 | for chunk in response.iter_content(chunk_size): 203 | if chunk: # filter out keep-alive new chunks 204 | f.write(chunk) 205 | progress += len(chunk) 206 | pbar.update(progress - pbar.n) 207 | pbar.close() 208 | 209 | 210 | def _is_tar(filename): 211 | return filename.endswith(".tar") 212 | 213 | 214 | def _is_targz(filename): 215 | return filename.endswith(".tar.gz") 216 | 217 | 218 | def _is_gzip(filename): 219 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 220 | 221 | 222 | def _is_zip(filename): 223 | return filename.endswith(".zip") 224 | 225 | 226 | def extract_archive(from_path, to_path=None, remove_finished=False): 227 | if to_path is None: 228 | to_path = os.path.dirname(from_path) 229 | 230 | if _is_tar(from_path): 231 | with tarfile.open(from_path, 'r') as tar: 232 | tar.extractall(path=to_path) 233 | elif _is_targz(from_path): 234 | with tarfile.open(from_path, 'r:gz') as tar: 235 | tar.extractall(path=to_path) 236 | elif _is_gzip(from_path): 237 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 238 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 239 | out_f.write(zip_f.read()) 240 | elif _is_zip(from_path): 241 | with zipfile.ZipFile(from_path, 'r') as z: 242 | z.extractall(to_path) 243 | else: 244 | raise ValueError("Extraction of {} not supported".format(from_path)) 245 | 246 | if remove_finished: 247 | os.remove(from_path) 248 | 249 | 250 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None, 251 | md5=None, remove_finished=False): 252 | download_root = os.path.expanduser(download_root) 253 | if extract_root is None: 254 | extract_root = download_root 255 | if not filename: 256 | filename = os.path.basename(url) 257 | 258 | download_url(url, download_root, filename, md5) 259 | 260 | archive = os.path.join(download_root, filename) 261 | print("Extracting {} to {}".format(archive, extract_root)) 262 | extract_archive(archive, extract_root, remove_finished) 263 | 264 | 265 | def iterable_to_str(iterable): 266 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 267 | 268 | 269 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): 270 | if not isinstance(value, torch._six.string_classes): 271 | if arg is None: 272 | msg = "Expected type str, but got type {type}." 273 | else: 274 | msg = "Expected type str for argument {arg}, but got type {type}." 275 | msg = msg.format(type=type(value), arg=arg) 276 | raise ValueError(msg) 277 | 278 | if valid_values is None: 279 | return value 280 | 281 | if value not in valid_values: 282 | if custom_msg is not None: 283 | msg = custom_msg 284 | else: 285 | msg = ("Unknown value '{value}' for argument {arg}. " 286 | "Valid values are {{{valid_values}}}.") 287 | msg = msg.format(value=value, arg=arg, 288 | valid_values=iterable_to_str(valid_values)) 289 | raise ValueError(msg) 290 | 291 | return value 292 | 293 | 294 | def get_int(b): 295 | return int(codecs.encode(b, 'hex'), 16) 296 | 297 | 298 | def open_maybe_compressed_file(path): 299 | """Return a file object that possibly decompresses 'path' on the fly. 300 | Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. 301 | """ 302 | if not isinstance(path, torch._six.string_classes): 303 | return path 304 | if path.endswith('.gz'): 305 | import gzip 306 | return gzip.open(path, 'rb') 307 | if path.endswith('.xz'): 308 | import lzma 309 | return lzma.open(path, 'rb') 310 | return open(path, 'rb') 311 | 312 | 313 | def read_sn3_pascalvincent_tensor(path, strict=True): 314 | """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). 315 | Argument may be a filename, compressed filename, or file object. 316 | """ 317 | # typemap 318 | if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): 319 | read_sn3_pascalvincent_tensor.typemap = { 320 | 8: (torch.uint8, np.uint8, np.uint8), 321 | 9: (torch.int8, np.int8, np.int8), 322 | 11: (torch.int16, np.dtype('>i2'), 'i2'), 323 | 12: (torch.int32, np.dtype('>i4'), 'i4'), 324 | 13: (torch.float32, np.dtype('>f4'), 'f4'), 325 | 14: (torch.float64, np.dtype('>f8'), 'f8')} 326 | # read 327 | with open_maybe_compressed_file(path) as f: 328 | data = f.read() 329 | # parse 330 | magic = get_int(data[0:4]) 331 | nd = magic % 256 332 | ty = magic // 256 333 | assert nd >= 1 and nd <= 3 334 | assert ty >= 8 and ty <= 14 335 | m = read_sn3_pascalvincent_tensor.typemap[ty] 336 | s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] 337 | parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) 338 | assert parsed.shape[0] == np.prod(s) or not strict 339 | return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) 340 | 341 | 342 | def read_label_file(path): 343 | with open(path, 'rb') as f: 344 | x = read_sn3_pascalvincent_tensor(f, strict=False) 345 | assert(x.dtype == torch.uint8) 346 | assert(x.ndimension() == 1) 347 | return x.long() 348 | 349 | 350 | def read_image_file(path): 351 | with open(path, 'rb') as f: 352 | x = read_sn3_pascalvincent_tensor(f, strict=False) 353 | assert(x.dtype == torch.uint8) 354 | assert(x.ndimension() == 3) 355 | return x -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os,argparse,time 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | 10 | import torch 11 | 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import utils 17 | 18 | tstart=time.time() 19 | 20 | 21 | # Arguments 22 | parser = argparse.ArgumentParser(description='Adversarial Continual Learning...') 23 | # Load the config file 24 | parser.add_argument('--config', type=str, default='./configs/config_mnist5.yml') 25 | flags = parser.parse_args() 26 | args = OmegaConf.load(flags.config) 27 | 28 | print() 29 | 30 | 31 | ######################################################################################################################## 32 | 33 | # Args -- Experiment 34 | if args.experiment=='pmnist': 35 | from dataloaders import pmnist as datagenerator 36 | elif args.experiment=='mnist5': 37 | from dataloaders import mnist5 as datagenerator 38 | elif args.experiment=='cifar100': 39 | from dataloaders import cifar100 as datagenerator 40 | elif args.experiment=='miniimagenet': 41 | from dataloaders import miniimagenet as datagenerator 42 | elif args.experiment=='multidatasets': 43 | from dataloaders import mulitidatasets as datagenerator 44 | else: 45 | raise NotImplementedError 46 | 47 | from acl import ACL as approach 48 | 49 | # Args -- Network 50 | if args.experiment == 'mnist5' or args.experiment == 'pmnist': 51 | from networks import mlp_acl as network 52 | elif args.experiment == 'cifar100' or args.experiment == 'miniimagenet' or args.experiment == 'multidatasets': 53 | from networks import alexnet_acl as network 54 | else: 55 | raise NotImplementedError 56 | 57 | ######################################################################################################################## 58 | 59 | def run(args, run_id): 60 | 61 | np.random.seed(args.seed) 62 | torch.manual_seed(args.seed) 63 | if torch.cuda.is_available(): 64 | torch.cuda.manual_seed(args.seed) 65 | 66 | # Faster run but not deterministic: 67 | # torch.backends.cudnn.benchmark = True 68 | 69 | # To get deterministic results that match with paper at cost of lower speed: 70 | torch.backends.cudnn.deterministic = True 71 | torch.backends.cudnn.benchmark = False 72 | 73 | # Data loader 74 | print('Instantiate data generators and model...') 75 | dataloader = datagenerator.DatasetGen(args) 76 | args.taskcla, args.inputsize = dataloader.taskcla, dataloader.inputsize 77 | if args.experiment == 'multidatasets': args.lrs = dataloader.lrs 78 | 79 | # Model 80 | net = network.Net(args) 81 | net = net.to(args.device) 82 | 83 | net.print_model_size() 84 | # print (net) 85 | 86 | # Approach 87 | appr=approach(net,args,network=network) 88 | 89 | # Loop tasks 90 | acc=np.zeros((len(args.taskcla),len(args.taskcla)),dtype=np.float32) 91 | lss=np.zeros((len(args.taskcla),len(args.taskcla)),dtype=np.float32) 92 | 93 | for t,ncla in args.taskcla: 94 | 95 | print('*'*250) 96 | dataset = dataloader.get(t) 97 | print(' '*105, 'Dataset {:2d} ({:s})'.format(t+1,dataset[t]['name'])) 98 | print('*'*250) 99 | 100 | # Train 101 | appr.train(t,dataset[t]) 102 | print('-'*250) 103 | print() 104 | 105 | for u in range(t+1): 106 | # Load previous model and replace the shared module with the current one 107 | test_model = appr.load_model(u) 108 | test_res = appr.test(dataset[u]['test'], u, model=test_model) 109 | 110 | print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<'.format(u, dataset[u]['name'], 111 | test_res['loss_t'], 112 | test_res['acc_t'])) 113 | 114 | 115 | acc[t, u] = test_res['acc_t'] 116 | lss[t, u] = test_res['loss_t'] 117 | 118 | 119 | # Save 120 | print() 121 | print('Saved accuracies at '+os.path.join(args.checkpoint,args.output)) 122 | np.savetxt(os.path.join(args.checkpoint,args.output),acc,'%.6f') 123 | 124 | # Extract embeddings to plot in tensorboard for miniimagenet 125 | if args.tsne == 'yes' and args.experiment == 'miniimagenet': 126 | appr.get_tsne_embeddings_first_ten_tasks(dataset, model=appr.load_model(t)) 127 | appr.get_tsne_embeddings_last_three_tasks(dataset, model=appr.load_model(t)) 128 | 129 | avg_acc, gem_bwt = utils.print_log_acc_bwt(args.taskcla, acc, lss, output_path=args.checkpoint, run_id=run_id) 130 | 131 | return avg_acc, gem_bwt 132 | 133 | 134 | 135 | ####################################################################################################################### 136 | 137 | 138 | def main(args): 139 | utils.make_directories(args) 140 | utils.some_sanity_checks(args) 141 | utils.save_code(args) 142 | 143 | print('=' * 100) 144 | print('Arguments =') 145 | for arg in vars(args): 146 | print('\t' + arg + ':', getattr(args, arg)) 147 | print('=' * 100) 148 | 149 | 150 | accuracies, forgetting = [], [] 151 | for n in range(args.num_runs): 152 | args.seed = n 153 | args.output = '{}_{}_tasks_seed_{}.txt'.format(args.experiment, args.ntasks, args.seed) 154 | print ("args.output: ", args.output) 155 | 156 | print (" >>>> Run #", n) 157 | acc, bwt = run(args, n) 158 | accuracies.append(acc) 159 | forgetting.append(bwt) 160 | 161 | 162 | print('*' * 100) 163 | print ("Average over {} runs: ".format(args.num_runs)) 164 | print ('AVG ACC: {:5.4f}% \pm {:5.4f}'.format(np.array(accuracies).mean(), np.array(accuracies).std())) 165 | print ('AVG BWT: {:5.2f}% \pm {:5.4f}'.format(np.array(forgetting).mean(), np.array(forgetting).std())) 166 | 167 | 168 | print ("All Done! ") 169 | print('[Elapsed time = {:.1f} min]'.format((time.time()-tstart)/(60))) 170 | utils.print_time() 171 | 172 | 173 | ####################################################################################################################### 174 | 175 | if __name__ == '__main__': 176 | main(args) 177 | -------------------------------------------------------------------------------- /src/networks/alexnet_acl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import utils 8 | 9 | class Shared(torch.nn.Module): 10 | 11 | def __init__(self,args): 12 | super(Shared, self).__init__() 13 | 14 | self.ncha,size,_=args.inputsize 15 | self.taskcla=args.taskcla 16 | self.latent_dim = args.latent_dim 17 | 18 | if args.experiment == 'cifar100': 19 | hiddens = [64, 128, 256, 1024, 1024, 512] 20 | 21 | elif args.experiment == 'miniimagenet': 22 | hiddens = [64, 128, 256, 512, 512, 512] 23 | 24 | # ---------------------------------- 25 | elif args.experiment == 'multidatasets': 26 | hiddens = [64, 128, 256, 1024, 1024, 512] 27 | 28 | else: 29 | raise NotImplementedError 30 | 31 | self.conv1=torch.nn.Conv2d(self.ncha,hiddens[0],kernel_size=size//8) 32 | s=utils.compute_conv_output_size(size,size//8) 33 | s=s//2 34 | self.conv2=torch.nn.Conv2d(hiddens[0],hiddens[1],kernel_size=size//10) 35 | s=utils.compute_conv_output_size(s,size//10) 36 | s=s//2 37 | self.conv3=torch.nn.Conv2d(hiddens[1],hiddens[2],kernel_size=2) 38 | s=utils.compute_conv_output_size(s,2) 39 | s=s//2 40 | self.maxpool=torch.nn.MaxPool2d(2) 41 | self.relu=torch.nn.ReLU() 42 | 43 | self.drop1=torch.nn.Dropout(0.2) 44 | self.drop2=torch.nn.Dropout(0.5) 45 | self.fc1=torch.nn.Linear(hiddens[2]*s*s,hiddens[3]) 46 | self.fc2=torch.nn.Linear(hiddens[3],hiddens[4]) 47 | self.fc3=torch.nn.Linear(hiddens[4],hiddens[5]) 48 | self.fc4=torch.nn.Linear(hiddens[5], self.latent_dim) 49 | 50 | 51 | def forward(self, x_s): 52 | x_s = x_s.view_as(x_s) 53 | h = self.maxpool(self.drop1(self.relu(self.conv1(x_s)))) 54 | h = self.maxpool(self.drop1(self.relu(self.conv2(h)))) 55 | h = self.maxpool(self.drop2(self.relu(self.conv3(h)))) 56 | h = h.view(x_s.size(0), -1) 57 | h = self.drop2(self.relu(self.fc1(h))) 58 | h = self.drop2(self.relu(self.fc2(h))) 59 | h = self.drop2(self.relu(self.fc3(h))) 60 | h = self.drop2(self.relu(self.fc4(h))) 61 | return h 62 | 63 | 64 | 65 | class Private(torch.nn.Module): 66 | def __init__(self, args): 67 | super(Private, self).__init__() 68 | 69 | self.ncha,self.size,_=args.inputsize 70 | self.taskcla=args.taskcla 71 | self.latent_dim = args.latent_dim 72 | self.num_tasks = args.ntasks 73 | self.device = args.device 74 | 75 | if args.experiment == 'cifar100': 76 | hiddens=[32,32] 77 | flatten=1152 78 | 79 | elif args.experiment == 'miniimagenet': 80 | # hiddens=[8,8] 81 | # flatten=1800 82 | hiddens=[16,16] 83 | flatten=3600 84 | 85 | 86 | elif args.experiment == 'multidatasets': 87 | hiddens=[32,32] 88 | flatten=1152 89 | 90 | 91 | else: 92 | raise NotImplementedError 93 | 94 | 95 | self.task_out = torch.nn.ModuleList() 96 | for _ in range(self.num_tasks): 97 | self.conv = torch.nn.Sequential() 98 | self.conv.add_module('conv1',torch.nn.Conv2d(self.ncha, hiddens[0], kernel_size=self.size // 8)) 99 | self.conv.add_module('relu1', torch.nn.ReLU(inplace=True)) 100 | self.conv.add_module('drop1', torch.nn.Dropout(0.2)) 101 | self.conv.add_module('maxpool1', torch.nn.MaxPool2d(2)) 102 | self.conv.add_module('conv2', torch.nn.Conv2d(hiddens[0], hiddens[1], kernel_size=self.size // 10)) 103 | self.conv.add_module('relu2', torch.nn.ReLU(inplace=True)) 104 | self.conv.add_module('dropout2', torch.nn.Dropout(0.5)) 105 | self.conv.add_module('maxpool2', torch.nn.MaxPool2d(2)) 106 | self.task_out.append(self.conv) 107 | self.linear = torch.nn.Sequential() 108 | 109 | self.linear.add_module('linear1', torch.nn.Linear(flatten,self.latent_dim)) 110 | self.linear.add_module('relu3', torch.nn.ReLU(inplace=True)) 111 | self.task_out.append(self.linear) 112 | 113 | 114 | def forward(self, x, task_id): 115 | x = x.view_as(x) 116 | out = self.task_out[2*task_id].forward(x) 117 | out = out.view(out.size(0),-1) 118 | out = self.task_out[2*task_id+1].forward(out) 119 | return out 120 | 121 | 122 | 123 | class Net(torch.nn.Module): 124 | 125 | def __init__(self, args): 126 | super(Net, self).__init__() 127 | self.ncha,size,_=args.inputsize 128 | self.taskcla=args.taskcla 129 | self.latent_dim = args.latent_dim 130 | self.num_tasks = args.ntasks 131 | self.samples = args.samples 132 | self.image_size = self.ncha*size*size 133 | self.args=args 134 | 135 | self.hidden1 = args.head_units 136 | self.hidden2 = args.head_units//2 137 | 138 | self.shared = Shared(args) 139 | self.private = Private(args) 140 | 141 | self.head = torch.nn.ModuleList() 142 | for i in range(self.num_tasks): 143 | self.head.append( 144 | torch.nn.Sequential( 145 | torch.nn.Linear(2*self.latent_dim, self.hidden1), 146 | torch.nn.ReLU(inplace=True), 147 | torch.nn.Dropout(), 148 | torch.nn.Linear(self.hidden1, self.hidden2), 149 | torch.nn.ReLU(inplace=True), 150 | torch.nn.Linear(self.hidden2, self.taskcla[i][1]) 151 | )) 152 | 153 | 154 | def forward(self, x_s, x_p, tt, task_id): 155 | 156 | x_s = x_s.view_as(x_s) 157 | x_p = x_p.view_as(x_p) 158 | 159 | x_s = self.shared(x_s) 160 | x_p = self.private(x_p, task_id) 161 | 162 | x = torch.cat([x_p, x_s], dim=1) 163 | 164 | if self.args.experiment == 'multidatasets': 165 | # if no memory is used this is faster: 166 | y=[] 167 | for i,_ in self.taskcla: 168 | y.append(self.head[i](x)) 169 | return y[task_id] 170 | else: 171 | return torch.stack([self.head[tt[i]].forward(x[i]) for i in range(x.size(0))]) 172 | 173 | 174 | def get_encoded_ftrs(self, x_s, x_p, task_id): 175 | return self.shared(x_s), self.private(x_p, task_id) 176 | 177 | def print_model_size(self): 178 | count_P = sum(p.numel() for p in self.private.parameters() if p.requires_grad) 179 | count_S = sum(p.numel() for p in self.shared.parameters() if p.requires_grad) 180 | count_H = sum(p.numel() for p in self.head.parameters() if p.requires_grad) 181 | 182 | print('Num parameters in S = %s ' % (self.pretty_print(count_S))) 183 | print('Num parameters in P = %s, per task = %s ' % (self.pretty_print(count_P),self.pretty_print(count_P/self.num_tasks))) 184 | print('Num parameters in p = %s, per task = %s ' % (self.pretty_print(count_H),self.pretty_print(count_H/self.num_tasks))) 185 | print('Num parameters in P+p = %s ' % self.pretty_print(count_P+count_H)) 186 | print('--------------------------> Architecture size: %s parameters (%sB)' % (self.pretty_print(count_S + count_P + count_H), 187 | self.pretty_print(4*(count_S + count_P + count_H)))) 188 | 189 | print("--------------------------> Memory size: %s samples per task (%sB)" % (self.samples, 190 | self.pretty_print(self.num_tasks*4*self.samples*self.image_size))) 191 | print("------------------------------------------------------------------------------") 192 | print(" TOTAL: %sB" % self.pretty_print(4*(count_S + count_P + count_H)+self.num_tasks*4*self.samples*self.image_size)) 193 | 194 | def pretty_print(self, num): 195 | magnitude=0 196 | while abs(num) >= 1000: 197 | magnitude+=1 198 | num/=1000.0 199 | return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 200 | 201 | -------------------------------------------------------------------------------- /src/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import utils 8 | 9 | class Discriminator(torch.nn.Module): 10 | def __init__(self,args,task_id): 11 | super(Discriminator, self).__init__() 12 | 13 | self.num_tasks=args.ntasks 14 | self.units=args.units 15 | self.latent_dim=args.latent_dim 16 | 17 | 18 | if args.diff == 'yes': 19 | self.dis = torch.nn.Sequential( 20 | GradientReversal(args.lam), 21 | torch.nn.Linear(self.latent_dim, args.units), 22 | torch.nn.LeakyReLU(), 23 | torch.nn.Linear(args.units, args.units), 24 | torch.nn.Linear(args.units, task_id + 2) 25 | ) 26 | else: 27 | self.dis = torch.nn.Sequential( 28 | torch.nn.Linear(self.latent_dim, args.units), 29 | torch.nn.LeakyReLU(), 30 | torch.nn.Linear(args.units, args.units), 31 | torch.nn.Linear(args.units, task_id + 2) 32 | ) 33 | 34 | 35 | def forward(self, z, labels, task_id): 36 | return self.dis(z) 37 | 38 | def pretty_print(self, num): 39 | magnitude=0 40 | while abs(num) >= 1000: 41 | magnitude+=1 42 | num/=1000.0 43 | return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 44 | 45 | 46 | def get_size(self): 47 | count=sum(p.numel() for p in self.dis.parameters() if p.requires_grad) 48 | print('Num parameters in D = %s ' % (self.pretty_print(count))) 49 | 50 | 51 | class GradientReversalFunction(torch.autograd.Function): 52 | """ 53 | From: 54 | https://github.com/jvanvugt/pytorch-domain-adaptation/blob/cb65581f20b71ff9883dd2435b2275a1fd4b90df/utils.py#L26 55 | 56 | Gradient Reversal Layer from: 57 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 58 | Forward pass is the identity function. In the backward pass, 59 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 60 | """ 61 | 62 | @staticmethod 63 | def forward(ctx, x, lambda_): 64 | ctx.lambda_ = lambda_ 65 | return x.clone() 66 | 67 | @staticmethod 68 | def backward(ctx, grads): 69 | lambda_ = ctx.lambda_ 70 | lambda_ = grads.new_tensor(lambda_) 71 | dx = -lambda_ * grads 72 | return dx, None 73 | 74 | 75 | class GradientReversal(torch.nn.Module): 76 | def __init__(self, lambda_): 77 | super(GradientReversal, self).__init__() 78 | self.lambda_ = lambda_ 79 | 80 | def forward(self, x): 81 | return GradientReversalFunction.apply(x, self.lambda_) -------------------------------------------------------------------------------- /src/networks/mlp_acl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class Private(torch.nn.Module): 10 | def __init__(self, args): 11 | super(Private, self).__init__() 12 | 13 | self.ncha,self.size,_=args.inputsize 14 | self.taskcla=args.taskcla 15 | self.latent_dim = args.latent_dim 16 | self.num_tasks = args.ntasks 17 | self.nhid = args.units 18 | self.device = args.device 19 | 20 | self.task_out = torch.nn.ModuleList() 21 | for _ in range(self.num_tasks): 22 | self.linear = torch.nn.Sequential() 23 | self.linear.add_module('linear', torch.nn.Linear(self.ncha*self.size*self.size, self.latent_dim)) 24 | self.linear.add_module('relu', torch.nn.ReLU(inplace=True)) 25 | self.task_out.append(self.linear) 26 | 27 | def forward(self, x_p, task_id): 28 | x_p = x_p.view(x_p.size(0), -1) 29 | return self.task_out[task_id].forward(x_p) 30 | 31 | 32 | 33 | class Shared(torch.nn.Module): 34 | 35 | def __init__(self,args): 36 | super(Shared, self).__init__() 37 | 38 | ncha,self.size,_=args.inputsize 39 | self.taskcla=args.taskcla 40 | self.latent_dim = args.latent_dim 41 | self.nhid = args.units 42 | self.nlayers = args.nlayers 43 | 44 | self.relu=torch.nn.ReLU() 45 | self.drop=torch.nn.Dropout(0.2) 46 | self.fc1=torch.nn.Linear(ncha*self.size*self.size, self.nhid) 47 | 48 | if self.nlayers == 3: 49 | self.fc2 = torch.nn.Linear(self.nhid, self.nhid) 50 | self.fc3=torch.nn.Linear(self.nhid,self.latent_dim) 51 | else: 52 | self.fc2 = torch.nn.Linear(self.nhid,self.latent_dim) 53 | 54 | def forward(self, x_s): 55 | 56 | h = x_s.view(x_s.size(0), -1) 57 | h = self.drop(self.relu(self.fc1(h))) 58 | h = self.drop(self.relu(self.fc2(h))) 59 | if self.nlayers == 3: 60 | h = self.drop(self.relu(self.fc3(h))) 61 | 62 | return h 63 | 64 | 65 | class Net(torch.nn.Module): 66 | 67 | def __init__(self, args): 68 | super(Net, self).__init__() 69 | ncha,size,_=args.inputsize 70 | self.taskcla=args.taskcla 71 | self.latent_dim = args.latent_dim 72 | self.num_tasks = args.ntasks 73 | self.device = args.device 74 | 75 | if args.experiment == 'mnist5': 76 | self.hidden1 = 28 77 | self.hidden2 = 14 78 | elif args.experiment == 'pmnist': 79 | self.hidden1 = 28 80 | self.hidden2 = 28 81 | 82 | self.samples = args.samples 83 | 84 | self.shared = Shared(args) 85 | self.private = Private(args) 86 | 87 | self.head = torch.nn.ModuleList() 88 | for i in range(self.num_tasks): 89 | self.head.append( 90 | torch.nn.Sequential( 91 | torch.nn.Linear(2 * self.latent_dim, self.hidden1), 92 | torch.nn.ReLU(inplace=True), 93 | torch.nn.Dropout(), 94 | torch.nn.Linear(self.hidden1, self.hidden2), 95 | torch.nn.ReLU(inplace=True), 96 | torch.nn.Linear(self.hidden2, self.taskcla[i][1]) 97 | )) 98 | 99 | def forward(self,x_s, x_p, tt, task_id): 100 | 101 | h_s = x_s.view(x_s.size(0), -1) 102 | h_p = x_s.view(x_p.size(0), -1) 103 | 104 | x_s = self.shared(h_s) 105 | x_p = self.private(h_p, task_id) 106 | 107 | x = torch.cat([x_p, x_s], dim=1) 108 | 109 | return torch.stack([self.head[tt[i]].forward(x[i]) for i in range(x.size(0))]) 110 | 111 | 112 | def get_encoded_ftrs(self, x_s, x_p, task_id): 113 | return self.shared(x_s), self.private(x_p, task_id) 114 | 115 | 116 | def print_model_size(self): 117 | count_P = sum(p.numel() for p in self.private.parameters() if p.requires_grad) 118 | count_S = sum(p.numel() for p in self.shared.parameters() if p.requires_grad) 119 | count_H = sum(p.numel() for p in self.head.parameters() if p.requires_grad) 120 | 121 | print('Num parameters in S = %s ' % (self.pretty_print(count_S))) 122 | print('Num parameters in P = %s, per task = %s ' % (self.pretty_print(count_P),self.pretty_print(count_P/self.num_tasks))) 123 | print('Num parameters in p = %s, per task = %s ' % (self.pretty_print(count_H),self.pretty_print(count_H/self.num_tasks))) 124 | print('Num parameters in P+p = %s ' % self.pretty_print(count_P+count_H)) 125 | print('--------------------------> Total architecture size: %s parameters (%sB)' % (self.pretty_print(count_S + count_P + count_H), 126 | self.pretty_print(4*(count_S + count_P + count_H)))) 127 | 128 | def pretty_print(self, num): 129 | magnitude=0 130 | while abs(num) >= 1000: 131 | magnitude+=1 132 | num/=1000.0 133 | return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) 134 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | from copy import deepcopy 9 | import pickle 10 | import time 11 | import uuid 12 | from subprocess import call 13 | ######################################################################################################################## 14 | 15 | def human_format(num): 16 | magnitude=0 17 | while abs(num)>=1000: 18 | magnitude+=1 19 | num/=1000.0 20 | return '%.1f%s'%(num,['','K','M','G','T','P'][magnitude]) 21 | 22 | 23 | def report_tr(res, e, sbatch, clock0, clock1): 24 | # Training performance 25 | print( 26 | '| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train losses={:.3f} | T: loss={:.3f}, acc={:5.2f}% | D: loss={:.3f}, acc={:5.1f}%, ' 27 | 'Diff loss:{:.3f} |'.format( 28 | e + 1, 29 | 1000 * sbatch * (clock1 - clock0) / res['size'], 30 | 1000 * sbatch * (time.time() - clock1) / res['size'], res['loss_tot'], 31 | res['loss_t'], res['acc_t'], res['loss_a'], res['acc_d'], res['loss_d']), end='') 32 | 33 | def report_val(res): 34 | # Validation performance 35 | print(' Valid losses={:.3f} | T: loss={:.6f}, acc={:5.2f}%, | D: loss={:.3f}, acc={:5.2f}%, Diff loss={:.3f} |'.format( 36 | res['loss_tot'], res['loss_t'], res['acc_t'], res['loss_a'], res['acc_d'], res['loss_d']), end='') 37 | 38 | 39 | ######################################################################################################################## 40 | 41 | def get_model(model): 42 | return deepcopy(model.state_dict()) 43 | 44 | ######################################################################################################################## 45 | 46 | def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1): 47 | return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1)) 48 | 49 | ######################################################################################################################## 50 | 51 | def save_print_log(taskcla, acc, lss, output_path): 52 | 53 | print('*'*100) 54 | print('Accuracies =') 55 | for i in range(acc.shape[0]): 56 | print('\t',end=',') 57 | for j in range(acc.shape[1]): 58 | print('{:5.4f}% '.format(acc[i,j]),end=',') 59 | print() 60 | print ('ACC: {:5.4f}%'.format((np.mean(acc[acc.shape[0]-1,:])))) 61 | print() 62 | 63 | print ('BWD Transfer = ') 64 | 65 | print () 66 | print ("Diagonal R_ii") 67 | for i in range(acc.shape[0]): 68 | print('\t',end='') 69 | print('{:5.2f}% '.format(np.diag(acc)[i]), end=',') 70 | 71 | 72 | print() 73 | print ("Last row") 74 | for i in range(acc.shape[0]): 75 | print('\t', end=',') 76 | print('{:5.2f}% '.format(acc[-1][i]), end=',') 77 | 78 | print() 79 | # BWT calculated based on GEM paper (https://arxiv.org/abs/1706.08840) 80 | gem_bwt = sum(acc[-1]-np.diag(acc))/ (len(acc[-1])-1) 81 | # BWT calculated based on our UCB paper (https://openreview.net/pdf?id=HklUCCVKDB) 82 | ucb_bwt = (acc[-1] - np.diag(acc)).mean() 83 | print ('BWT: {:5.2f}%'.format(gem_bwt)) 84 | # print ('BWT (UCB paper): {:5.2f}%'.format(ucb_bwt)) 85 | 86 | print('*'*100) 87 | print('Done!') 88 | 89 | 90 | logs = {} 91 | # save results 92 | logs['name'] = output_path 93 | logs['taskcla'] = taskcla 94 | logs['acc'] = acc 95 | logs['loss'] = lss 96 | logs['gem_bwt'] = gem_bwt 97 | logs['ucb_bwt'] = ucb_bwt 98 | logs['rii'] = np.diag(acc) 99 | logs['rij'] = acc[-1] 100 | 101 | # pickle 102 | with open(os.path.join(output_path, 'logs.p'), 'wb') as output: 103 | pickle.dump(logs, output) 104 | 105 | print ("Log file saved in ", os.path.join(output_path, 'logs.p')) 106 | 107 | 108 | def print_log_acc_bwt(taskcla, acc, lss, output_path, run_id): 109 | 110 | print('*'*100) 111 | print('Accuracies =') 112 | for i in range(acc.shape[0]): 113 | print('\t',end=',') 114 | for j in range(acc.shape[1]): 115 | print('{:5.4f}% '.format(acc[i,j]),end=',') 116 | print() 117 | 118 | avg_acc = np.mean(acc[acc.shape[0]-1,:]) 119 | print ('ACC: {:5.4f}%'.format(avg_acc)) 120 | print() 121 | print() 122 | # BWT calculated based on GEM paper (https://arxiv.org/abs/1706.08840) 123 | gem_bwt = sum(acc[-1]-np.diag(acc))/ (len(acc[-1])-1) 124 | # BWT calculated based on UCB paper (https://arxiv.org/abs/1906.02425) 125 | ucb_bwt = (acc[-1] - np.diag(acc)).mean() 126 | print ('BWT: {:5.2f}%'.format(gem_bwt)) 127 | # print ('BWT (UCB paper): {:5.2f}%'.format(ucb_bwt)) 128 | 129 | print('*'*100) 130 | print('Done!') 131 | 132 | 133 | logs = {} 134 | # save results 135 | logs['name'] = output_path 136 | logs['taskcla'] = taskcla 137 | logs['acc'] = acc 138 | logs['loss'] = lss 139 | logs['gem_bwt'] = gem_bwt 140 | logs['ucb_bwt'] = ucb_bwt 141 | logs['rii'] = np.diag(acc) 142 | logs['rij'] = acc[-1] 143 | 144 | # pickle 145 | path = os.path.join(output_path, 'logs_run_id_{}.p'.format(run_id)) 146 | with open(path, 'wb') as output: 147 | pickle.dump(logs, output) 148 | 149 | print ("Log file saved in ", path) 150 | return avg_acc, gem_bwt 151 | 152 | 153 | def print_running_acc_bwt(acc, task_num): 154 | print() 155 | acc = acc[:task_num+1,:task_num+1] 156 | avg_acc = np.mean(acc[acc.shape[0] - 1, :]) 157 | gem_bwt = sum(acc[-1] - np.diag(acc)) / (len(acc[-1]) - 1) 158 | print('ACC: {:5.4f}% || BWT: {:5.2f}% '.format(avg_acc, gem_bwt)) 159 | print() 160 | 161 | 162 | def make_directories(args): 163 | uid = uuid.uuid4().hex 164 | if args.checkpoint is None: 165 | os.mkdir('checkpoints') 166 | args.checkpoint = os.path.join('./checkpoints/',uid) 167 | os.mkdir(args.checkpoint) 168 | else: 169 | if not os.path.exists(args.checkpoint): 170 | os.mkdir(args.checkpoint) 171 | args.checkpoint = os.path.join(args.checkpoint, uid) 172 | os.mkdir(args.checkpoint) 173 | 174 | 175 | 176 | 177 | def some_sanity_checks(args): 178 | # Making sure the chosen experiment matches with the number of tasks performed in the paper: 179 | datasets_tasks = {} 180 | datasets_tasks['mnist5']=[5] 181 | datasets_tasks['pmnist']=[10,20,30,40] 182 | datasets_tasks['cifar100']=[20] 183 | datasets_tasks['miniimagenet']=[20] 184 | datasets_tasks['multidatasets']=[5] 185 | 186 | 187 | if not args.ntasks in datasets_tasks[args.experiment]: 188 | raise Exception("Chosen number of tasks ({}) does not match with {} experiment".format(args.ntasks,args.experiment)) 189 | 190 | # Making sure if memory usage is happenning: 191 | if args.use_memory == 'yes' and not args.samples > 0: 192 | raise Exception("Flags required to use memory: --use_memory yes --samples n where n>0") 193 | 194 | if args.use_memory == 'no' and args.samples > 0: 195 | raise Exception("Flags required to use memory: --use_memory yes --samples n where n>0") 196 | 197 | 198 | 199 | def save_code(args): 200 | cwd = os.getcwd() 201 | des = os.path.join(args.checkpoint, 'code') + '/' 202 | if not os.path.exists(des): 203 | os.mkdir(des) 204 | 205 | def get_folder(folder): 206 | return os.path.join(cwd,folder) 207 | 208 | folders = [get_folder(item) for item in ['dataloaders', 'networks', 'configs', 'main.py', 'acl.py', 'utils.py']] 209 | 210 | for folder in folders: 211 | call('cp -rf {} {}'.format(folder, des),shell=True) 212 | 213 | 214 | def print_time(): 215 | from datetime import datetime 216 | 217 | # datetime object containing current date and time 218 | now = datetime.now() 219 | 220 | # dd/mm/YY H:M:S 221 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 222 | print("Job finished at =", dt_string) 223 | 224 | --------------------------------------------------------------------------------