├── utils ├── euclid2polar.py ├── __init__.py ├── mixup_utils.py ├── conf.py ├── batch_norm.py ├── metrics.py ├── plot_lamda_acc.py ├── plot_perturbed_angular_acc.py ├── plot_mixup_acc.py ├── plot_training_gradient.py ├── plot_all_method_acc.py ├── continual_training.py ├── status.py ├── tb_logger.py ├── debug.py ├── simclrloss.py ├── augmentations.py ├── feature_gradient_svd.py ├── plot_acc_curve.py ├── method_variation.py ├── ring_buffer.py ├── args.py ├── plot_test_variance.py ├── main.py ├── loggers.py ├── gss_buffer.py └── vmf_sampling.py ├── backbone ├── utils │ ├── __init__.py │ └── modules.py └── __init__.py ├── datasets ├── utils │ ├── __init__.py │ ├── gcl_dataset.py │ ├── validation.py │ └── continual_dataset.py ├── transforms │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── rotation.cpython-36.pyc │ │ ├── rotation.cpython-37.pyc │ │ ├── permutation.cpython-36.pyc │ │ ├── permutation.cpython-37.pyc │ │ ├── denormalization.cpython-36.pyc │ │ └── denormalization.cpython-37.pyc │ ├── denormalization.py │ ├── permutation.py │ └── rotation.py ├── rot_mnist.py ├── __init__.py ├── seq_mnist.py ├── perm_mnist.py ├── seq_cifar10.py ├── seq_cifar100.py └── seq_tinyimagenet.py ├── models ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── continual_model.cpython-36.pyc │ │ └── continual_model.cpython-37.pyc │ └── continual_model.py ├── __init__.py ├── sgd.py ├── der.py ├── joint_gcl.py ├── agem_r.py ├── si.py ├── agem.py ├── ewc_on.py ├── gss.py ├── pnn.py ├── mer.py ├── fdr.py ├── gdumb.py ├── lwf.py ├── joint.py ├── rpc.py ├── gem.py └── hal.py ├── misc ├── vmf_svd-1.png ├── moca_graph-1.png ├── 2d_feat_vis-1.png ├── MOCA_variants-1.png ├── grad_combined-1.png ├── moca_framwork-1.png └── var_comp_v4-1.png └── README.md /utils/euclid2polar.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /backbone/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/vmf_svd-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/vmf_svd-1.png -------------------------------------------------------------------------------- /misc/moca_graph-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/moca_graph-1.png -------------------------------------------------------------------------------- /misc/2d_feat_vis-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/2d_feat_vis-1.png -------------------------------------------------------------------------------- /misc/MOCA_variants-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/MOCA_variants-1.png -------------------------------------------------------------------------------- /misc/grad_combined-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/grad_combined-1.png -------------------------------------------------------------------------------- /misc/moca_framwork-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/moca_framwork-1.png -------------------------------------------------------------------------------- /misc/var_comp_v4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/var_comp_v4-1.png -------------------------------------------------------------------------------- /models/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/rotation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/rotation.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/rotation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/rotation.cpython-37.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/continual_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/continual_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/continual_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/continual_model.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/permutation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/permutation.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/permutation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/permutation.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/denormalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/denormalization.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/transforms/__pycache__/denormalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/denormalization.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/utils/gcl_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | class GCLDataset: 7 | """ 8 | Continual learning evaluation setting. 9 | """ 10 | NAME = None 11 | SETTING = None 12 | N_CLASSES = None 13 | LENGTH = None 14 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | 8 | 9 | def create_if_not_exists(path: str) -> None: 10 | """ 11 | Creates the specified folder if it does not exist. 12 | :param path: the complete path of the folder to be created 13 | """ 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | -------------------------------------------------------------------------------- /datasets/transforms/denormalization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | class DeNormalize(object): 8 | def __init__(self, mean, std): 9 | self.mean = mean 10 | self.std = std 11 | 12 | def __call__(self, tensor): 13 | """ 14 | Args: 15 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 16 | Returns: 17 | Tensor: Normalized image. 18 | """ 19 | for t, m, s in zip(tensor, self.mean, self.std): 20 | t.mul_(s).add_(m) 21 | return tensor 22 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import importlib 8 | 9 | def get_all_models(): 10 | return [model.split('.')[0] for model in os.listdir('models') 11 | if not model.find('__') > -1 and 'py' in model] 12 | 13 | names = {} 14 | for model in get_all_models(): 15 | mod = importlib.import_module('models.' + model) 16 | class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')] 17 | names[model] = getattr(mod, class_name) 18 | 19 | def get_model(args, backbone, loss, transform): 20 | return names[args.model](backbone, loss, args, transform) 21 | -------------------------------------------------------------------------------- /utils/mixup_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys, time 3 | import torch 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import matplotlib 7 | matplotlib.use('agg') 8 | import matplotlib.pyplot as plt 9 | 10 | def mixup_old_data(x, y, alpha): 11 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 12 | if alpha > 0.: 13 | lam = np.random.beta(alpha, alpha) 14 | else: 15 | lam = 1. 16 | batch_size = int(x.size()[0]/2) 17 | index = torch.randperm(batch_size).cuda() 18 | index = index + batch_size 19 | if lam < 0.5: 20 | lam = 1. - lam 21 | mixed_x = lam * x[:batch_size,:] + (1 - lam) * x[index,:] 22 | y_a, y_b = y[:batch_size], y[index] 23 | return mixed_x, y_a, y_b, lam 24 | 25 | # mixed_x, y_a, y_b, lam = mixup_old_data(x, y, alpha=1.0) 26 | # print(mixed_x) 27 | # print(x) 28 | # # print(y_a) 29 | # # print(y_b) -------------------------------------------------------------------------------- /utils/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | import torch 8 | import numpy as np 9 | 10 | def get_device() -> torch.device: 11 | """ 12 | Returns the GPU device if available else CPU. 13 | """ 14 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | def base_path() -> str: 18 | """ 19 | Returns the base bath where to log accuracies and tensorboard data. 20 | """ 21 | return './data/' 22 | 23 | 24 | def set_random_seed(seed: int) -> None: 25 | """ 26 | Sets the seeds at a certain value. 27 | :param seed: the value to be set 28 | """ 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) 33 | -------------------------------------------------------------------------------- /utils/batch_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class bn_track_stats: 10 | def __init__(self, module: nn.Module, condition=True): 11 | self.module = module 12 | self.enable = condition 13 | 14 | def __enter__(self): 15 | if not self.enable: 16 | for m in self.module.modules(): 17 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): 18 | m.track_running_stats = False 19 | 20 | def __exit__(self ,type, value, traceback): 21 | if not self.enable: 22 | for m in self.module.modules(): 23 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): 24 | m.track_running_stats = True -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | 9 | def backward_transfer(results): 10 | n_tasks = len(results) 11 | li = list() 12 | for i in range(n_tasks - 1): 13 | li.append(results[-1][i] - results[i][i]) 14 | 15 | return np.mean(li) 16 | 17 | 18 | def forward_transfer(results, random_results): 19 | n_tasks = len(results) 20 | li = list() 21 | for i in range(1, n_tasks): 22 | li.append(results[i-1][i] - random_results[i]) 23 | 24 | return np.mean(li) 25 | 26 | 27 | def forgetting(results): 28 | n_tasks = len(results) 29 | li = list() 30 | for i in range(n_tasks - 1): 31 | results[i] += [0.0] * (n_tasks - len(results[i])) 32 | np_res = np.array(results) 33 | maxx = np.max(np_res, axis=0) 34 | for i in range(n_tasks - 1): 35 | li.append(maxx[i] - results[-1][i]) 36 | 37 | return np.mean(li) 38 | -------------------------------------------------------------------------------- /models/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils.args import * 7 | from models.utils.continual_model import ContinualModel 8 | import numpy as np 9 | import copy 10 | from torch.nn import functional as F 11 | import math 12 | from torch.optim import SGD 13 | from collections import OrderedDict 14 | EPS = 1E-20 15 | 16 | def get_parser() -> ArgumentParser: 17 | parser = ArgumentParser(description='Continual Learning via' 18 | ' Progressive Neural Networks.') 19 | add_management_args(parser) 20 | add_experiment_args(parser) 21 | return parser 22 | 23 | 24 | class Sgd(ContinualModel): 25 | NAME = 'sgd' 26 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 27 | 28 | def __init__(self, backbone, loss, args, transform): 29 | super(Sgd, self).__init__(backbone, loss, args, transform) 30 | 31 | def observe(self, inputs, labels, not_aug_inputs): 32 | self.opt.zero_grad() 33 | outputs = self.net(inputs) 34 | loss = self.loss(outputs, labels) 35 | loss.backward() 36 | self.opt.step() 37 | 38 | return loss.item() 39 | -------------------------------------------------------------------------------- /utils/plot_lamda_acc.py: -------------------------------------------------------------------------------- 1 | from turtle import color 2 | import numpy as np 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import math 9 | import torch 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.nn import functional as F 13 | import math 14 | 15 | 16 | Model_agnostic = [32.25, 34.46, 36.67, 38.01, 36.21] 17 | Model_based = [38.1, 39.48, 40.78, 41.32, 39.72] 18 | # plt.figure(figsize=(8,8)) 19 | # title = 'CIFAR-100' 20 | # plt.title(r'Perturbation Magnitude $\lambda$') 21 | # plt.title(title, fontsize=14) 22 | plt.ylim(30, 45) 23 | plt.xlabel(r'Perturbation Magnitude $\lambda$', fontsize=14) 24 | plt.ylabel("Accuracy", fontsize=14) 25 | 26 | plt.grid(linewidth = 1.5) 27 | x = ['0.5', '1.0', '1.5', '2.0', '3.0'] 28 | plt.plot(x, Model_agnostic, alpha=0.8, label="Model-agnostic MOCA", linewidth=3, marker= 'o') 29 | plt.plot(x, Model_based, alpha=0.8, label="Model-based MOCA", linewidth=3, marker= 'o') 30 | plt.tick_params(labelsize=14) 31 | plt.axhline(y = 31.12, linestyle='--', color = 'r', linewidth=1.5) 32 | ax = plt.gca() 33 | bwith = 1. 34 | ax.spines['top'].set_linewidth(bwith) 35 | ax.spines['right'].set_linewidth(bwith) 36 | ax.spines['bottom'].set_linewidth(bwith) 37 | ax.spines['left'].set_linewidth(bwith) 38 | plt.legend(loc="best", fontsize=14, edgecolor='black') 39 | plt.show() 40 | plt.savefig('./acc_lamda.pdf', bbox_inches = 'tight') 41 | -------------------------------------------------------------------------------- /utils/plot_perturbed_angular_acc.py: -------------------------------------------------------------------------------- 1 | from turtle import color 2 | import numpy as np 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import math 9 | import torch 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.nn import functional as F 13 | import math 14 | 15 | 16 | Model_agnostic = [30.48, 31.67, 32.58, 34.8, 37.98, 35.59] 17 | Model_based = [34.54, 38.0, 40.55, 39.94, 38.55, 20] 18 | # plt.figure(figsize=(8,8)) 19 | # title = 'CIFAR-100' 20 | # plt.title(r'Perturbation Magnitude $\lambda$') 21 | # plt.title(title, fontsize=14) 22 | plt.ylim(30, 45) 23 | plt.xlabel(r'Perturbation Angular', fontsize=14) 24 | plt.ylabel("Accuracy", fontsize=14) 25 | 26 | plt.grid(linewidth = 1.5) 27 | x = ['5', '15', '30', '45', '60', '75'] 28 | plt.plot(x, Model_agnostic, alpha=0.8, label="Model-agnostic MOCA", linewidth=3, marker= 'o') 29 | plt.plot(x, Model_based, alpha=0.8, label="Model-based MOCA", linewidth=3, marker= 'o') 30 | plt.tick_params(labelsize=14) 31 | plt.axhline(y = 31.12, linestyle='--', color = 'r', linewidth=1.5) 32 | ax = plt.gca() 33 | bwith = 1. 34 | ax.spines['top'].set_linewidth(bwith) 35 | ax.spines['right'].set_linewidth(bwith) 36 | ax.spines['bottom'].set_linewidth(bwith) 37 | ax.spines['left'].set_linewidth(bwith) 38 | plt.legend(loc="best", fontsize=14, edgecolor='black') 39 | plt.show() 40 | plt.savefig('./acc_perturbed_angular.pdf', bbox_inches = 'tight') 41 | -------------------------------------------------------------------------------- /datasets/rot_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torchvision.transforms as transforms 7 | from datasets.transforms.rotation import Rotation 8 | from torch.utils.data import DataLoader 9 | from backbone.MNISTMLP import MNISTMLP 10 | import torch.nn.functional as F 11 | from datasets.perm_mnist import store_mnist_loaders 12 | from datasets.utils.continual_dataset import ContinualDataset 13 | 14 | 15 | class RotatedMNIST(ContinualDataset): 16 | NAME = 'rot-mnist' 17 | SETTING = 'domain-il' 18 | N_CLASSES_PER_TASK = 10 19 | N_TASKS = 20 20 | 21 | def get_data_loaders(self): 22 | transform = transforms.Compose((Rotation(), transforms.ToTensor())) 23 | train, test = store_mnist_loaders(transform, self) 24 | return train, test 25 | 26 | @staticmethod 27 | def get_backbone(): 28 | return MNISTMLP(28 * 28, RotatedMNIST.N_CLASSES_PER_TASK) 29 | 30 | @staticmethod 31 | def get_transform(): 32 | return None 33 | 34 | @staticmethod 35 | def get_normalization_transform(): 36 | return None 37 | 38 | @staticmethod 39 | def get_loss(): 40 | return F.cross_entropy 41 | 42 | @staticmethod 43 | def get_denormalization_transform(): 44 | return None 45 | 46 | @staticmethod 47 | def get_scheduler(model, args): 48 | return None -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import inspect 8 | import importlib 9 | from datasets.utils.gcl_dataset import GCLDataset 10 | from datasets.utils.continual_dataset import ContinualDataset 11 | from argparse import Namespace 12 | 13 | def get_all_models(): 14 | return [model.split('.')[0] for model in os.listdir('datasets') 15 | if not model.find('__') > -1 and 'py' in model] 16 | 17 | NAMES = {} 18 | for model in get_all_models(): 19 | mod = importlib.import_module('datasets.' + model) 20 | dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'ContinualDataset' in str(inspect.getmro(getattr(mod, x))[1:])] 21 | for d in dataset_classes_name: 22 | c = getattr(mod, d) 23 | NAMES[c.NAME] = c 24 | 25 | gcl_dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'GCLDataset' in str(inspect.getmro(getattr(mod, x))[1:])] 26 | for d in gcl_dataset_classes_name: 27 | c = getattr(mod, d) 28 | NAMES[c.NAME] = c 29 | 30 | def get_dataset(args: Namespace) -> ContinualDataset: 31 | """ 32 | Creates and returns a continual dataset. 33 | :param args: the arguments which contains the hyperparameters 34 | :return: the continual dataset 35 | """ 36 | assert args.dataset in NAMES.keys() 37 | return NAMES[args.dataset](args) 38 | -------------------------------------------------------------------------------- /backbone/utils/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class AlphaModule(nn.Module): 12 | def __init__(self, shape): 13 | super(AlphaModule, self).__init__() 14 | if not isinstance(shape, tuple): 15 | shape = (shape,) 16 | self.alpha = Parameter(torch.rand(tuple([1] + list(shape))) * 0.1, 17 | requires_grad=True) 18 | 19 | def forward(self, x): 20 | return x * self.alpha 21 | 22 | def parameters(self, recurse: bool = True): 23 | yield self.alpha 24 | 25 | 26 | class ListModule(nn.Module): 27 | def __init__(self, *args): 28 | super(ListModule, self).__init__() 29 | self.idx = 0 30 | for module in args: 31 | self.add_module(str(self.idx), module) 32 | self.idx += 1 33 | 34 | def append(self, module): 35 | self.add_module(str(self.idx), module) 36 | self.idx += 1 37 | 38 | def __getitem__(self, idx): 39 | if idx < 0: 40 | idx += self.idx 41 | if idx >= len(self._modules): 42 | raise IndexError('index {} is out of range'.format(idx)) 43 | it = iter(self._modules.values()) 44 | for i in range(idx): 45 | next(it) 46 | return next(it) 47 | 48 | def __iter__(self): 49 | return iter(self._modules.values()) 50 | 51 | def __len__(self): 52 | return len(self._modules) 53 | -------------------------------------------------------------------------------- /models/der.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils.buffer import Buffer 7 | from torch.nn import functional as F 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | 11 | def get_parser() -> ArgumentParser: 12 | parser = ArgumentParser(description='Continual learning via' 13 | ' Dark Experience Replay.') 14 | add_management_args(parser) 15 | add_experiment_args(parser) 16 | add_rehearsal_args(parser) 17 | parser.add_argument('--alpha', type=float, required=True, 18 | help='Penalty weight.') 19 | return parser 20 | 21 | 22 | class Der(ContinualModel): 23 | NAME = 'der' 24 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 25 | 26 | def __init__(self, backbone, loss, args, transform): 27 | super(Der, self).__init__(backbone, loss, args, transform) 28 | self.buffer = Buffer(self.args.buffer_size, self.device) 29 | 30 | def observe(self, inputs, labels, not_aug_inputs): 31 | 32 | self.opt.zero_grad() 33 | 34 | outputs = self.net(inputs) 35 | loss = self.loss(outputs, labels) 36 | 37 | if not self.buffer.is_empty(): 38 | buf_inputs, buf_logits = self.buffer.get_data( 39 | self.args.minibatch_size, transform=self.transform) 40 | buf_outputs = self.net(buf_inputs) 41 | loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits) 42 | 43 | loss.backward() 44 | self.opt.step() 45 | self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data) 46 | 47 | return loss.item() 48 | -------------------------------------------------------------------------------- /datasets/transforms/permutation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | 9 | class Permutation(object): 10 | """ 11 | Defines a fixed permutation for a numpy array. 12 | """ 13 | def __init__(self) -> None: 14 | """ 15 | Initializes the permutation. 16 | """ 17 | self.perm = None 18 | 19 | def __call__(self, sample: np.ndarray) -> np.ndarray: 20 | """ 21 | Randomly defines the permutation and applies the transformation. 22 | :param sample: image to be permuted 23 | :return: permuted image 24 | """ 25 | old_shape = sample.shape 26 | if self.perm is None: 27 | self.perm = np.random.permutation(len(sample.flatten())) 28 | 29 | return sample.flatten()[self.perm].reshape(old_shape) 30 | 31 | 32 | class FixedPermutation(object): 33 | """ 34 | Defines a fixed permutation (given the seed) for a numpy array. 35 | """ 36 | def __init__(self, seed: int) -> None: 37 | """ 38 | Defines the seed. 39 | :param seed: seed of the permutation 40 | """ 41 | self.perm = None 42 | self.seed = seed 43 | 44 | def __call__(self, sample: np.ndarray) -> np.ndarray: 45 | """ 46 | Defines the permutation and applies the transformation. 47 | :param sample: image to be permuted 48 | :return: permuted image 49 | """ 50 | old_shape = sample.shape 51 | if self.perm is None: 52 | np.random.seed(self.seed) 53 | self.perm = np.random.permutation(len(sample.flatten())) 54 | 55 | return sample.flatten()[self.perm].reshape(old_shape) 56 | -------------------------------------------------------------------------------- /utils/plot_mixup_acc.py: -------------------------------------------------------------------------------- 1 | from turtle import color 2 | import numpy as np 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import math 9 | import torch 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.nn import functional as F 13 | import math 14 | 15 | 16 | 17 | data = [29.33, 34.08, 38.76, 41.02] 18 | labels = ['Original \n Manifold Mixup', 'Adapted \n Manifold Mixup', 'model-agnostic \n MOCA', 'model-based \n MOCA'] 19 | plt.figure(figsize=(8,5)) 20 | plt.bar(range(len(data)), data, tick_label=labels, width = 0.35, alpha = 0.8, edgecolor='black') 21 | plt.ylim(30, 42) 22 | plt.axhline(y = 31.08, linestyle='--', color = 'r', linewidth=1.5) 23 | plt.tick_params(labelsize=12) 24 | plt.ylabel('Accuracy', fontsize=14) 25 | plt.xlabel('Method', fontsize=14) 26 | ax = plt.gca() 27 | bwith = 1. 28 | ax.spines['top'].set_linewidth(bwith) 29 | ax.spines['right'].set_linewidth(bwith) 30 | ax.spines['bottom'].set_linewidth(bwith) 31 | ax.spines['left'].set_linewidth(bwith) 32 | ax.yaxis.set_ticks([26, 32, 38, 44]) 33 | plt.show() 34 | plt.savefig('./mixup_acc.pdf', bbox_inches = 'tight') 35 | 36 | 37 | 38 | def visual_svd(s_list, method_list): 39 | x = np.arange(0,512) 40 | from matplotlib.ticker import MaxNLocator 41 | plt.ion() 42 | plt.clf() 43 | for i in range(len(s_list)): 44 | if 'V-vmf' in method_list[i]: 45 | continue 46 | s_list[i] = s_list[i] / torch.sum(s_list[i]) 47 | plt.plot(x, torch.log(s_list[i]), label=method_list[i]) 48 | plt.xlabel('Dimension', fontsize=14) 49 | plt.ylabel('Log Singular Value', fontsize=14) 50 | plt.legend() 51 | 52 | ax = plt.gca() 53 | bwith = 1. 54 | ax.spines['top'].set_linewidth(bwith) 55 | ax.spines['right'].set_linewidth(bwith) 56 | ax.spines['bottom'].set_linewidth(bwith) 57 | ax.spines['left'].set_linewidth(bwith) 58 | plt.tick_params(labelsize=14) 59 | plt.savefig('./aug_feat.pdf') 60 | plt.show 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ☕ (MOCA) Continual Learning by Modeling Intra-Class Variation 2 | This is an official implementation of the TMLR 2023 paper "Continual Learning by Modeling Intra-Class Variation" (MOCA). 3 | ## Environment 4 | This work is based on the code of [DER](https://github.com/aimagelab/mammoth): 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | ## Setup 9 | + Use `./utils/main.py` to run experiments. 10 | + Use argument `--load_best_args` to use the best hyperparameters from the paper. 11 | 12 | ## TODO: 13 | - [x] Release code! 14 | - [x] Bash Arguments! 15 | - [ ] The code is still dirty and we'll sort it out soon. 16 | - [ ] 2D Visualization code and Gradients Analysis code. 17 | 18 | ## Examples 19 | For reproducing the results of our MOCA-Gaussian on Cifar-100, run: 20 | ```bash 21 | python ./utils/main.py --load_best_args --model er --dataset seq-cifar100 --buffer_size 500 --para_scale 1.5 --gamma_loss 1 --norm_add norm_add --method2 gaussian --noise_type noise 22 | ``` 23 | 24 | For reproducing the results of our MOCA-WAP on Cifar-100, run: 25 | ```bash 26 | python ./utils/main.py --load_best_args --model er --dataset seq-cifar100 --buffer_size 500 --para_scale 1.5 --gamma_loss 10 --norm_add norm_add --advloss none --target_type new_labels --noise_type adv --inner_iter 1 27 | ``` 28 | 29 | ## Citation 30 | If you find this code or idea useful, please cite our work: 31 | ```bib 32 | @article{yu2022continual, 33 | title={Continual Learning by Modeling Intra-Class Variation}, 34 | author={Yu, Longhui and Hu, Tianyang and Hong, Lanqing and Liu, Zhen and Weller, Adrian and Liu, Weiyang}, 35 | journal={arXiv preprint arXiv:2210.05398}, 36 | year={2022} 37 | } 38 | ``` 39 | 40 | ## Contact 41 | If you have any questions, feel free to contact us through email (yulonghui@stu.pku.edu.cn). Enjoy! 42 | 43 | ## Intra-class Variation Gap 44 |

45 | 46 |

47 | 48 | ## Representation Collapse 49 |

50 | 51 |

52 | 53 | ## Gradient Collapse 54 |

55 | 56 |

57 | 58 | ## MOCA Framework 59 |

60 | 61 |

62 | 63 | ## MOCA Variants 64 |

65 | 66 |

67 | -------------------------------------------------------------------------------- /utils/plot_training_gradient.py: -------------------------------------------------------------------------------- 1 | 2 | from turtle import color 3 | import numpy as np 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use("Agg") 7 | import matplotlib.pyplot as plt 8 | import copy 9 | import math 10 | import torch 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | from torch.nn import functional as F 14 | import math 15 | 16 | def visual_svd(s_list, method_list): 17 | x = np.arange(0,512) 18 | from matplotlib.ticker import MaxNLocator 19 | plt.ion() 20 | plt.clf() 21 | for i in [5, 3, 6, 7, 1, 2, 0, 4]: 22 | s_list[i] = s_list[i] / torch.sum(s_list[i]) 23 | if method_list[i] == 'WAP': 24 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.9, color = 'purple') 25 | continue 26 | if method_list[i] == 'DOA-new': 27 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.65, color='purple') 28 | continue 29 | if method_list[i] == 'VT': 30 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.4, color='purple') 31 | continue 32 | if method_list[i] == 'vMF': 33 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.9, color='chocolate') 34 | continue 35 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.9) 36 | 37 | plt.xlabel('Dimension', fontsize=14) 38 | plt.ylabel('Log Singular Value', fontsize=14) 39 | plt.legend(fontsize=12, edgecolor='black') 40 | # plt.legend.get_frame().set_edgecolor('black') 41 | plt.axvline(x = 50, linestyle='--', color = 'r', linewidth=1.5) 42 | ax = plt.gca() 43 | bwith = 1. 44 | ax.spines['top'].set_linewidth(bwith) 45 | ax.spines['right'].set_linewidth(bwith) 46 | ax.spines['bottom'].set_linewidth(bwith) 47 | ax.spines['left'].set_linewidth(bwith) 48 | plt.tick_params(labelsize=14) 49 | plt.savefig('./vmf_svd.pdf') 50 | plt.show() 51 | 52 | 53 | s_list2 = np.loadtxt('all_perturb_gradient_31.txt') 54 | 55 | s_list2 = torch.from_numpy(s_list2) 56 | 57 | method_list = [ 58 | 'WAP', 59 | 'VT', 60 | 'DOA-new' , 61 | 'DOA-old' , 62 | 'Joint' , 63 | 'ER' , 64 | 'Gaussian' , 65 | 'vMF' ] 66 | 67 | # s_list2 = s_list2.cpu().detach().numpy() 68 | # np.savetxt('all_perturb_gradient_31.txt',s_list2) 69 | visual_svd(s_list2, method_list) -------------------------------------------------------------------------------- /models/joint_gcl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.optim import SGD 7 | 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | from datasets.utils.validation import ValidationDataset 11 | from utils.status import progress_bar 12 | import torch 13 | import numpy as np 14 | import math 15 | from tqdm import tqdm 16 | from torchvision import transforms 17 | 18 | 19 | def get_parser() -> ArgumentParser: 20 | parser = ArgumentParser(description='Joint training: a strong, simple baseline.') 21 | add_management_args(parser) 22 | add_experiment_args(parser) 23 | return parser 24 | 25 | 26 | class JointGCL(ContinualModel): 27 | NAME = 'joint_gcl' 28 | COMPATIBILITY = ['general-continual'] 29 | 30 | def __init__(self, backbone, loss, args, transform): 31 | super(JointGCL, self).__init__(backbone, loss, args, transform) 32 | self.old_data = [] 33 | self.old_labels = [] 34 | self.current_task = 0 35 | 36 | def end_task(self, dataset): 37 | # reinit network 38 | self.net = dataset.get_backbone() 39 | self.net.to(self.device) 40 | self.net.train() 41 | self.opt = SGD(self.net.parameters(), lr=self.args.lr) 42 | 43 | # gather data 44 | all_data = torch.cat(self.old_data) 45 | all_labels = torch.cat(self.old_labels) 46 | 47 | # train 48 | for e in range(1):#range(self.args.n_epochs): 49 | rp = torch.randperm(len(all_data)) 50 | for i in range(math.ceil(len(all_data) / self.args.batch_size)): 51 | inputs = all_data[rp][i * self.args.batch_size:(i+1) * self.args.batch_size] 52 | labels = all_labels[rp][i * self.args.batch_size:(i+1) * self.args.batch_size] 53 | inputs, labels = inputs.to(self.device), labels.to(self.device) 54 | 55 | self.opt.zero_grad() 56 | outputs = self.net(inputs) 57 | loss = self.loss(outputs, labels.long()) 58 | loss.backward() 59 | self.opt.step() 60 | progress_bar(i, math.ceil(len(all_data) / self.args.batch_size), e, 'J', loss.item()) 61 | 62 | def observe(self, inputs, labels, not_aug_inputs): 63 | self.old_data.append(inputs.data) 64 | self.old_labels.append(labels.data) 65 | return 0 66 | -------------------------------------------------------------------------------- /models/agem_r.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from utils.buffer import Buffer 9 | from models.gem import overwrite_grad 10 | from models.gem import store_grad 11 | from utils.args import * 12 | from models.agem import project 13 | from models.utils.continual_model import ContinualModel 14 | 15 | def get_parser() -> ArgumentParser: 16 | parser = ArgumentParser(description='Continual learning via A-GEM, ' 17 | 'leveraging a reservoir buffer.') 18 | add_management_args(parser) 19 | add_experiment_args(parser) 20 | add_rehearsal_args(parser) 21 | return parser 22 | 23 | class AGemr(ContinualModel): 24 | NAME = 'agem_r' 25 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 26 | 27 | def __init__(self, backbone, loss, args, transform): 28 | super(AGemr, self).__init__(backbone, loss, args, transform) 29 | 30 | self.buffer = Buffer(self.args.buffer_size, self.device) 31 | self.grad_dims = [] 32 | for param in self.parameters(): 33 | self.grad_dims.append(param.data.numel()) 34 | self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device) 35 | self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device) 36 | self.current_task = 0 37 | 38 | def observe(self, inputs, labels, not_aug_inputs): 39 | self.zero_grad() 40 | p = self.net.forward(inputs) 41 | loss = self.loss(p, labels) 42 | loss.backward() 43 | 44 | if not self.buffer.is_empty(): 45 | store_grad(self.parameters, self.grad_xy, self.grad_dims) 46 | 47 | buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size) 48 | self.net.zero_grad() 49 | buf_outputs = self.net.forward(buf_inputs) 50 | penalty = self.loss(buf_outputs, buf_labels) 51 | penalty.backward() 52 | store_grad(self.parameters, self.grad_er, self.grad_dims) 53 | 54 | dot_prod = torch.dot(self.grad_xy, self.grad_er) 55 | if dot_prod.item() < 0: 56 | g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) 57 | overwrite_grad(self.parameters, g_tilde, self.grad_dims) 58 | else: 59 | overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) 60 | 61 | self.opt.step() 62 | 63 | self.buffer.add_data(examples=not_aug_inputs, labels=labels) 64 | 65 | return loss.item() 66 | -------------------------------------------------------------------------------- /models/si.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | 11 | import numpy as np 12 | import copy 13 | from torch.nn import functional as F 14 | import math 15 | from torch.optim import SGD 16 | from collections import OrderedDict 17 | EPS = 1E-20 18 | 19 | def get_parser() -> ArgumentParser: 20 | parser = ArgumentParser(description='Continual Learning Through' 21 | ' Synaptic Intelligence.') 22 | add_management_args(parser) 23 | add_experiment_args(parser) 24 | parser.add_argument('--c', type=float, required=True, 25 | help='surrogate loss weight parameter c') 26 | parser.add_argument('--xi', type=float, required=True, 27 | help='xi parameter for EWC online') 28 | 29 | return parser 30 | 31 | 32 | class SI(ContinualModel): 33 | NAME = 'si' 34 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] 35 | 36 | def __init__(self, backbone, loss, args, transform): 37 | super(SI, self).__init__(backbone, loss, args, transform) 38 | 39 | self.checkpoint = self.net.get_params().data.clone().to(self.device) 40 | self.big_omega = None 41 | self.small_omega = 0 42 | 43 | def penalty(self): 44 | if self.big_omega is None: 45 | return torch.tensor(0.0).to(self.device) 46 | else: 47 | penalty = (self.big_omega * ((self.net.get_params() - self.checkpoint) ** 2)).sum() 48 | return penalty 49 | 50 | def end_task(self, dataset): 51 | # big omega calculation step 52 | if self.big_omega is None: 53 | self.big_omega = torch.zeros_like(self.net.get_params()).to(self.device) 54 | 55 | self.big_omega += self.small_omega / ((self.net.get_params().data - self.checkpoint) ** 2 + self.args.xi) 56 | 57 | # store parameters checkpoint and reset small_omega 58 | self.checkpoint = self.net.get_params().data.clone().to(self.device) 59 | self.small_omega = 0 60 | 61 | def observe(self, inputs, labels, not_aug_inputs): 62 | self.opt.zero_grad() 63 | outputs = self.net(inputs) 64 | penalty = self.penalty() 65 | loss = self.loss(outputs, labels) + self.args.c * penalty 66 | loss.backward() 67 | nn.utils.clip_grad.clip_grad_value_(self.net.parameters(), 1) 68 | self.opt.step() 69 | 70 | self.small_omega += self.args.lr * self.net.get_grads().data ** 2 71 | 72 | return loss.item() 73 | -------------------------------------------------------------------------------- /utils/plot_all_method_acc.py: -------------------------------------------------------------------------------- 1 | from turtle import color 2 | import numpy as np 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import math 9 | import torch 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.nn import functional as F 13 | import math 14 | 15 | 16 | 17 | data = [33.67, 37.29, 38.76, 38.75, 39.78, 41.02] 18 | labels = ['DOA-self', 'Gaussian', 'vMF', 'DOA', 'VT', 'WAP'] 19 | plt.figure(figsize=(8,4)) 20 | plt.figure(figsize=(8,4)) 21 | plt.bar(range(len(data)), data, tick_label=labels, width = 0.35, alpha = 0.8, edgecolor='black') 22 | plt.ylim(30, 42) 23 | plt.axhline(y = 31.08, linestyle='--', color = 'r', linewidth=1.5) 24 | plt.tick_params(labelsize=12) 25 | plt.ylabel('Accuracy', fontsize=14) 26 | plt.xlabel('Method', fontsize=14) 27 | ax = plt.gca() 28 | bwith = 1. 29 | ax.spines['top'].set_linewidth(bwith) 30 | ax.spines['right'].set_linewidth(bwith) 31 | ax.spines['bottom'].set_linewidth(bwith) 32 | ax.spines['left'].set_linewidth(bwith) 33 | ax.yaxis.set_ticks([30, 34, 38, 42]) 34 | plt.show() 35 | plt.savefig('./all_method_acc.pdf', bbox_inches = 'tight') 36 | 37 | 38 | 39 | 40 | from matplotlib.ticker import MaxNLocator 41 | data = [33.54, 51.279, 68.75, 41.17, 35.72, 37.08] 42 | labels = ['DOA-self', 'Gaussian', 'vMF', 'DOA', 'VT', 'WAP'] 43 | plt.figure(figsize=(8,4)) 44 | plt.bar(range(len(data)), data, tick_label=labels, width = 0.3, alpha = 0.8, edgecolor='black') 45 | plt.ylim(29, 46) 46 | plt.axhline(y = 30.12, linestyle='--', color = 'r', linewidth=1.5) 47 | plt.tick_params(labelsize=12) 48 | plt.ylabel('Old Intra-class Variance', fontsize=14) 49 | plt.xlabel('Method', fontsize=14) 50 | ax = plt.gca() 51 | bwith = 1. 52 | ax.spines['top'].set_linewidth(bwith) 53 | ax.spines['right'].set_linewidth(bwith) 54 | ax.spines['bottom'].set_linewidth(bwith) 55 | ax.spines['left'].set_linewidth(bwith) 56 | ax.yaxis.set_ticks([30, 34, 38, 42, 46]) 57 | plt.show() 58 | plt.savefig('./all_method_variance_exp.pdf', bbox_inches = 'tight') 59 | 60 | def visual_svd(s_list, method_list): 61 | x = np.arange(0,512) 62 | from matplotlib.ticker import MaxNLocator 63 | plt.ion() 64 | plt.clf() 65 | for i in range(len(s_list)): 66 | if 'V-vmf' in method_list[i]: 67 | continue 68 | s_list[i] = s_list[i] / torch.sum(s_list[i]) 69 | plt.plot(x, torch.log(s_list[i]), label=method_list[i]) 70 | plt.xlabel('Dimension', fontsize=14) 71 | plt.ylabel('Log Singular Value', fontsize=14) 72 | plt.legend() 73 | 74 | ax = plt.gca() 75 | bwith = 1. 76 | ax.spines['top'].set_linewidth(bwith) 77 | ax.spines['right'].set_linewidth(bwith) 78 | ax.spines['bottom'].set_linewidth(bwith) 79 | ax.spines['left'].set_linewidth(bwith) 80 | plt.tick_params(labelsize=14) 81 | plt.savefig('./aug_feat.pdf') 82 | plt.show 83 | -------------------------------------------------------------------------------- /utils/continual_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from datasets import get_dataset 9 | from models import get_model 10 | from utils.status import progress_bar 11 | from utils.tb_logger import * 12 | from utils.status import create_fake_stash 13 | from models.utils.continual_model import ContinualModel 14 | from argparse import Namespace 15 | 16 | 17 | def evaluate(model: ContinualModel, dataset) -> float: 18 | """ 19 | Evaluates the final accuracy of the model. 20 | :param model: the model to be evaluated 21 | :param dataset: the GCL dataset at hand 22 | :return: a float value that indicates the accuracy 23 | """ 24 | model.net.eval() 25 | correct, total = 0, 0 26 | while not dataset.test_over: 27 | inputs, labels = dataset.get_test_data() 28 | inputs, labels = inputs.to(model.device), labels.to(model.device) 29 | outputs = model(inputs) 30 | _, predicted = torch.max(outputs.data, 1) 31 | correct += torch.sum(predicted == labels).item() 32 | total += labels.shape[0] 33 | 34 | acc = correct / total * 100 35 | return acc 36 | 37 | 38 | def train(args: Namespace): 39 | """ 40 | The training process, including evaluations and loggers. 41 | :param model: the module to be trained 42 | :param dataset: the continual dataset at hand 43 | :param args: the arguments of the current execution 44 | """ 45 | if args.csv_log: 46 | from utils.loggers import CsvLogger 47 | 48 | dataset = get_dataset(args) 49 | backbone = dataset.get_backbone() 50 | loss = dataset.get_loss() 51 | model = get_model(args, backbone, loss, dataset.get_transform()) 52 | model.net.to(model.device) 53 | 54 | model_stash = create_fake_stash(model, args) 55 | 56 | if args.csv_log: 57 | csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME) 58 | if args.tensorboard: 59 | tb_logger = TensorboardLogger(args, dataset.SETTING, model_stash) 60 | 61 | model.net.train() 62 | epoch, i = 0, 0 63 | while not dataset.train_over: 64 | inputs, labels, not_aug_inputs = dataset.get_train_data() 65 | inputs, labels = inputs.to(model.device), labels.to(model.device) 66 | not_aug_inputs = not_aug_inputs.to(model.device) 67 | loss = model.observe(inputs, labels, not_aug_inputs) 68 | progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss) 69 | if args.tensorboard: 70 | tb_logger.log_loss_gcl(loss, i) 71 | i += 1 72 | 73 | if model.NAME == 'joint_gcl': 74 | model.end_task(dataset) 75 | 76 | acc = evaluate(model, dataset) 77 | print('Accuracy:', acc) 78 | 79 | if args.csv_log: 80 | csv_logger.log(acc) 81 | csv_logger.write(vars(args)) 82 | -------------------------------------------------------------------------------- /models/agem.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from utils.buffer import Buffer 9 | from models.gem import overwrite_grad 10 | from models.gem import store_grad 11 | from utils.args import * 12 | from models.utils.continual_model import ContinualModel 13 | 14 | def get_parser() -> ArgumentParser: 15 | parser = ArgumentParser(description='Continual learning via A-GEM.') 16 | add_management_args(parser) 17 | add_experiment_args(parser) 18 | add_rehearsal_args(parser) 19 | return parser 20 | 21 | def project(gxy: torch.Tensor, ger: torch.Tensor) -> torch.Tensor: 22 | corr = torch.dot(gxy, ger) / torch.dot(ger, ger) 23 | return gxy - corr * ger 24 | 25 | 26 | class AGem(ContinualModel): 27 | NAME = 'agem' 28 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] 29 | 30 | def __init__(self, backbone, loss, args, transform): 31 | super(AGem, self).__init__(backbone, loss, args, transform) 32 | 33 | self.buffer = Buffer(self.args.buffer_size, self.device) 34 | self.grad_dims = [] 35 | for param in self.parameters(): 36 | self.grad_dims.append(param.data.numel()) 37 | self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device) 38 | self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device) 39 | 40 | def end_task(self, dataset): 41 | samples_per_task = self.args.buffer_size // dataset.N_TASKS 42 | loader = dataset.train_loader 43 | cur_y, cur_x = next(iter(loader))[1:] 44 | self.buffer.add_data( 45 | examples=cur_x.to(self.device), 46 | labels=cur_y.to(self.device) 47 | ) 48 | 49 | def observe(self, inputs, labels, not_aug_inputs): 50 | 51 | self.zero_grad() 52 | p = self.net.forward(inputs) 53 | loss = self.loss(p, labels) 54 | loss.backward() 55 | 56 | if not self.buffer.is_empty(): 57 | store_grad(self.parameters, self.grad_xy, self.grad_dims) 58 | 59 | buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform) 60 | self.net.zero_grad() 61 | buf_outputs = self.net.forward(buf_inputs) 62 | penalty = self.loss(buf_outputs, buf_labels) 63 | penalty.backward() 64 | store_grad(self.parameters, self.grad_er, self.grad_dims) 65 | 66 | dot_prod = torch.dot(self.grad_xy, self.grad_er) 67 | if dot_prod.item() < 0: 68 | g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) 69 | overwrite_grad(self.parameters, g_tilde, self.grad_dims) 70 | else: 71 | overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) 72 | 73 | self.opt.step() 74 | 75 | return loss.item() 76 | -------------------------------------------------------------------------------- /datasets/utils/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from PIL import Image 8 | import numpy as np 9 | import os 10 | from utils import create_if_not_exists 11 | import torchvision.transforms.transforms as transforms 12 | from torchvision import datasets 13 | 14 | 15 | class ValidationDataset(torch.utils.data.Dataset): 16 | def __init__(self, data: torch.Tensor, targets: np.ndarray, 17 | transform: transforms=None, target_transform: transforms=None) -> None: 18 | self.data = data 19 | self.targets = targets 20 | self.transform = transform 21 | self.target_transform = target_transform 22 | 23 | def __len__(self): 24 | return self.data.shape[0] 25 | 26 | def __getitem__(self, index): 27 | img, target = self.data[index], self.targets[index] 28 | 29 | # doing this so that it is consistent with all other datasets 30 | # to return a PIL Image 31 | if isinstance(img, np.ndarray): 32 | if np.max(img) < 2: 33 | img = Image.fromarray(np.uint8(img * 255)) 34 | else: 35 | img = Image.fromarray(img) 36 | else: 37 | img = Image.fromarray(img.numpy()) 38 | 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | 42 | if self.target_transform is not None: 43 | target = self.target_transform(target) 44 | 45 | return img, target 46 | 47 | def get_train_val(train: datasets, test_transform: transforms, 48 | dataset: str, val_perc: float=0.1): 49 | """ 50 | Extract val_perc% of the training set as the validation set. 51 | :param train: training dataset 52 | :param test_transform: transformation of the test dataset 53 | :param dataset: dataset name 54 | :param val_perc: percentage of the training set to be extracted 55 | :return: the training set and the validation set 56 | """ 57 | dataset_length = train.data.shape[0] 58 | directory = 'datasets/val_permutations/' 59 | create_if_not_exists(directory) 60 | file_name = dataset + '.pt' 61 | if os.path.exists(directory + file_name): 62 | perm = torch.load(directory + file_name) 63 | else: 64 | perm = torch.randperm(dataset_length) 65 | torch.save(perm, directory + file_name) 66 | train.data = train.data[perm] 67 | train.targets = np.array(train.targets)[perm] 68 | test_dataset = ValidationDataset(train.data[:int(val_perc * dataset_length)], 69 | train.targets[:int(val_perc * dataset_length)], 70 | transform=test_transform) 71 | train.data = train.data[int(val_perc * dataset_length):] 72 | train.targets = train.targets[int(val_perc * dataset_length):] 73 | 74 | return train, test_dataset 75 | -------------------------------------------------------------------------------- /models/ewc_on.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from utils.args import * 10 | from models.utils.continual_model import ContinualModel 11 | 12 | 13 | def get_parser() -> ArgumentParser: 14 | parser = ArgumentParser(description='Continual learning via' 15 | ' online EWC.') 16 | add_management_args(parser) 17 | add_experiment_args(parser) 18 | parser.add_argument('--e_lambda', type=float, required=True, 19 | help='lambda weight for EWC') 20 | parser.add_argument('--gamma', type=float, required=True, 21 | help='gamma parameter for EWC online') 22 | 23 | return parser 24 | 25 | 26 | class EwcOn(ContinualModel): 27 | NAME = 'ewc_on' 28 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] 29 | 30 | def __init__(self, backbone, loss, args, transform): 31 | super(EwcOn, self).__init__(backbone, loss, args, transform) 32 | 33 | self.logsoft = nn.LogSoftmax(dim=1) 34 | self.checkpoint = None 35 | self.fish = None 36 | 37 | def penalty(self): 38 | if self.checkpoint is None: 39 | return torch.tensor(0.0).to(self.device) 40 | else: 41 | penalty = (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum() 42 | return penalty 43 | 44 | def end_task(self, dataset): 45 | fish = torch.zeros_like(self.net.get_params()) 46 | 47 | for j, data in enumerate(dataset.train_loader): 48 | inputs, labels, _ = data 49 | inputs, labels = inputs.to(self.device), labels.to(self.device) 50 | for ex, lab in zip(inputs, labels): 51 | self.opt.zero_grad() 52 | output = self.net(ex.unsqueeze(0)) 53 | loss = - F.nll_loss(self.logsoft(output), lab.unsqueeze(0), 54 | reduction='none') 55 | exp_cond_prob = torch.mean(torch.exp(loss.detach().clone())) 56 | loss = torch.mean(loss) 57 | loss.backward() 58 | fish += exp_cond_prob * self.net.get_grads() ** 2 59 | 60 | fish /= (len(dataset.train_loader) * self.args.batch_size) 61 | 62 | if self.fish is None: 63 | self.fish = fish 64 | else: 65 | self.fish *= self.args.gamma 66 | self.fish += fish 67 | 68 | self.checkpoint = self.net.get_params().data.clone() 69 | 70 | def observe(self, inputs, labels, not_aug_inputs): 71 | 72 | self.opt.zero_grad() 73 | outputs = self.net(inputs) 74 | penalty = self.penalty() 75 | loss = self.loss(outputs, labels) + self.args.e_lambda * penalty 76 | assert not torch.isnan(loss) 77 | loss.backward() 78 | self.opt.step() 79 | 80 | return loss.item() 81 | -------------------------------------------------------------------------------- /datasets/transforms/rotation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | class Rotation(object): 11 | """ 12 | Defines a fixed rotation for a numpy array. 13 | """ 14 | 15 | def __init__(self, deg_min: int = 0, deg_max: int = 180) -> None: 16 | """ 17 | Initializes the rotation with a random angle. 18 | :param deg_min: lower extreme of the possible random angle 19 | :param deg_max: upper extreme of the possible random angle 20 | """ 21 | self.deg_min = deg_min 22 | self.deg_max = deg_max 23 | self.degrees = np.random.uniform(self.deg_min, self.deg_max) 24 | 25 | def __call__(self, x: np.ndarray) -> np.ndarray: 26 | """ 27 | Applies the rotation. 28 | :param x: image to be rotated 29 | :return: rotated image 30 | """ 31 | return F.rotate(x, self.degrees) 32 | 33 | 34 | class FixedRotation(object): 35 | """ 36 | Defines a fixed rotation for a numpy array. 37 | """ 38 | 39 | def __init__(self, seed: int, deg_min: int = 0, deg_max: int = 180) -> None: 40 | """ 41 | Initializes the rotation with a random angle. 42 | :param seed: seed of the rotation 43 | :param deg_min: lower extreme of the possible random angle 44 | :param deg_max: upper extreme of the possible random angle 45 | """ 46 | self.seed = seed 47 | self.deg_min = deg_min 48 | self.deg_max = deg_max 49 | 50 | np.random.seed(seed) 51 | self.degrees = np.random.uniform(self.deg_min, self.deg_max) 52 | 53 | def __call__(self, x: np.ndarray) -> np.ndarray: 54 | """ 55 | Applies the rotation. 56 | :param x: image to be rotated 57 | :return: rotated image 58 | """ 59 | return F.rotate(x, self.degrees) 60 | 61 | 62 | class IncrementalRotation(object): 63 | """ 64 | Defines an incremental rotation for a numpy array. 65 | """ 66 | 67 | def __init__(self, init_deg: int = 0, increase_per_iteration: float = 0.006) -> None: 68 | """ 69 | Defines the initial angle as well as the increase for each rotation 70 | :param init_deg: 71 | :param increase_per_iteration: 72 | """ 73 | self.increase_per_iteration = increase_per_iteration 74 | self.iteration = 0 75 | self.degrees = init_deg 76 | 77 | def __call__(self, x: np.ndarray) -> np.ndarray: 78 | """ 79 | Applies the rotation. 80 | :param x: image to be rotated 81 | :return: rotated image 82 | """ 83 | degs = (self.iteration * self.increase_per_iteration + self.degrees) % 360 84 | self.iteration += 1 85 | return F.rotate(x, degs) 86 | 87 | def set_iteration(self, x: int) -> None: 88 | """ 89 | Set the iteration to a given integer 90 | :param x: iteration index 91 | """ 92 | self.iteration = x 93 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def xavier(m: nn.Module) -> None: 12 | """ 13 | Applies Xavier initialization to linear modules. 14 | 15 | :param m: the module to be initialized 16 | 17 | Example:: 18 | >>> net = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) 19 | >>> net.apply(xavier) 20 | """ 21 | if m.__class__.__name__ == 'Linear': 22 | fan_in = m.weight.data.size(1) 23 | fan_out = m.weight.data.size(0) 24 | std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out)) 25 | a = math.sqrt(3.0) * std 26 | m.weight.data.uniform_(-a, a) 27 | if m.bias is not None: 28 | m.bias.data.fill_(0.0) 29 | 30 | 31 | def num_flat_features(x: torch.Tensor) -> int: 32 | """ 33 | Computes the total number of items except the first dimension. 34 | 35 | :param x: input tensor 36 | :return: number of item from the second dimension onward 37 | """ 38 | size = x.size()[1:] 39 | num_features = 1 40 | for ff in size: 41 | num_features *= ff 42 | return num_features 43 | 44 | class MammothBackbone(nn.Module): 45 | 46 | def __init__(self, **kwargs) -> None: 47 | super(MammothBackbone, self).__init__() 48 | 49 | def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: 50 | raise NotImplementedError 51 | 52 | def features(self, x: torch.Tensor) -> torch.Tensor: 53 | return self.forward(x, returnt='features') 54 | 55 | def get_params(self) -> torch.Tensor: 56 | """ 57 | Returns all the parameters concatenated in a single tensor. 58 | :return: parameters tensor (??) 59 | """ 60 | params = [] 61 | for pp in list(self.parameters()): 62 | params.append(pp.view(-1)) 63 | return torch.cat(params) 64 | 65 | def set_params(self, new_params: torch.Tensor) -> None: 66 | """ 67 | Sets the parameters to a given value. 68 | :param new_params: concatenated values to be set (??) 69 | """ 70 | assert new_params.size() == self.get_params().size() 71 | progress = 0 72 | for pp in list(self.parameters()): 73 | cand_params = new_params[progress: progress + 74 | torch.tensor(pp.size()).prod()].view(pp.size()) 75 | progress += torch.tensor(pp.size()).prod() 76 | pp.data = cand_params 77 | 78 | def get_grads(self) -> torch.Tensor: 79 | """ 80 | Returns all the gradients concatenated in a single tensor. 81 | :return: gradients tensor (??) 82 | """ 83 | return torch.cat(self.get_grads_list()) 84 | 85 | def get_grads_list(self): 86 | """ 87 | Returns a list containing the gradients (a tensor for each layer). 88 | :return: gradients list 89 | """ 90 | grads = [] 91 | for pp in list(self.parameters()): 92 | grads.append(pp.grad.view(-1)) 93 | return grads 94 | -------------------------------------------------------------------------------- /models/gss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from utils.gss_buffer import Buffer as Buffer 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | 11 | import numpy as np 12 | import copy 13 | from torch.nn import functional as F 14 | import math 15 | from torch.optim import SGD 16 | from collections import OrderedDict 17 | EPS = 1E-20 18 | 19 | def get_parser() -> ArgumentParser: 20 | parser = ArgumentParser(description='Gradient based sample selection' 21 | 'for online continual learning') 22 | add_management_args(parser) 23 | add_experiment_args(parser) 24 | add_rehearsal_args(parser) 25 | parser.add_argument('--batch_num', type=int, required=True, 26 | help='Number of batches extracted from the buffer.') 27 | parser.add_argument('--gss_minibatch_size', type=int, default=None, 28 | help='The batch size of the gradient comparison.') 29 | return parser 30 | 31 | 32 | class Gss(ContinualModel): 33 | NAME = 'gss' 34 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 35 | 36 | def __init__(self, backbone, loss, args, transform): 37 | super(Gss, self).__init__(backbone, loss, args, transform) 38 | self.buffer = Buffer(self.args.buffer_size, self.device, 39 | self.args.gss_minibatch_size if 40 | self.args.gss_minibatch_size is not None 41 | else self.args.minibatch_size, self) 42 | self.alj_nepochs = self.args.batch_num 43 | 44 | def get_grads(self, inputs, labels): 45 | self.net.eval() 46 | self.opt.zero_grad() 47 | outputs = self.net(inputs) 48 | loss = self.loss(outputs, labels) 49 | loss.backward() 50 | grads = self.net.get_grads().clone().detach() 51 | self.opt.zero_grad() 52 | self.net.train() 53 | if len(grads.shape) == 1: 54 | grads = grads.unsqueeze(0) 55 | return grads 56 | 57 | def observe(self, inputs, labels, not_aug_inputs): 58 | 59 | real_batch_size = inputs.shape[0] 60 | self.buffer.drop_cache() 61 | self.buffer.reset_fathom() 62 | 63 | for _ in range(self.alj_nepochs): 64 | self.opt.zero_grad() 65 | if not self.buffer.is_empty(): 66 | buf_inputs, buf_labels = self.buffer.get_data( 67 | self.args.minibatch_size, transform=self.transform) 68 | tinputs = torch.cat((inputs, buf_inputs)) 69 | tlabels = torch.cat((labels, buf_labels)) 70 | else: 71 | tinputs = inputs 72 | tlabels = labels 73 | 74 | outputs = self.net(tinputs) 75 | loss = self.loss(outputs, tlabels) 76 | loss.backward() 77 | self.opt.step() 78 | 79 | self.buffer.add_data(examples=not_aug_inputs, 80 | labels=labels[:real_batch_size]) 81 | 82 | return loss.item() 83 | -------------------------------------------------------------------------------- /models/pnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim as optim 7 | from torch.optim import SGD 8 | import torch 9 | import torch.nn as nn 10 | from utils.conf import get_device 11 | from utils.args import * 12 | from datasets import get_dataset 13 | 14 | 15 | def get_parser() -> ArgumentParser: 16 | parser = ArgumentParser(description='Continual Learning via' 17 | ' Progressive Neural Networks.') 18 | add_management_args(parser) 19 | add_experiment_args(parser) 20 | return parser 21 | 22 | 23 | def get_backbone(bone, old_cols=None, x_shape=None): 24 | from backbone.MNISTMLP import MNISTMLP 25 | from backbone.MNISTMLP_PNN import MNISTMLP_PNN 26 | from backbone.ResNet18 import ResNet 27 | from backbone.ResNet18_PNN import resnet18_pnn 28 | 29 | if isinstance(bone, MNISTMLP): 30 | return MNISTMLP_PNN(bone.input_size, bone.output_size, old_cols) 31 | elif isinstance(bone, ResNet): 32 | return resnet18_pnn(bone.num_classes, bone.nf, old_cols, x_shape) 33 | else: 34 | raise NotImplementedError('Progressive Neural Networks is not implemented for this backbone') 35 | 36 | 37 | class Pnn(nn.Module): 38 | NAME = 'pnn' 39 | COMPATIBILITY = ['task-il'] 40 | 41 | def __init__(self, backbone, loss, args, transform): 42 | super(Pnn, self).__init__() 43 | self.loss = loss 44 | self.args = args 45 | self.transform = transform 46 | self.device = get_device() 47 | self.x_shape = None 48 | self.nets = [get_backbone(backbone).to(self.device)] 49 | self.net = self.nets[-1] 50 | self.opt = SGD(self.net.parameters(), lr=self.args.lr) 51 | 52 | self.soft = torch.nn.Softmax(dim=0) 53 | self.logsoft = torch.nn.LogSoftmax(dim=0) 54 | self.dataset = get_dataset(args) 55 | self.task_idx = 0 56 | 57 | def forward(self, x, task_label): 58 | if self.x_shape is None: 59 | self.x_shape = x.shape 60 | 61 | if self.task_idx == 0: 62 | out = self.net(x) 63 | else: 64 | self.nets[task_label].to(self.device) 65 | out = self.nets[task_label](x) 66 | if self.task_idx != task_label: 67 | self.nets[task_label].cpu() 68 | return out 69 | 70 | def end_task(self, dataset): 71 | # instantiate new column 72 | if self.task_idx == 4: 73 | return 74 | self.task_idx += 1 75 | self.nets[-1].cpu() 76 | self.nets.append(get_backbone(dataset.get_backbone(), self.nets, self.x_shape).to(self.device)) 77 | self.net = self.nets[-1] 78 | self.opt = optim.SGD(self.net.parameters(), lr=self.args.lr) 79 | 80 | def observe(self, inputs, labels, not_aug_inputs): 81 | if self.x_shape is None: 82 | self.x_shape = inputs.shape 83 | 84 | self.opt.zero_grad() 85 | outputs = self.net(inputs) 86 | loss = self.loss(outputs, labels) 87 | loss.backward() 88 | self.opt.step() 89 | 90 | return loss.item() 91 | -------------------------------------------------------------------------------- /utils/status.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from datetime import datetime 7 | import sys 8 | import os 9 | from utils.conf import base_path 10 | from typing import Any, Dict, Union 11 | from torch import nn 12 | from argparse import Namespace 13 | from datasets.utils.continual_dataset import ContinualDataset 14 | 15 | 16 | def create_stash(model: nn.Module, args: Namespace, 17 | dataset: ContinualDataset) -> Dict[Any, str]: 18 | """ 19 | Creates the dictionary where to save the model status. 20 | :param model: the model 21 | :param args: the current arguments 22 | :param dataset: the dataset at hand 23 | """ 24 | now = datetime.now() 25 | model_stash = {'task_idx': 0, 'epoch_idx': 0, 'batch_idx': 0} 26 | name_parts = [args.dataset, model.NAME] 27 | if 'buffer_size' in vars(args).keys(): 28 | name_parts.append('buf_' + str(args.buffer_size)) 29 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f")) 30 | model_stash['model_name'] = '/'.join(name_parts) 31 | model_stash['mean_accs'] = [] 32 | model_stash['args'] = args 33 | model_stash['backup_folder'] = os.path.join(base_path(), 'backups', 34 | dataset.SETTING, 35 | model_stash['model_name']) 36 | return model_stash 37 | 38 | 39 | def create_fake_stash(model: nn.Module, args: Namespace) -> Dict[Any, str]: 40 | """ 41 | Create a fake stash, containing just the model name. 42 | This is used in general continual, as it is useless to backup 43 | a lightweight MNIST-360 training. 44 | :param model: the model 45 | :param args: the arguments of the call 46 | :return: a dict containing a fake stash 47 | """ 48 | now = datetime.now() 49 | model_stash = {'task_idx': 0, 'epoch_idx': 0} 50 | name_parts = [args.dataset, model.NAME] 51 | if 'buffer_size' in vars(args).keys(): 52 | name_parts.append('buf_' + str(args.buffer_size)) 53 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f")) 54 | model_stash['model_name'] = '/'.join(name_parts) 55 | 56 | return model_stash 57 | 58 | 59 | def progress_bar(i: int, max_iter: int, epoch: Union[int, str], 60 | task_number: int, loss: float) -> None: 61 | """ 62 | Prints out the progress bar on the stderr file. 63 | :param i: the current iteration 64 | :param max_iter: the maximum number of iteration 65 | :param epoch: the epoch 66 | :param task_number: the task index 67 | :param loss: the current value of the loss function 68 | """ 69 | if not (i + 1) % 100 or (i + 1) == max_iter: 70 | progress = min(float((i + 1) / max_iter), 1) 71 | progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) 72 | print('\r[ {} ] Task {} | epoch {}: |{}| loss: {}'.format( 73 | datetime.now().strftime("%m-%d | %H:%M"), 74 | task_number + 1 if isinstance(task_number, int) else task_number, 75 | epoch, 76 | progress_bar, 77 | round(loss, 8) 78 | ), file=sys.stderr, end='', flush=True) 79 | -------------------------------------------------------------------------------- /models/mer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from utils.buffer import Buffer 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | 11 | 12 | def get_parser() -> ArgumentParser: 13 | parser = ArgumentParser(description='Continual Learning via' 14 | ' Meta-Experience Replay.') 15 | add_management_args(parser) 16 | add_experiment_args(parser) 17 | add_rehearsal_args(parser) 18 | # remove batch_size from parser 19 | for i in range(len(parser._actions)): 20 | if parser._actions[i].dest == 'batch_size': 21 | del parser._actions[i] 22 | break 23 | 24 | parser.add_argument('--beta', type=float, required=True, 25 | help='Within-batch update beta parameter.') 26 | parser.add_argument('--gamma', type=float, required=True, 27 | help='Across-batch update gamma parameter.') 28 | parser.add_argument('--batch_num', type=int, required=True, 29 | help='Number of batches extracted from the buffer.') 30 | 31 | return parser 32 | 33 | 34 | class Mer(ContinualModel): 35 | NAME = 'mer' 36 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 37 | 38 | def __init__(self, backbone, loss, args, transform): 39 | super(Mer, self).__init__(backbone, loss, args, transform) 40 | self.buffer = Buffer(self.args.buffer_size, self.device) 41 | assert args.batch_size == 1, 'Mer only works with batch_size=1' 42 | 43 | def draw_batches(self, inp, lab): 44 | batches = [] 45 | for i in range(self.args.batch_num): 46 | if not self.buffer.is_empty(): 47 | buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform) 48 | inputs = torch.cat((buf_inputs, inp.unsqueeze(0))) 49 | labels = torch.cat((buf_labels, torch.tensor([lab]).to(self.device))) 50 | batches.append((inputs, labels)) 51 | else: 52 | batches.append((inp.unsqueeze(0), torch.tensor([lab]).unsqueeze(0).to(self.device))) 53 | return batches 54 | 55 | def observe(self, inputs, labels, not_aug_inputs): 56 | 57 | batches = self.draw_batches(inputs, labels) 58 | theta_A0 = self.net.get_params().data.clone() 59 | 60 | for i in range(self.args.batch_num): 61 | theta_Wi0 = self.net.get_params().data.clone() 62 | 63 | batch_inputs, batch_labels = batches[i] 64 | 65 | # within-batch step 66 | self.opt.zero_grad() 67 | outputs = self.net(batch_inputs) 68 | loss = self.loss(outputs, batch_labels.squeeze(-1)) 69 | loss.backward() 70 | self.opt.step() 71 | 72 | # within batch reptile meta-update 73 | new_params = theta_Wi0 + self.args.beta * (self.net.get_params() - theta_Wi0) 74 | self.net.set_params(new_params) 75 | 76 | self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0), labels=labels) 77 | 78 | # across batch reptile meta-update 79 | new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() - theta_A0) 80 | self.net.set_params(new_new_params) 81 | 82 | return loss.item() 83 | -------------------------------------------------------------------------------- /utils/tb_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from utils.conf import base_path 8 | import os 9 | from argparse import Namespace 10 | from typing import Dict, Any 11 | import numpy as np 12 | 13 | 14 | class TensorboardLogger: 15 | def __init__(self, args: Namespace, setting: str) -> None: 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | self.settings = [setting] 19 | if setting == 'class-il': 20 | self.settings.append('task-il') 21 | self.loggers = {} 22 | self.name = args.model 23 | for a_setting in self.settings: 24 | self.loggers[a_setting] = SummaryWriter( 25 | os.path.join(base_path(), 'tensorboard_runs', a_setting, self.name)) 26 | config_text = ', '.join( 27 | ["%s=%s" % (name, getattr(args, name)) for name in args.__dir__() 28 | if not name.startswith('_')]) 29 | for a_logger in self.loggers.values(): 30 | a_logger.add_text('config', config_text) 31 | 32 | def get_name(self) -> str: 33 | """ 34 | :return: the name of the model 35 | """ 36 | return self.name 37 | 38 | def log_accuracy(self, all_accs: np.ndarray, all_mean_accs: np.ndarray, 39 | args: Namespace, task_number: int) -> None: 40 | """ 41 | Logs the current accuracy value for each task. 42 | :param all_accs: the accuracies (class-il, task-il) for each task 43 | :param all_mean_accs: the mean accuracies for (class-il, task-il) 44 | :param args: the arguments of the run 45 | :param task_number: the task index 46 | """ 47 | mean_acc_common, mean_acc_task_il = all_mean_accs 48 | for setting, a_logger in self.loggers.items(): 49 | mean_acc = mean_acc_task_il\ 50 | if setting == 'task-il' else mean_acc_common 51 | index = 1 if setting == 'task-il' else 0 52 | accs = [all_accs[index][kk] for kk in range(len(all_accs[0]))] 53 | for kk, acc in enumerate(accs): 54 | a_logger.add_scalar('acc_task%02d' % (kk + 1), acc, 55 | task_number * args.n_epochs) 56 | a_logger.add_scalar('acc_mean', mean_acc, task_number * args.n_epochs) 57 | 58 | def log_loss(self, loss: float, args: Namespace, epoch: int, 59 | task_number: int, iteration: int) -> None: 60 | """ 61 | Logs the loss value at each iteration. 62 | :param loss: the loss value 63 | :param args: the arguments of the run 64 | :param epoch: the epoch index 65 | :param task_number: the task index 66 | :param iteration: the current iteration 67 | """ 68 | for a_logger in self.loggers.values(): 69 | a_logger.add_scalar('loss', loss, task_number * args.n_epochs + epoch) 70 | 71 | def log_loss_gcl(self, loss: float, iteration: int) -> None: 72 | """ 73 | Logs the loss value at each iteration. 74 | :param loss: the loss value 75 | :param iteration: the current iteration 76 | """ 77 | for a_logger in self.loggers.values(): 78 | a_logger.add_scalar('loss', loss, iteration) 79 | 80 | def close(self) -> None: 81 | """ 82 | At the end of the execution, closes the logger. 83 | """ 84 | for a_logger in self.loggers.values(): 85 | a_logger.close() 86 | -------------------------------------------------------------------------------- /models/fdr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils.buffer import Buffer 7 | from utils.args import * 8 | from models.utils.continual_model import ContinualModel 9 | import torch 10 | 11 | 12 | def get_parser() -> ArgumentParser: 13 | parser = ArgumentParser(description='Continual learning via' 14 | ' Dark Experience Replay.') 15 | add_management_args(parser) 16 | add_experiment_args(parser) 17 | add_rehearsal_args(parser) 18 | parser.add_argument('--alpha', type=float, required=True, 19 | help='Penalty weight.') 20 | return parser 21 | 22 | 23 | class Fdr(ContinualModel): 24 | NAME = 'fdr' 25 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 26 | 27 | def __init__(self, backbone, loss, args, transform): 28 | super(Fdr, self).__init__(backbone, loss, args, transform) 29 | self.buffer = Buffer(self.args.buffer_size, self.device) 30 | self.current_task = 0 31 | self.i = 0 32 | self.soft = torch.nn.Softmax(dim=1) 33 | self.logsoft = torch.nn.LogSoftmax(dim=1) 34 | 35 | def end_task(self, dataset): 36 | self.current_task += 1 37 | examples_per_task = self.args.buffer_size // self.current_task 38 | 39 | if self.current_task > 1: 40 | buf_x, buf_log, buf_tl = self.buffer.get_all_data() 41 | self.buffer.empty() 42 | 43 | for ttl in buf_tl.unique(): 44 | idx = (buf_tl == ttl) 45 | ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx] 46 | first = min(ex.shape[0], examples_per_task) 47 | self.buffer.add_data( 48 | examples=ex[:first], 49 | logits=log[:first], 50 | task_labels=tasklab[:first] 51 | ) 52 | counter = 0 53 | with torch.no_grad(): 54 | for i, data in enumerate(dataset.train_loader): 55 | inputs, labels, not_aug_inputs = data 56 | inputs = inputs.to(self.device) 57 | not_aug_inputs = not_aug_inputs.to(self.device) 58 | outputs = self.net(inputs) 59 | if examples_per_task - counter < 0: 60 | break 61 | self.buffer.add_data(examples=not_aug_inputs[:(examples_per_task - counter)], 62 | logits=outputs.data[:(examples_per_task - counter)], 63 | task_labels=(torch.ones(self.args.batch_size) * 64 | (self.current_task - 1))[:(examples_per_task - counter)]) 65 | counter += self.args.batch_size 66 | 67 | def observe(self, inputs, labels, not_aug_inputs): 68 | self.i += 1 69 | 70 | self.opt.zero_grad() 71 | outputs = self.net(inputs) 72 | loss = self.loss(outputs, labels) 73 | loss.backward() 74 | self.opt.step() 75 | if not self.buffer.is_empty(): 76 | self.opt.zero_grad() 77 | buf_inputs, buf_logits, _ = self.buffer.get_data(self.args.minibatch_size, transform=self.transform) 78 | buf_outputs = self.net(buf_inputs) 79 | loss = torch.norm(self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean() 80 | assert not torch.isnan(loss) 81 | loss.backward() 82 | self.opt.step() 83 | 84 | return loss.item() 85 | -------------------------------------------------------------------------------- /datasets/seq_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchvision.datasets import MNIST 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from backbone.MNISTMLP import MNISTMLP 10 | import torch.nn.functional as F 11 | from utils.conf import base_path 12 | from PIL import Image 13 | import numpy as np 14 | from datasets.utils.validation import get_train_val 15 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 16 | from typing import Tuple 17 | 18 | 19 | class MyMNIST(MNIST): 20 | """ 21 | Overrides the MNIST dataset to change the getitem function. 22 | """ 23 | def __init__(self, root, train=True, transform=None, 24 | target_transform=None, download=False) -> None: 25 | self.not_aug_transform = transforms.ToTensor() 26 | super(MyMNIST, self).__init__(root, train, 27 | transform, target_transform, download) 28 | 29 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]: 30 | """ 31 | Gets the requested element from the dataset. 32 | :param index: index of the element to be returned 33 | :returns: tuple: (image, target) where target is index of the target class. 34 | """ 35 | img, target = self.data[index], self.targets[index] 36 | 37 | # doing this so that it is consistent with all other datasets 38 | # to return a PIL Image 39 | img = Image.fromarray(img.numpy(), mode='L') 40 | original_img = self.not_aug_transform(img.copy()) 41 | 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | 45 | if self.target_transform is not None: 46 | target = self.target_transform(target) 47 | 48 | if hasattr(self, 'logits'): 49 | return img, target, original_img, self.logits[index] 50 | 51 | return img, target, original_img 52 | 53 | 54 | class SequentialMNIST(ContinualDataset): 55 | 56 | NAME = 'seq-mnist' 57 | SETTING = 'class-il' 58 | N_CLASSES_PER_TASK = 2 59 | N_TASKS = 5 60 | TRANSFORM = None 61 | 62 | def get_data_loaders(self): 63 | transform = transforms.ToTensor() 64 | train_dataset = MyMNIST(base_path() + 'MNIST', 65 | train=True, download=True, transform=transform) 66 | if self.args.validation: 67 | train_dataset, test_dataset = get_train_val(train_dataset, 68 | transform, self.NAME) 69 | else: 70 | test_dataset = MNIST(base_path() + 'MNIST', 71 | train=False, download=True, transform=transform) 72 | 73 | train, test = store_masked_loaders(train_dataset, test_dataset, self) 74 | return train, test 75 | 76 | @staticmethod 77 | def get_backbone(): 78 | return MNISTMLP(28 * 28, SequentialMNIST.N_TASKS 79 | * SequentialMNIST.N_CLASSES_PER_TASK) 80 | 81 | @staticmethod 82 | def get_transform(): 83 | return None 84 | 85 | @staticmethod 86 | def get_loss(): 87 | return F.cross_entropy 88 | 89 | @staticmethod 90 | def get_normalization_transform(): 91 | return None 92 | 93 | @staticmethod 94 | def get_denormalization_transform(): 95 | return None 96 | 97 | @staticmethod 98 | def get_scheduler(model, args): 99 | return None -------------------------------------------------------------------------------- /datasets/perm_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchvision.datasets import MNIST 7 | import torchvision.transforms as transforms 8 | from datasets.transforms.permutation import Permutation 9 | from torch.utils.data import DataLoader 10 | from backbone.MNISTMLP import MNISTMLP 11 | import torch.nn.functional as F 12 | from utils.conf import base_path 13 | from PIL import Image 14 | from datasets.utils.validation import get_train_val 15 | from typing import Tuple 16 | from datasets.utils.continual_dataset import ContinualDataset 17 | 18 | 19 | def store_mnist_loaders(transform, setting): 20 | train_dataset = MyMNIST(base_path() + 'MNIST', 21 | train=True, download=True, transform=transform) 22 | if setting.args.validation: 23 | train_dataset, test_dataset = get_train_val(train_dataset, 24 | transform, setting.NAME) 25 | else: 26 | test_dataset = MNIST(base_path() + 'MNIST', 27 | train=False, download=True, transform=transform) 28 | 29 | train_loader = DataLoader(train_dataset, 30 | batch_size=setting.args.batch_size, shuffle=True) 31 | test_loader = DataLoader(test_dataset, 32 | batch_size=setting.args.batch_size, shuffle=False) 33 | setting.test_loaders.append(test_loader) 34 | setting.train_loader = train_loader 35 | 36 | return train_loader, test_loader 37 | 38 | 39 | class MyMNIST(MNIST): 40 | """ 41 | Overrides the MNIST dataset to change the getitem function. 42 | """ 43 | def __init__(self, root, train=True, transform=None, 44 | target_transform=None, download=False) -> None: 45 | super(MyMNIST, self).__init__(root, train, transform, 46 | target_transform, download) 47 | 48 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]: 49 | """ 50 | Gets the requested element from the dataset. 51 | :param index: index of the element to be returned 52 | :returns: tuple: (image, target) where target is index of the target class. 53 | """ 54 | img, target = self.data[index], int(self.targets[index]) 55 | 56 | # doing this so that it is consistent with all other datasets 57 | # to return a PIL Image 58 | img = Image.fromarray(img.numpy(), mode='L') 59 | 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | if self.target_transform is not None: 64 | target = self.target_transform(target) 65 | 66 | return img, target, img 67 | 68 | 69 | class PermutedMNIST(ContinualDataset): 70 | 71 | NAME = 'perm-mnist' 72 | SETTING = 'domain-il' 73 | N_CLASSES_PER_TASK = 10 74 | N_TASKS = 20 75 | 76 | def get_data_loaders(self): 77 | transform = transforms.Compose((transforms.ToTensor(), Permutation())) 78 | train, test = store_mnist_loaders(transform, self) 79 | return train, test 80 | 81 | @staticmethod 82 | def get_backbone(): 83 | return MNISTMLP(28 * 28, PermutedMNIST.N_CLASSES_PER_TASK) 84 | 85 | @staticmethod 86 | def get_transform(): 87 | return None 88 | 89 | @staticmethod 90 | def get_normalization_transform(): 91 | return None 92 | 93 | @staticmethod 94 | def get_denormalization_transform(): 95 | return None 96 | 97 | @staticmethod 98 | def get_loss(): 99 | return F.cross_entropy 100 | 101 | @staticmethod 102 | def get_scheduler(model, args): 103 | return None 104 | -------------------------------------------------------------------------------- /utils/debug.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import math 9 | import torch 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.nn import functional as F 13 | import math 14 | 15 | def euclid2polar(feat): 16 | norm_x = torch.norm(feat.clone(), 2, 1, keepdim=True) 17 | feat = feat / norm_x 18 | feat_2 = feat * feat 19 | length = feat.shape[1] - 1 20 | for idx in range(length): 21 | sum_ = torch.sum(feat_2[:,idx:], dim=1) 22 | if idx != length -1: 23 | polar = torch.arccos(feat[:, idx]/torch.sqrt(sum_)) 24 | else: 25 | polar = torch.arccos(feat[:, idx]/torch.sqrt(sum_)) 26 | if idx == 0: 27 | polar_feat = polar 28 | else: 29 | if len(polar_feat.shape) <2: 30 | polar_feat, polar = polar_feat.unsqueeze(-1), polar.unsqueeze(-1) 31 | else: 32 | polar = polar.unsqueeze(-1) 33 | polar_feat = torch.cat((polar_feat, polar), dim=1) 34 | return polar_feat 35 | 36 | def polar2euclid(feat): 37 | length = feat.shape[1] + 1 38 | sin_feat = torch.sin(feat) 39 | cos_feat = torch.cos(feat) 40 | 41 | sin_product = torch.cumprod(sin_feat, dim=1) 42 | for idx in range(length): 43 | if idx == 0 : 44 | euclid = cos_feat[:, 0] 45 | else: 46 | if idx != length -1 : 47 | euclid = sin_product[:, idx -1] 48 | euclid = euclid * cos_feat[: , idx] 49 | else: 50 | euclid = sin_product[:, idx -2] 51 | euclid = euclid * sin_feat[: , idx-1] 52 | if idx == 0: 53 | euclid_feat = euclid 54 | else: 55 | if len(euclid_feat.shape) <2: 56 | euclid_feat, euclid = euclid_feat.unsqueeze(-1), euclid.unsqueeze(-1) 57 | else: 58 | euclid = euclid.unsqueeze(-1) 59 | euclid_feat = torch.cat((euclid_feat, euclid), dim=1) 60 | 61 | return euclid_feat 62 | # feat = torch.normal(0.0, 1, size=(32,512)) 63 | # print(0,feat[0]) 64 | # feat = euclid2polar(feat) 65 | # print(1,feat[0]) 66 | # polar_noise = torch.from_numpy(np.random.vonmises(0, 0000000, (feat.shape[0], feat.shape[1]))) 67 | # polar_noise.to(feat) 68 | # feat = feat + polar_noise 69 | # print(2,feat[0]) 70 | # feat = polar2euclid(feat) 71 | # print(3,feat[0]) 72 | 73 | def visualize_2d(feat, labels, step): 74 | feat = feat.numpy() 75 | plt.figure(figsize=(6, 6)) 76 | plt.ion() 77 | c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', 78 | '#ff00ff', '#990000', '#999900', '#009900', '#009999'] 79 | plt.clf() 80 | for i in range(10): 81 | plt.plot(feat[:, 0], feat[:, 1], '.') 82 | XMax = np.max(feat[:,0]) 83 | XMin = np.min(feat[:,1]) 84 | YMax = np.max(feat[:,0]) 85 | YMin = np.min(feat[:,1]) 86 | 87 | plt.xlim(xmin=XMin,xmax=XMax) 88 | plt.ylim(ymin=YMin,ymax=YMax) 89 | plt.savefig('./%s.pdf' % str(step)) 90 | 91 | def visualize_3d(feat, labels, step): 92 | feat = feat.numpy() 93 | fig = plt.figure(figsize=(8, 8)) 94 | ax = fig.add_subplot(projection='3d') 95 | 96 | 97 | ax.scatter(feat[:,0], feat[:,1], feat[:,2]) 98 | plt.show() 99 | plt.savefig('./%s.pdf' % str(step)) 100 | plt.show() 101 | 102 | # feat = torch.from_numpy(np.random.vonmises(0, 1000, (1000, 2))) 103 | # feat = feat 104 | # mask = feat < 0 105 | # copy_feat = copy.deepcopy(feat) 106 | # feat[mask] = -feat[mask] 107 | # # feat[:, :-1] = copy_feat[:, :-1] 108 | # feat = polar2euclid(feat) 109 | # visualize_3d(feat, feat, 'vmf') 110 | 111 | # feat2 = euclid2polar(feat) 112 | 113 | # feat2 = polar2euclid(feat2) 114 | 115 | # visualize_3d(feat2, feat2, 'vmf2') 116 | 117 | # true_ = feat == feat2 118 | # print(true_) 119 | # print(feat,feat2) -------------------------------------------------------------------------------- /models/gdumb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils.args import * 7 | from models.utils.continual_model import ContinualModel 8 | from torch.optim import SGD, lr_scheduler 9 | import math 10 | from utils.buffer import Buffer 11 | import torch 12 | from utils.augmentations import cutmix_data 13 | import numpy as np 14 | from utils.status import progress_bar 15 | 16 | def get_parser() -> ArgumentParser: 17 | parser = ArgumentParser(description='Continual Learning via' 18 | ' Progressive Neural Networks.') 19 | add_management_args(parser) 20 | add_rehearsal_args(parser) 21 | parser.add_argument('--maxlr', type=float, default=5e-2, 22 | help='Penalty weight.') 23 | parser.add_argument('--minlr', type=float, default=5e-4, 24 | help='Penalty weight.') 25 | parser.add_argument('--fitting_epochs', type=int, default=256, 26 | help='Penalty weight.') 27 | parser.add_argument('--cutmix_alpha', type=float, default=None, 28 | help='Penalty weight.') 29 | add_experiment_args(parser) 30 | return parser 31 | 32 | def fit_buffer(self, epochs): 33 | for epoch in range(epochs): 34 | 35 | optimizer = SGD(self.net.parameters(), lr=self.args.maxlr, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd, nesterov=self.args.optim_nesterov) 36 | scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=self.args.minlr) 37 | 38 | if epoch <= 0: # Warm start of 1 epoch 39 | for param_group in optimizer.param_groups: 40 | param_group['lr'] = self.args.maxlr * 0.1 41 | elif epoch == 1: # Then set to maxlr 42 | for param_group in optimizer.param_groups: 43 | param_group['lr'] = self.args.maxlr 44 | else: 45 | scheduler.step() 46 | 47 | all_inputs, all_labels = self.buffer.get_data( 48 | len(self.buffer.examples), transform=self.transform) 49 | 50 | while len(all_inputs): 51 | optimizer.zero_grad() 52 | buf_inputs, buf_labels = all_inputs[:self.args.batch_size], all_labels[:self.args.batch_size] 53 | all_inputs, all_labels = all_inputs[self.args.batch_size:], all_labels[self.args.batch_size:] 54 | 55 | if self.args.cutmix_alpha is not None: 56 | inputs, labels_a, labels_b, lam = cutmix_data(x=buf_inputs.cpu(), y=buf_labels.cpu(), alpha=self.args.cutmix_alpha) 57 | buf_inputs = inputs.to(self.device) 58 | buf_labels_a = labels_a.to(self.device) 59 | buf_labels_b = labels_b.to(self.device) 60 | buf_outputs = self.net(buf_inputs) 61 | loss = lam * self.loss(buf_outputs, buf_labels_a) + (1 - lam) * self.loss(buf_outputs, buf_labels_b) 62 | else: 63 | buf_outputs = self.net(buf_inputs) 64 | loss = self.loss(buf_outputs, buf_labels) 65 | 66 | loss.backward() 67 | optimizer.step() 68 | progress_bar(epoch, epochs, 1, 'G', loss.item()) 69 | 70 | class GDumb(ContinualModel): 71 | NAME = 'gdumb' 72 | COMPATIBILITY = ['class-il', 'task-il'] 73 | 74 | def __init__(self, backbone, loss, args, transform): 75 | super(GDumb, self).__init__(backbone, loss, args, transform) 76 | self.buffer = Buffer(self.args.buffer_size, self.device) 77 | self.task = 0 78 | 79 | def observe(self, inputs, labels, not_aug_inputs): 80 | self.buffer.add_data(examples=not_aug_inputs, 81 | labels=labels) 82 | return 0 83 | 84 | def end_task(self, dataset): 85 | # new model 86 | self.task += 1 87 | if not (self.task == dataset.N_TASKS): 88 | return 89 | self.net = dataset.get_backbone().to(self.device) 90 | fit_buffer(self, self.args.fitting_epochs) -------------------------------------------------------------------------------- /utils/simclrloss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | Source: https://github.com/HobbitLong/SupContrast/blob/master/losses.py 5 | """ 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class SupConLoss(nn.Module): 13 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 14 | It also supports the unsupervised contrastive loss in SimCLR""" 15 | def __init__(self, temperature=0.07, contrast_mode='all', 16 | base_temperature=0.07, reduction='mean'): 17 | super(SupConLoss, self).__init__() 18 | self.temperature = temperature 19 | self.contrast_mode = contrast_mode 20 | self.base_temperature = base_temperature 21 | self.reduction = reduction 22 | 23 | def forward(self, features, labels=None, mask=None): 24 | """Compute loss for model. If both `labels` and `mask` are None, 25 | it degenerates to SimCLR unsupervised loss: 26 | https://arxiv.org/pdf/2002.05709.pdf 27 | Args: 28 | features: hidden vector of shape [bsz, n_views, ...]. 29 | labels: ground truth of shape [bsz]. 30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 31 | has the same class as sample i. Can be asymmetric. 32 | Returns: 33 | A loss scalar. 34 | """ 35 | device = features.device 36 | 37 | if len(features.shape) < 3: 38 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 39 | 'at least 3 dimensions are required') 40 | if len(features.shape) > 3: 41 | features = features.view(features.shape[0], features.shape[1], -1) 42 | 43 | batch_size = features.shape[0] 44 | if labels is not None and mask is not None: 45 | raise ValueError('Cannot define both `labels` and `mask`') 46 | elif labels is None and mask is None: 47 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 48 | elif labels is not None: 49 | labels = labels.contiguous().view(-1, 1) 50 | if labels.shape[0] != batch_size: 51 | raise ValueError('Num of labels does not match num of features') 52 | mask = torch.eq(labels, labels.T).float().to(device) 53 | else: 54 | mask = mask.float().to(device) 55 | 56 | contrast_count = features.shape[1] 57 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 58 | if self.contrast_mode == 'one': 59 | anchor_feature = features[:, 0] 60 | anchor_count = 1 61 | elif self.contrast_mode == 'all': 62 | anchor_feature = contrast_feature 63 | anchor_count = contrast_count 64 | else: 65 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 66 | 67 | # compute logits 68 | anchor_dot_contrast = torch.div( 69 | torch.matmul(anchor_feature, contrast_feature.T), 70 | self.temperature) 71 | # for numerical stability 72 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 73 | logits = anchor_dot_contrast - logits_max.detach() 74 | 75 | # tile mask 76 | mask = mask.repeat(anchor_count, contrast_count) 77 | # mask-out self-contrast cases 78 | logits_mask = torch.scatter( 79 | torch.ones_like(mask), 80 | 1, 81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 82 | 0 83 | ) 84 | mask = mask * logits_mask 85 | 86 | # compute log_prob 87 | exp_logits = torch.exp(logits) * logits_mask 88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 89 | 90 | # compute mean of log-likelihood over positive 91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 92 | 93 | # loss 94 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 95 | loss = loss.view(anchor_count, batch_size).mean(0) 96 | 97 | return loss.mean() if self.reduction == 'mean' else loss.sum() -------------------------------------------------------------------------------- /models/lwf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from datasets import get_dataset 8 | from torch.optim import SGD 9 | from utils.args import * 10 | from models.utils.continual_model import ContinualModel 11 | 12 | 13 | def get_parser() -> ArgumentParser: 14 | parser = ArgumentParser(description='Continual learning via' 15 | ' Learning without Forgetting.') 16 | add_management_args(parser) 17 | add_experiment_args(parser) 18 | parser.add_argument('--alpha', type=float, default=0.5, 19 | help='Penalty weight.') 20 | parser.add_argument('--softmax_temp', type=float, default=2, 21 | help='Temperature of the softmax function.') 22 | return parser 23 | 24 | 25 | def smooth(logits, temp, dim): 26 | log = logits ** (1 / temp) 27 | return log / torch.sum(log, dim).unsqueeze(1) 28 | 29 | 30 | def modified_kl_div(old, new): 31 | return -torch.mean(torch.sum(old * torch.log(new), 1)) 32 | 33 | 34 | class Lwf(ContinualModel): 35 | NAME = 'lwf' 36 | COMPATIBILITY = ['class-il', 'task-il'] 37 | 38 | def __init__(self, backbone, loss, args, transform): 39 | super(Lwf, self).__init__(backbone, loss, args, transform) 40 | self.old_net = None 41 | self.soft = torch.nn.Softmax(dim=1) 42 | self.logsoft = torch.nn.LogSoftmax(dim=1) 43 | self.dataset = get_dataset(args) 44 | self.current_task = 0 45 | self.cpt = get_dataset(args).N_CLASSES_PER_TASK 46 | nc = get_dataset(args).N_TASKS * self.cpt 47 | self.eye = torch.tril(torch.ones((nc, nc))).bool().to(self.device) 48 | 49 | def begin_task(self, dataset): 50 | self.net.eval() 51 | if self.current_task > 0: 52 | # warm-up 53 | opt = SGD(self.net.classifier.parameters(), lr=self.args.lr) 54 | for epoch in range(self.args.n_epochs): 55 | for i, data in enumerate(dataset.train_loader): 56 | inputs, labels, not_aug_inputs = data 57 | inputs, labels = inputs.to(self.device), labels.to(self.device) 58 | opt.zero_grad() 59 | with torch.no_grad(): 60 | feats = self.net(inputs, returnt='features') 61 | mask = self.eye[(self.current_task + 1) * self.cpt - 1] ^ self.eye[self.current_task * self.cpt - 1] 62 | outputs = self.net.classifier(feats)[:, mask] 63 | loss = self.loss(outputs, labels - self.current_task * self.cpt) 64 | loss.backward() 65 | opt.step() 66 | 67 | logits = [] 68 | with torch.no_grad(): 69 | for i in range(0, dataset.train_loader.dataset.data.shape[0], self.args.batch_size): 70 | inputs = torch.stack([dataset.train_loader.dataset.__getitem__(j)[2] 71 | for j in range(i, min(i + self.args.batch_size, 72 | len(dataset.train_loader.dataset)))]) 73 | log = self.net(inputs.to(self.device)).cpu() 74 | logits.append(log) 75 | setattr(dataset.train_loader.dataset, 'logits', torch.cat(logits)) 76 | self.net.train() 77 | 78 | self.current_task += 1 79 | 80 | def observe(self, inputs, labels, not_aug_inputs, logits=None): 81 | self.opt.zero_grad() 82 | outputs = self.net(inputs) 83 | 84 | mask = self.eye[self.current_task * self.cpt - 1] 85 | loss = self.loss(outputs[:, mask], labels) 86 | if logits is not None: 87 | mask = self.eye[(self.current_task - 1) * self.cpt - 1] 88 | loss += self.args.alpha * modified_kl_div(smooth(self.soft(logits[:, mask]).to(self.device), 2, 1), 89 | smooth(self.soft(outputs[:, mask]), 2, 1)) 90 | 91 | loss.backward() 92 | self.opt.step() 93 | 94 | return loss.item() 95 | -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | def rand_bbox(size, lam): 11 | W = size[2] 12 | H = size[3] 13 | cut_rat = np.sqrt(1. - lam) 14 | cut_w = np.int(W * cut_rat) 15 | cut_h = np.int(H * cut_rat) 16 | 17 | # uniform 18 | cx = np.random.randint(W) 19 | cy = np.random.randint(H) 20 | 21 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 22 | bby1 = np.clip(cy - cut_h // 2, 0, H) 23 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 24 | bby2 = np.clip(cy + cut_h // 2, 0, H) 25 | 26 | return bbx1, bby1, bbx2, bby2 27 | 28 | def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5): 29 | assert(alpha > 0) 30 | # generate mixed sample 31 | lam = np.random.beta(alpha, alpha) 32 | 33 | batch_size = x.size()[0] 34 | index = torch.randperm(batch_size) 35 | 36 | if torch.cuda.is_available(): 37 | index = index.cuda() 38 | 39 | y_a, y_b = y, y[index] 40 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) 41 | x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] 42 | 43 | # adjust lambda to exactly match pixel ratio 44 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 45 | return x, y_a, y_b, lam 46 | 47 | def normalize(x, mean, std): 48 | assert len(x.shape) == 4 49 | return (x - torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)) \ 50 | / torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 51 | 52 | def random_flip(x): 53 | assert len(x.shape) == 4 54 | mask = torch.rand(x.shape[0]) < 0.5 55 | x[mask] = x[mask].flip(3) 56 | return x 57 | 58 | def random_grayscale(x, prob=0.2): 59 | assert len(x.shape) == 4 60 | mask = torch.rand(x.shape[0]) < prob 61 | x[mask] = (x[mask] * torch.tensor([[0.299,0.587,0.114]]).unsqueeze(2).unsqueeze(2).to(x.device)).sum(1, keepdim=True).repeat_interleave(3, 1) 62 | return x 63 | 64 | def random_crop(x, padding): 65 | assert len(x.shape) == 4 66 | crop_x = torch.randint(-padding, padding, size=(x.shape[0],)) 67 | crop_y = torch.randint(-padding, padding, size=(x.shape[0],)) 68 | 69 | crop_x_start, crop_y_start = crop_x + padding, crop_y + padding 70 | crop_x_end, crop_y_end = crop_x_start + x.shape[-1], crop_y_start + x.shape[-2] 71 | 72 | oboe = F.pad(x, (padding, padding, padding, padding)) 73 | mask_x = torch.arange(x.shape[-1] + padding * 2).repeat(x.shape[0], x.shape[-1] + padding * 2, 1) 74 | mask_y = mask_x.transpose(1,2) 75 | mask_x = ((mask_x >= crop_x_start.unsqueeze(1).unsqueeze(2)) & (mask_x < crop_x_end.unsqueeze(1).unsqueeze(2))) 76 | mask_y = ((mask_y >= crop_y_start.unsqueeze(1).unsqueeze(2)) & (mask_y < crop_y_end.unsqueeze(1).unsqueeze(2))) 77 | return oboe[mask_x.unsqueeze(1).repeat(1,x.shape[1],1,1) * mask_y.unsqueeze(1).repeat(1,x.shape[1],1,1)].reshape(x.shape[0], 3, x.shape[2], x.shape[3]) 78 | 79 | class soft_aug(): 80 | 81 | def __init__(self, mean, std): 82 | self.mean = mean 83 | self.std = std 84 | 85 | def __call__(self, x): 86 | return normalize( 87 | random_flip( 88 | random_crop(x, 4) 89 | ), 90 | self.mean, self.std) 91 | class strong_aug(): 92 | 93 | def __init__(self, size, mean, std): 94 | from torchvision import transforms 95 | self.transform = transforms.Compose([ 96 | transforms.ToPILImage(), 97 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 98 | transforms.RandomApply([ 99 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 100 | ], p=0.8), 101 | transforms.ToTensor() 102 | ]) 103 | self.mean = mean 104 | self.std = std 105 | 106 | def __call__(self, x): 107 | flip = random_flip(x) 108 | return normalize(random_grayscale( 109 | torch.stack( 110 | [self.transform(a) for a in flip] 111 | )), self.mean, self.std) 112 | -------------------------------------------------------------------------------- /models/utils/continual_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | from torch.optim import SGD 8 | import torch 9 | import torchvision 10 | from argparse import Namespace 11 | from utils.conf import get_device 12 | import numpy as np 13 | import copy 14 | from torch.nn import functional as F 15 | import math 16 | from torch.optim import SGD 17 | from collections import OrderedDict 18 | EPS = 1E-20 19 | 20 | class ContinualModel(nn.Module): 21 | """ 22 | Continual learning model. 23 | """ 24 | NAME = None 25 | COMPATIBILITY = [] 26 | 27 | def __init__(self, backbone: nn.Module, loss: nn.Module, 28 | args: Namespace, transform: torchvision.transforms) -> None: 29 | super(ContinualModel, self).__init__() 30 | 31 | self.net = backbone 32 | self.loss = loss 33 | self.args = args 34 | self.transform = transform 35 | self.opt = SGD(self.net.parameters(), lr=self.args.lr) 36 | self.device = get_device() 37 | 38 | self.current_task = 0 39 | self.buff_feature = [] 40 | self.buff_labels = [] 41 | self.new_feature = [] 42 | self.new_labels = [] 43 | self.buff_noise = [] 44 | self.EMA_cLass_mean = {} 45 | self.proxy = copy.deepcopy(self.net) 46 | self.proxy.to(self.device) 47 | self.proxy_optim = SGD(self.proxy.parameters(), lr=self.args.lr) 48 | self.theta_list = [] 49 | 50 | self.all_iteration = 0 51 | for i in range(201): 52 | self.EMA_cLass_mean[i] = 0.01 + torch.zeros(512).to(self.device) 53 | 54 | def diff_in_weights(self, model, proxy): 55 | diff_dict = OrderedDict() 56 | model_state_dict = model.state_dict() 57 | proxy_state_dict = proxy.state_dict() 58 | for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()): 59 | if len(old_w.size()) <= 1: 60 | continue 61 | if 'weight' in old_k: 62 | diff_w = new_w - old_w 63 | diff_dict[old_k] = old_w.norm() / (diff_w.norm() + EPS) * diff_w 64 | return diff_dict 65 | 66 | def calc_awp(self, inputs, targets): 67 | if self.all_iteration > 1000: 68 | if self.all_iteration % self.args.inner_iter ==0: 69 | self.proxy.load_state_dict(self.net.state_dict()) 70 | else: 71 | self.proxy.load_state_dict(self.net.state_dict()) 72 | self.proxy.train() 73 | 74 | if self.args.advloss == 'nega': 75 | loss = - F.cross_entropy(self.proxy(inputs), targets) 76 | else: 77 | loss = F.cross_entropy(self.proxy(inputs), targets) 78 | loss = self.args.gamma_loss * loss 79 | self.proxy_optim.zero_grad() 80 | loss.backward() 81 | self.proxy_optim.step() 82 | 83 | # the adversary weight perturb 84 | diff = self.diff_in_weights(self.net, self.proxy) 85 | return diff 86 | 87 | def norm_scale(self, theta, theta_limit): 88 | theta = torch.arccos(theta) 89 | theta_limit = torch.tensor(theta_limit/180 * math.pi).to(theta.device) 90 | norm_scale = torch.sin(theta_limit) / (torch.sin(theta) * torch.cos(theta_limit) - torch.cos(theta) * torch.sin(theta_limit)) 91 | return norm_scale 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | """ 95 | Computes a forward pass. 96 | :param x: batch of inputs 97 | :param task_label: some models require the task label 98 | :return: the result of the computation 99 | """ 100 | return self.net(x) 101 | 102 | def observe(self, inputs: torch.Tensor, labels: torch.Tensor, 103 | not_aug_inputs: torch.Tensor) -> float: 104 | """ 105 | Compute a training step over a given batch of examples. 106 | :param inputs: batch of examples 107 | :param labels: ground-truth labels 108 | :param kwargs: some methods could require additional parameters 109 | :return: the value of the loss function 110 | """ 111 | pass -------------------------------------------------------------------------------- /datasets/seq_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchvision.datasets import CIFAR10 7 | import torchvision.transforms as transforms 8 | from backbone.ResNet18 import resnet18 9 | import torch.nn.functional as F 10 | from datasets.seq_tinyimagenet import base_path 11 | from PIL import Image 12 | from datasets.utils.validation import get_train_val 13 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 14 | from typing import Tuple 15 | from datasets.transforms.denormalization import DeNormalize 16 | 17 | class MyCIFAR10(CIFAR10): 18 | """ 19 | Overrides the CIFAR10 dataset to change the getitem function. 20 | """ 21 | def __init__(self, root, train=True, transform=None, 22 | target_transform=None, download=False) -> None: 23 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()]) 24 | super(MyCIFAR10, self).__init__(root, train, transform, target_transform, download) 25 | 26 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]: 27 | """ 28 | Gets the requested element from the dataset. 29 | :param index: index of the element to be returned 30 | :returns: tuple: (image, target) where target is index of the target class. 31 | """ 32 | img, target = self.data[index], self.targets[index] 33 | 34 | # to return a PIL Image 35 | img = Image.fromarray(img, mode='RGB') 36 | original_img = img.copy() 37 | 38 | not_aug_img = self.not_aug_transform(original_img) 39 | 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | 43 | if self.target_transform is not None: 44 | target = self.target_transform(target) 45 | 46 | if hasattr(self, 'logits'): 47 | return img, target, not_aug_img, self.logits[index] 48 | 49 | return img, target, not_aug_img 50 | 51 | 52 | class SequentialCIFAR10(ContinualDataset): 53 | 54 | NAME = 'seq-cifar10' 55 | SETTING = 'class-il' 56 | N_CLASSES_PER_TASK = 2 57 | N_TASKS = 5 58 | TRANSFORM = transforms.Compose( 59 | [transforms.RandomCrop(32, padding=4), 60 | transforms.RandomHorizontalFlip(), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.4914, 0.4822, 0.4465), 63 | (0.2470, 0.2435, 0.2615))]) 64 | 65 | def get_data_loaders(self): 66 | transform = self.TRANSFORM 67 | 68 | test_transform = transforms.Compose( 69 | [transforms.ToTensor(), self.get_normalization_transform()]) 70 | 71 | train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True, 72 | download=True, transform=transform) 73 | if self.args.validation: 74 | train_dataset, test_dataset = get_train_val(train_dataset, 75 | test_transform, self.NAME) 76 | else: 77 | test_dataset = CIFAR10(base_path() + 'CIFAR10',train=False, 78 | download=True, transform=test_transform) 79 | 80 | train, test = store_masked_loaders(train_dataset, test_dataset, self) 81 | return train, test 82 | 83 | @staticmethod 84 | def get_transform(): 85 | transform = transforms.Compose( 86 | [transforms.ToPILImage(), SequentialCIFAR10.TRANSFORM]) 87 | return transform 88 | 89 | @staticmethod 90 | def get_backbone(): 91 | return resnet18(SequentialCIFAR10.N_CLASSES_PER_TASK 92 | * SequentialCIFAR10.N_TASKS) 93 | 94 | @staticmethod 95 | def get_loss(): 96 | return F.cross_entropy 97 | 98 | @staticmethod 99 | def get_normalization_transform(): 100 | transform = transforms.Normalize((0.4914, 0.4822, 0.4465), 101 | (0.2470, 0.2435, 0.2615)) 102 | return transform 103 | 104 | @staticmethod 105 | def get_denormalization_transform(): 106 | transform = DeNormalize((0.4914, 0.4822, 0.4465), 107 | (0.2470, 0.2435, 0.2615)) 108 | return transform 109 | 110 | @staticmethod 111 | def get_scheduler(model, args): 112 | return None 113 | -------------------------------------------------------------------------------- /models/joint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.optim import SGD 7 | 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | from datasets.utils.validation import ValidationDataset 11 | from utils.status import progress_bar 12 | import torch 13 | import numpy as np 14 | import math 15 | from torchvision import transforms 16 | 17 | 18 | def get_parser() -> ArgumentParser: 19 | parser = ArgumentParser(description='Joint training: a strong, simple baseline.') 20 | add_management_args(parser) 21 | add_experiment_args(parser) 22 | return parser 23 | 24 | 25 | class Joint(ContinualModel): 26 | NAME = 'joint' 27 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] 28 | 29 | def __init__(self, backbone, loss, args, transform): 30 | super(Joint, self).__init__(backbone, loss, args, transform) 31 | self.old_data = [] 32 | self.old_labels = [] 33 | self.current_task = 0 34 | 35 | def end_task(self, dataset): 36 | if dataset.SETTING != 'domain-il': 37 | self.old_data.append(dataset.train_loader.dataset.data) 38 | self.old_labels.append(torch.tensor(dataset.train_loader.dataset.targets)) 39 | self.current_task += 1 40 | 41 | # # for non-incremental joint training 42 | if len(dataset.test_loaders) != dataset.N_TASKS: return 43 | 44 | # reinit network 45 | self.net = dataset.get_backbone() 46 | self.net.to(self.device) 47 | self.net.train() 48 | self.opt = SGD(self.net.parameters(), lr=self.args.lr) 49 | 50 | # prepare dataloader 51 | all_data, all_labels = None, None 52 | for i in range(len(self.old_data)): 53 | if all_data is None: 54 | all_data = self.old_data[i] 55 | all_labels = self.old_labels[i] 56 | else: 57 | all_data = np.concatenate([all_data, self.old_data[i]]) 58 | all_labels = np.concatenate([all_labels, self.old_labels[i]]) 59 | 60 | transform = dataset.TRANSFORM if dataset.TRANSFORM is not None else transforms.ToTensor() 61 | temp_dataset = ValidationDataset(all_data, all_labels, transform=transform) 62 | loader = torch.utils.data.DataLoader(temp_dataset, batch_size=self.args.batch_size, shuffle=True) 63 | 64 | # train 65 | for e in range(self.args.n_epochs): 66 | for i, batch in enumerate(loader): 67 | inputs, labels = batch 68 | inputs, labels = inputs.to(self.device), labels.to(self.device) 69 | 70 | self.opt.zero_grad() 71 | outputs = self.net(inputs) 72 | loss = self.loss(outputs, labels.long()) 73 | loss.backward() 74 | self.opt.step() 75 | progress_bar(i, len(loader), e, 'J', loss.item()) 76 | else: 77 | self.old_data.append(dataset.train_loader) 78 | # train 79 | if len(dataset.test_loaders) != dataset.N_TASKS: return 80 | 81 | all_inputs = [] 82 | all_labels = [] 83 | for source in self.old_data: 84 | for x, l, _ in source: 85 | all_inputs.append(x) 86 | all_labels.append(l) 87 | all_inputs = torch.cat(all_inputs) 88 | all_labels = torch.cat(all_labels) 89 | bs = self.args.batch_size 90 | scheduler = dataset.get_scheduler(self, self.args) 91 | 92 | for e in range(self.args.n_epochs): 93 | order = torch.randperm(len(all_inputs)) 94 | for i in range(int(math.ceil(len(all_inputs) / bs))): 95 | inputs = all_inputs[order][i * bs: (i+1) * bs] 96 | labels = all_labels[order][i * bs: (i+1) * bs] 97 | inputs, labels = inputs.to(self.device), labels.to(self.device) 98 | self.opt.zero_grad() 99 | outputs = self.net(inputs) 100 | loss = self.loss(outputs, labels.long()) 101 | loss.backward() 102 | self.opt.step() 103 | progress_bar(i, int(math.ceil(len(all_inputs) / bs)), e, 'J', loss.item()) 104 | 105 | if scheduler is not None: 106 | scheduler.step() 107 | 108 | def observe(self, inputs, labels, not_aug_inputs): 109 | return 0 110 | -------------------------------------------------------------------------------- /utils/feature_gradient_svd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use("Agg") 5 | import matplotlib.pyplot as plt 6 | import copy 7 | import math 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from torch.nn import functional as F 12 | import math 13 | 14 | def visual_svd(s_list): 15 | x = np.arange(0,512) 16 | from matplotlib.ticker import MaxNLocator 17 | plt.ion() 18 | plt.clf() 19 | for i in range(len(s_list)): 20 | s_list[i] = s_list[i] / torch.sum(s_list[i]) 21 | plt.plot(x, torch.log(s_list[3]), label=method_list[3]) 22 | plt.plot(x, torch.log(s_list[0]), label=method_list[0]) 23 | plt.plot(x, torch.log(s_list[1]), label=method_list[1]) 24 | plt.plot(x, torch.log(s_list[2]), label=method_list[2]) 25 | plt.plot(x, torch.log(s_list[4]), label=method_list[4]) 26 | plt.xlabel('Dimension', fontsize=14) 27 | plt.ylabel('Log Singular Value', fontsize=14) 28 | plt.legend() 29 | 30 | ax = plt.gca() 31 | bwith = 1. 32 | ax.spines['top'].set_linewidth(bwith) 33 | ax.spines['right'].set_linewidth(bwith) 34 | ax.spines['bottom'].set_linewidth(bwith) 35 | ax.spines['left'].set_linewidth(bwith) 36 | plt.tick_params(labelsize=14) 37 | plt.savefig('./svd.pdf') 38 | plt.show 39 | 40 | def trans_feat(feat, label, ori_feat, new_feat, new_label): 41 | feat_all = [] 42 | label_all = [] 43 | new_feat_all = [] 44 | new_label_all = [] 45 | ori_feat_all = [] 46 | for len_i in range(label.shape[0]): 47 | if len(feat[len_i]) == len(feat[2]): 48 | feat_all.append(feat[len_i]) 49 | label_all.append(label[len_i]) 50 | ori_feat_all.append(ori_feat[len_i]) 51 | new_feat_all.append(new_feat[len_i]) 52 | new_label_all.append(new_label[len_i]) 53 | feat = torch.tensor(list(feat_all)) 54 | label = torch.tensor(list(label_all)) 55 | ori_feat = torch.tensor(list(ori_feat_all)) 56 | new_feat = torch.tensor(list(new_feat_all)) 57 | new_label = torch.tensor(list(new_label_all)) 58 | 59 | new_feat = torch.reshape(new_feat, (new_feat.shape[0] * new_feat.shape[1], new_feat.shape[2])) 60 | new_label = torch.reshape(new_label, (new_label.shape[0] * new_label.shape[1],)) 61 | feat = torch.reshape(feat, (feat.shape[0] * feat.shape[1], feat.shape[2])) 62 | ori_feat = torch.reshape(ori_feat, (ori_feat.shape[0] * ori_feat.shape[1], ori_feat.shape[2])) 63 | label = torch.reshape(label, (label.shape[0] * label.shape[1],)) 64 | return feat, label, ori_feat, new_feat, new_label 65 | 66 | EPS = 1E-20 67 | current_task = '5' 68 | target_type = 'target_buf_labels' 69 | gamma_loss = '1.0' 70 | noise_type = 'noise' 71 | method2 = 'gaussian' 72 | c_theta = '45.0' 73 | para_scale = '1.0' 74 | 75 | s_list = [] 76 | method_list = ['drop_self', 'drop_new', 'class_mean', 'gaussian', 'none'] 77 | for method2 in method_list: 78 | if method2 == 'none': 79 | target_type = 'new_labels' 80 | gamma_loss = '50.0' 81 | noise_type = 'adv' 82 | 83 | epoch = 50 84 | name = str(current_task) + '_' + target_type + \ 85 | '_' + str(gamma_loss) + '_' + noise_type + \ 86 | '_' + method2 + '_' + str(c_theta) +\ 87 | '_' + str(para_scale) 88 | model_name = target_type + \ 89 | '_' + str(gamma_loss) + '_' + noise_type + \ 90 | '_' + method2 + '_' + str(c_theta) +\ 91 | '_' + str(para_scale) 92 | 93 | model_path = './output/task_models/seq-cifar100/%s' %model_name + '/task_5_model.ph' 94 | net = torch.load(model_path) 95 | net.train() 96 | feat = np.load('output/buff_featurte_task_%s.npy' %name, allow_pickle=True) 97 | label = np.load('output/buff_labels_task_%s.npy' %name, allow_pickle=True) 98 | new_feat = np.load('output/new_featurte_task_%s.npy' %name, allow_pickle=True) 99 | new_label = np.load('output/new_labels_task_%s.npy' %name, allow_pickle=True) 100 | ori_feat = np.load('output/ori_buffer_feat_task_%s.npy' %name, allow_pickle=True) 101 | feat, label, ori_feat, new_feat, new_label = trans_feat(feat, label, ori_feat, new_feat, new_label) 102 | 103 | aug_feat_1, ori_feat_1 = copy.deepcopy(feat), copy.deepcopy(ori_feat) 104 | feat_noise = aug_feat_1.detach() - ori_feat_1.detach() 105 | 106 | feat = feat + feat_noise 107 | num_per_epoch = int(feat_noise.shape[0]/epoch) 108 | mask = label[-num_per_epoch:] <80 109 | 110 | feat_all = torch.cat((feat, new_feat), 0) 111 | label_all = torch.cat((label, new_label), 0) 112 | 113 | pred = net.classifier(feat_all) 114 | loss = F.cross_entropy(pred, label_all) 115 | loss.backward() 116 | 117 | u, s, v = torch.svd(feat.grad[-num_per_epoch:][mask]) 118 | s_list.append(s) 119 | 120 | method_list = ['V-self-drop', 'V-new-drop', 'V-trans', 'V-gaussian', 'V-adv'] 121 | visual_svd(s_list) -------------------------------------------------------------------------------- /models/rpc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from utils.buffer import Buffer 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | from datasets import get_dataset 11 | 12 | def dsimplex(num_classes=10): 13 | def simplex_coordinates2(m): 14 | # add the credit 15 | import numpy as np 16 | 17 | x = np.zeros([m, m + 1]) 18 | for j in range(0, m): 19 | x[j, j] = 1.0 20 | 21 | a = (1.0 - np.sqrt(float(1 + m))) / float(m) 22 | 23 | for i in range(0, m): 24 | x[i, m] = a 25 | 26 | # Adjust coordinates so the centroid is at zero. 27 | c = np.zeros(m) 28 | for i in range(0, m): 29 | s = 0.0 30 | for j in range(0, m + 1): 31 | s = s + x[i, j] 32 | c[i] = s / float(m + 1) 33 | 34 | for j in range(0, m + 1): 35 | for i in range(0, m): 36 | x[i, j] = x[i, j] - c[i] 37 | 38 | # Scale so each column has norm 1. UNIT NORMALIZED 39 | s = 0.0 40 | for i in range(0, m): 41 | s = s + x[i, 0] ** 2 42 | s = np.sqrt(s) 43 | 44 | for j in range(0, m + 1): 45 | for i in range(0, m): 46 | x[i, j] = x[i, j] / s 47 | 48 | return x 49 | 50 | feat_dim = num_classes - 1 51 | ds = simplex_coordinates2(feat_dim) 52 | return ds 53 | 54 | def get_parser() -> ArgumentParser: 55 | parser = ArgumentParser(description='Continual learning via' 56 | ' Experience Replay.') 57 | add_management_args(parser) 58 | add_experiment_args(parser) 59 | add_rehearsal_args(parser) 60 | return parser 61 | 62 | 63 | class RPC(ContinualModel): 64 | NAME = 'rpc' 65 | COMPATIBILITY = ['class-il', 'task-il'] 66 | 67 | def __init__(self, backbone, loss, args, transform): 68 | super(RPC, self).__init__(backbone, loss, args, transform) 69 | self.buffer = Buffer(self.args.buffer_size, self.device) 70 | self.cpt = get_dataset(args).N_CLASSES_PER_TASK 71 | self.tasks = get_dataset(args).N_TASKS 72 | self.task=0 73 | self.rpchead = torch.from_numpy(dsimplex(self.cpt * self.tasks)).float().to(self.device) 74 | 75 | def forward(self, x): 76 | x = self.net(x)[:, :-1] 77 | x = x @ self.rpchead 78 | return x 79 | 80 | def end_task(self, dataset): 81 | # reduce coreset 82 | if self.task > 0: 83 | examples_per_class = self.args.buffer_size // ((self.task + 1) * self.cpt) 84 | buf_x, buf_lab = self.buffer.get_all_data() 85 | self.buffer.empty() 86 | for tl in buf_lab.unique(): 87 | idx = tl == buf_lab 88 | ex, lab = buf_x[idx], buf_lab[idx] 89 | first = min(ex.shape[0], examples_per_class) 90 | self.buffer.add_data( 91 | examples=ex[:first], 92 | labels = lab[:first] 93 | ) 94 | 95 | # add new task 96 | examples_last_task = self.buffer.buffer_size - self.buffer.num_seen_examples 97 | examples_per_class = examples_last_task // self.cpt 98 | ce = torch.tensor([examples_per_class] * self.cpt).int() 99 | ce[torch.randperm(self.cpt)[:examples_last_task - (examples_per_class * self.cpt)]] += 1 100 | 101 | with torch.no_grad(): 102 | for data in dataset.train_loader: 103 | _, labels, not_aug_inputs = data 104 | not_aug_inputs = not_aug_inputs.to(self.device) 105 | if all(ce == 0): 106 | break 107 | 108 | flags = torch.zeros(len(labels)).bool() 109 | for j in range(len(flags)): 110 | if ce[labels[j] % self.cpt] > 0: 111 | flags[j] = True 112 | ce[labels[j] % self.cpt] -= 1 113 | 114 | self.buffer.add_data(examples=not_aug_inputs[flags], 115 | labels=labels[flags]) 116 | self.task += 1 117 | 118 | def observe(self, inputs, labels, not_aug_inputs): 119 | self.opt.zero_grad() 120 | if not self.buffer.is_empty(): 121 | buf_inputs, buf_labels = self.buffer.get_data( 122 | self.args.minibatch_size, transform=self.transform) 123 | inputs = torch.cat((inputs, buf_inputs)) 124 | labels = torch.cat((labels, buf_labels)) 125 | 126 | outputs = self.net(inputs) 127 | losses = self.loss(outputs, labels, reduction='none') 128 | loss = losses.mean() 129 | 130 | loss.backward() 131 | self.opt.step() 132 | 133 | 134 | return loss.item() 135 | -------------------------------------------------------------------------------- /utils/plot_acc_curve.py: -------------------------------------------------------------------------------- 1 | from turtle import color 2 | import numpy as np 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import math 9 | import torch 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.nn import functional as F 13 | import math 14 | 15 | 16 | # Baseline = [83.05, 62.15, 46.12, 40.19, 31.08] 17 | # Self_drop = [83.05, 61.32, 49.02, 41.62, 33.67] 18 | # gaussian = [83.05, 64.05, 51.13, 42.66, 37.29] 19 | # vMF = [83.05, 66.45, 53.22, 44.17, 38.76] 20 | # new_drop = [83.05, 65.15, 52.13, 43.93, 38.75] 21 | # Trans = [83.05, 65.88, 53.10, 45.30, 39.78] 22 | # Adv = [83.05, 66.92, 54.58, 46.72, 41.02] 23 | # # plt.figure(figsize=(8,8)) 24 | # title = 'CIFAR-100' 25 | # plt.title(title, fontsize=14) 26 | # plt.ylim(30, 68) 27 | # plt.xlabel("Continual Task", fontsize=14) 28 | # plt.ylabel("Accuracy", fontsize=14) 29 | 30 | # plt.grid(linewidth = 1.5) 31 | # x = ['Task1', 'Task2', 'Task3', 'Task4', 'Task5'] 32 | # plt.plot(x, Baseline, alpha=0.8, label="ER", linewidth=3, marker= 'o') 33 | # plt.plot(x, Self_drop, alpha=0.8, label="DOA-self", linewidth=3, marker= 'o') 34 | # plt.plot(x, gaussian, alpha=0.8, label="Gaussian", linewidth=3, marker= 'o') 35 | # plt.plot(x, vMF, alpha=0.8, label="vMF", linewidth=3, marker= 'o') 36 | # plt.plot(x, new_drop, alpha=0.8, label="DOA", linewidth=3, marker= 'o') 37 | # plt.plot(x, Trans, alpha=0.8, label="VT", linewidth=3, marker= 'o') 38 | # plt.plot(x, Adv, alpha=0.8, label="WAP", linewidth=3, marker= 'o') 39 | # plt.tick_params(labelsize=14) 40 | # ax = plt.gca() 41 | # bwith = 1. 42 | # ax.spines['top'].set_linewidth(bwith) 43 | # ax.spines['right'].set_linewidth(bwith) 44 | # ax.spines['bottom'].set_linewidth(bwith) 45 | # ax.spines['left'].set_linewidth(bwith) 46 | # plt.legend(loc="best", fontsize=14, edgecolor='black') 47 | # plt.show() 48 | # plt.savefig('./acc_curve_%s.pdf' %title, bbox_inches = 'tight') 49 | 50 | 51 | 52 | # Baseline = [74.6, 52.2, 40.63, 32.22, 24.6, 20.38, 16.71, 14.68, 12.17, 11.24] 53 | # Self_drop = [74.6, 53.75, 40.33, 32.42, 26.94, 21.9, 17.17, 15.41, 12.31, 11.15] 54 | # gaussian = [74.6, 59.1, 46.63, 37.52, 28.74, 25.17, 20.56, 19.12, 16.21, 14.01] 55 | # vMF = [74.6, 57.21, 43.24, 34.55, 29.62, 24.27, 19.24, 19.95, 16.87, 14.62] 56 | # new_drop = [74.6, 55.85, 42.53, 34.8, 25.06, 21.8, 18.9, 19.06, 15.58, 14.96] 57 | # Trans = [74.6, 56.0, 44.63, 36.58, 31.1, 23.67, 19.93, 18.51, 14.91, 15.03] 58 | # Adv = [74.6, 59.2, 45.77, 35.58, 30.12, 26.52, 23.93, 20.35, 18.05, 16.68] 59 | # # plt.figure(figsize=(8,8)) 60 | # title = 'TinyImageNet' 61 | # plt.title(title, fontsize=14) 62 | # plt.ylim(0, 60) 63 | # plt.xlabel("Continual Task", fontsize=14) 64 | # plt.ylabel("Accuracy", fontsize=14) 65 | 66 | # plt.grid(linewidth = 1.5) 67 | # x = ['Task1', 'Task2', 'Task3', 'Task4', 'Task5', 'Task6', 'Task7', 'Task8', 'Task9', 'Task10'] 68 | # plt.plot(x, Baseline, alpha=0.8, label="ER", linewidth=3, marker= 'o') 69 | # plt.plot(x, Self_drop, alpha=0.8, label="DOA-self", linewidth=3, marker= 'o') 70 | # plt.plot(x, gaussian, alpha=0.8, label="Gaussian", linewidth=3, marker= 'o') 71 | # plt.plot(x, vMF, alpha=0.8, label="vMF", linewidth=3, marker= 'o') 72 | # plt.plot(x, new_drop, alpha=0.8, label="DOA", linewidth=3, marker= 'o') 73 | # plt.plot(x, Trans, alpha=0.8, label="VT", linewidth=3, marker= 'o') 74 | # plt.plot(x, Adv, alpha=0.8, label="WAP", linewidth=3, marker= 'o') 75 | # plt.tick_params(labelsize=10) 76 | # ax = plt.gca() 77 | # bwith = 1. 78 | # ax.spines['top'].set_linewidth(bwith) 79 | # ax.spines['right'].set_linewidth(bwith) 80 | # ax.spines['bottom'].set_linewidth(bwith) 81 | # ax.spines['left'].set_linewidth(bwith) 82 | # plt.legend(loc="best", fontsize=14, edgecolor='black') 83 | # plt.show() 84 | # plt.savefig('./acc_curve_%s.pdf' %title, bbox_inches = 'tight') 85 | 86 | 87 | 88 | Baseline = [98.25, 90.22, 74.22, 71.03, 62.53] 89 | Self_drop = [98.25, 90.7, 79.23, 70.88, 72.18] 90 | gaussian = [98.25, 88.82, 78.68, 74.23, 67.85] 91 | vMF = [98.25, 88.82, 78.87, 76.18, 71.58] 92 | new_drop = [98.25, 90.15, 72.68, 75.42, 67.88] 93 | Trans = [98.25, 89.5, 78.98, 69.24, 71.66] 94 | Adv = [98.25, 91.35, 80.03, 74.24, 72.99] 95 | 96 | # plt.figure(figsize=(8,8)) 97 | title = 'CIFAR-10' 98 | plt.title(title, fontsize=14) 99 | plt.ylim(50, 100) 100 | plt.xlabel("Continual Task", fontsize=14) 101 | plt.ylabel("Accuracy", fontsize=14) 102 | 103 | plt.grid(linewidth = 1.5) 104 | x = ['Task1', 'Task2', 'Task3', 'Task4', 'Task5'] 105 | plt.plot(x, Baseline, alpha=0.8, label="ER", linewidth=3, marker= 'o') 106 | plt.plot(x, Self_drop, alpha=0.8, label="DOA-self", linewidth=3, marker= 'o') 107 | plt.plot(x, gaussian, alpha=0.8, label="Gaussian", linewidth=3, marker= 'o') 108 | plt.plot(x, vMF, alpha=0.8, label="vMF", linewidth=3, marker= 'o') 109 | plt.plot(x, new_drop, alpha=0.8, label="DOA", linewidth=3, marker= 'o') 110 | plt.plot(x, Trans, alpha=0.8, label="VT", linewidth=3, marker= 'o') 111 | plt.plot(x, Adv, alpha=0.8, label="WAP", linewidth=3, marker= 'o') 112 | plt.tick_params(labelsize=14) 113 | ax = plt.gca() 114 | bwith = 1. 115 | ax.spines['top'].set_linewidth(bwith) 116 | ax.spines['right'].set_linewidth(bwith) 117 | ax.spines['bottom'].set_linewidth(bwith) 118 | ax.spines['left'].set_linewidth(bwith) 119 | plt.legend(loc="best", fontsize=14, edgecolor='black') 120 | plt.show() 121 | plt.savefig('./acc_curve_%s.pdf' %title, bbox_inches = 'tight') 122 | -------------------------------------------------------------------------------- /utils/method_variation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use("Agg") 5 | import matplotlib.pyplot as plt 6 | import copy 7 | import math 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from torch.nn import functional as F 12 | import math 13 | 14 | def visual_svd(s_list): 15 | x = np.arange(0,512) 16 | from matplotlib.ticker import MaxNLocator 17 | plt.ion() 18 | plt.clf() 19 | for i in range(len(s_list)): 20 | s_list[i] = s_list[i] / torch.sum(s_list[i]) 21 | plt.plot(x, torch.log(s_list[3]), label=method_list[3]) 22 | plt.plot(x, torch.log(s_list[0]), label=method_list[0]) 23 | plt.plot(x, torch.log(s_list[1]), label=method_list[1]) 24 | plt.plot(x, torch.log(s_list[2]), label=method_list[2]) 25 | plt.plot(x, torch.log(s_list[4]), label=method_list[4]) 26 | plt.xlabel('Dimension', fontsize=14) 27 | plt.ylabel('Log Singular Value', fontsize=14) 28 | plt.legend() 29 | 30 | ax = plt.gca() 31 | bwith = 1. 32 | ax.spines['top'].set_linewidth(bwith) 33 | ax.spines['right'].set_linewidth(bwith) 34 | ax.spines['bottom'].set_linewidth(bwith) 35 | ax.spines['left'].set_linewidth(bwith) 36 | plt.tick_params(labelsize=14) 37 | plt.savefig('./svd.pdf') 38 | plt.show 39 | 40 | def trans_feat(feat, label, ori_feat, new_feat, new_label): 41 | feat_all = [] 42 | label_all = [] 43 | new_feat_all = [] 44 | new_label_all = [] 45 | ori_feat_all = [] 46 | for len_i in range(label.shape[0]): 47 | if len(feat[len_i]) == len(feat[2]): 48 | feat_all.append(feat[len_i]) 49 | label_all.append(label[len_i]) 50 | ori_feat_all.append(ori_feat[len_i]) 51 | new_feat_all.append(new_feat[len_i]) 52 | new_label_all.append(new_label[len_i]) 53 | feat = torch.tensor(list(feat_all)) 54 | label = torch.tensor(list(label_all)) 55 | ori_feat = torch.tensor(list(ori_feat_all)) 56 | new_feat = torch.tensor(list(new_feat_all)) 57 | new_label = torch.tensor(list(new_label_all)) 58 | 59 | new_feat = torch.reshape(new_feat, (new_feat.shape[0] * new_feat.shape[1], new_feat.shape[2])) 60 | new_label = torch.reshape(new_label, (new_label.shape[0] * new_label.shape[1],)) 61 | feat = torch.reshape(feat, (feat.shape[0] * feat.shape[1], feat.shape[2])) 62 | ori_feat = torch.reshape(ori_feat, (ori_feat.shape[0] * ori_feat.shape[1], ori_feat.shape[2])) 63 | label = torch.reshape(label, (label.shape[0] * label.shape[1],)) 64 | return feat, label, ori_feat, new_feat, new_label 65 | 66 | def compute_variance(class_feat, m_feature): 67 | RMSSTD = 0 68 | for i_ter in range(class_feat.shape[0]): 69 | RMSSTD += 1 - torch.cosine_similarity(m_feature[0], class_feat[i_ter], dim=0) 70 | RMSSTD = RMSSTD / (class_feat.shape[0]) 71 | return RMSSTD 72 | 73 | EPS = 1E-20 74 | current_task = '5' 75 | target_type = 'target_buf_labels' 76 | gamma_loss = '1.0' 77 | noise_type = 'noise' 78 | method2 = 'gaussian' 79 | c_theta = '45.0' 80 | para_scale = '1.0' 81 | 82 | s_list = [] 83 | # method_list = ['drop_self', 'drop_new', 'class_mean', 'gaussian', 'none'] #'drop_self', 'drop_new', 'class_mean', 'gaussian', 84 | method_list = ['drop_new', 'drop_self'] 85 | for idx, method2 in enumerate(method_list): 86 | if method2 == 'none': 87 | target_type = 'new_labels' 88 | gamma_loss = '50.0' 89 | noise_type = 'adv' 90 | if idx == 0: 91 | para_scale = '1.0' 92 | if idx == 1: 93 | para_scale = '1.5' 94 | 95 | name = str(current_task) + '_' + target_type + \ 96 | '_' + str(gamma_loss) + '_' + noise_type + \ 97 | '_' + method2 + '_' + str(c_theta) +\ 98 | '_' + str(para_scale) 99 | 100 | # dirs = './out_gradient/' + 'mnist_2d/gaussian/'+ 'er_ours_3_50_task_2_model.ph' 101 | feat = np.load('output_7_22_all_method/buff_featurte_task_%s.npy' %name, allow_pickle=True) 102 | label = np.load('output_7_22_all_method/buff_labels_task_%s.npy' %name, allow_pickle=True) 103 | new_feat = np.load('output_7_22_all_method/new_featurte_task_%s.npy' %name, allow_pickle=True) 104 | new_label = np.load('output_7_22_all_method/new_labels_task_%s.npy' %name, allow_pickle=True) 105 | ori_feat = np.load('output_7_22_all_method/ori_buffer_feat_task_%s.npy' %name, allow_pickle=True) 106 | feat, label, ori_feat, new_feat, new_label = trans_feat(feat, label, ori_feat, new_feat, new_label) 107 | 108 | aug_feat_1, ori_feat_1 = copy.deepcopy(feat), copy.deepcopy(ori_feat) 109 | feat_noise = aug_feat_1.detach() - ori_feat_1.detach() 110 | 111 | epoch = 50 112 | num_per_epoch = int(feat_noise.shape[0]/epoch) 113 | all_feat = torch.cat((feat, new_feat), 0) 114 | all_label = torch.cat((label, new_label), 0) 115 | RMSSTD_list = [] 116 | for i_class in range(0,100): 117 | mask = all_label == i_class 118 | class_feat = all_feat[mask] 119 | 120 | class_feat = F.normalize(class_feat, dim=1, p=2) 121 | m_feature = torch.mean(class_feat, dim=0).unsqueeze(0) 122 | m_feature = F.normalize(m_feature, dim=1, p=2) 123 | 124 | RMSSTD = compute_variance(class_feat, m_feature) 125 | 126 | RMSSTD_list.append(RMSSTD) 127 | 128 | print(sum(RMSSTD_list[:80])/80, sum(RMSSTD_list[80:])/20) 129 | 130 | 131 | 132 | # method_list = ['V-self-drop', 'V-new-drop', 'V-trans', 'V-gaussian', 'V-adv'] 133 | # visual_svd(s_list) -------------------------------------------------------------------------------- /utils/ring_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from typing import Tuple 9 | from torchvision import transforms 10 | 11 | 12 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: 13 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size 14 | 15 | 16 | class RingBuffer: 17 | """ 18 | The memory buffer of rehearsal method. 19 | """ 20 | def __init__(self, buffer_size, device, n_tasks): 21 | self.buffer_size = buffer_size 22 | self.buffer_portion_size = buffer_size // n_tasks 23 | self.device = device 24 | self.task_number = 0 25 | self.num_seen_examples = 0 26 | self.attributes = ['examples', 'labels', 'logits', 'task_labels'] 27 | 28 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 29 | logits: torch.Tensor, task_labels: torch.Tensor) -> None: 30 | """ 31 | Initializes just the required tensors. 32 | :param examples: tensor containing the images 33 | :param labels: tensor containing the labels 34 | :param logits: tensor containing the outputs of the network 35 | :param task_labels: tensor containing the task labels 36 | """ 37 | for attr_str in self.attributes: 38 | attr = eval(attr_str) 39 | if attr is not None and not hasattr(self, attr_str): 40 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 41 | setattr(self, attr_str, torch.zeros((self.buffer_size, 42 | *attr.shape[1:]), dtype=typ, device=self.device)) 43 | 44 | self.labels -= 1 45 | 46 | def add_data(self, examples, labels=None, logits=None, task_labels=None): 47 | """ 48 | Adds the data to the memory buffer according to the reservoir strategy. 49 | :param examples: tensor containing the images 50 | :param labels: tensor containing the labels 51 | :param logits: tensor containing the outputs of the network 52 | :param task_labels: tensor containing the task labels 53 | :return: 54 | """ 55 | if not hasattr(self, 'examples'): 56 | self.init_tensors(examples, labels, logits, task_labels) 57 | 58 | for i in range(examples.shape[0]): 59 | index = ring(self.num_seen_examples, self.buffer_portion_size, self.task_number) 60 | self.num_seen_examples += 1 61 | if index >= 0: 62 | self.examples[index] = examples[i].to(self.device) 63 | if labels is not None: 64 | self.labels[index] = labels[i].to(self.device) 65 | if logits is not None: 66 | self.logits[index] = logits[i].to(self.device) 67 | if task_labels is not None: 68 | self.task_labels[index] = task_labels[i].to(self.device) 69 | 70 | def get_data(self, size: int, transform: transforms=None) -> Tuple: 71 | """ 72 | Random samples a batch of size items. 73 | :param size: the number of requested items 74 | :param transform: the transformation to be applied (data augmentation) 75 | :return: 76 | """ 77 | populated_portion_length = (self.labels != -1).sum().item() 78 | 79 | if size > populated_portion_length: 80 | size = populated_portion_length 81 | 82 | 83 | choice = np.random.choice(populated_portion_length, size=size, replace=False) 84 | if transform is None: transform = lambda x: x 85 | ret_tuple = (torch.stack([transform(ee.cpu()) 86 | for ee in self.examples[choice]]).to(self.device),) 87 | for attr_str in self.attributes[1:]: 88 | if hasattr(self, attr_str): 89 | attr = getattr(self, attr_str) 90 | ret_tuple += (attr[choice],) 91 | 92 | return ret_tuple 93 | 94 | def is_empty(self) -> bool: 95 | """ 96 | Returns true if the buffer is empty, false otherwise. 97 | """ 98 | if self.num_seen_examples == 0 and self.task_number == 0: 99 | return True 100 | else: 101 | return False 102 | 103 | def get_all_data(self, transform: transforms=None) -> Tuple: 104 | """ 105 | Return all the items in the memory buffer. 106 | :param transform: the transformation to be applied (data augmentation) 107 | :return: a tuple with all the items in the memory buffer 108 | """ 109 | if transform is None: transform = lambda x: x 110 | ret_tuple = (torch.stack([transform(ee.cpu()) 111 | for ee in self.examples]).to(self.device),) 112 | for attr_str in self.attributes[1:]: 113 | if hasattr(self, attr_str): 114 | attr = getattr(self, attr_str) 115 | ret_tuple += (attr,) 116 | return ret_tuple 117 | 118 | def empty(self) -> None: 119 | """ 120 | Set all the tensors to None. 121 | """ 122 | for attr_str in self.attributes: 123 | if hasattr(self, attr_str): 124 | delattr(self, attr_str) 125 | self.num_seen_examples = 0 126 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from argparse import ArgumentParser 7 | from datasets import NAMES as DATASET_NAMES 8 | from models import get_all_models 9 | 10 | 11 | def add_experiment_args(parser: ArgumentParser) -> None: 12 | """ 13 | Adds the arguments used by all the models. 14 | :param parser: the parser instance 15 | """ 16 | parser.add_argument('--dataset', type=str, required=True, 17 | choices=DATASET_NAMES, 18 | help='Which dataset to perform experiments on.') 19 | parser.add_argument('--model', type=str, required=True, 20 | help='Model name.', choices=get_all_models()) 21 | 22 | parser.add_argument('--lr', type=float, required=True, 23 | help='Learning rate.') 24 | 25 | parser.add_argument('--optim_wd', type=float, default=0., 26 | help='optimizer weight decay.') 27 | parser.add_argument('--optim_mom', type=float, default=0., 28 | help='optimizer momentum.') 29 | parser.add_argument('--optim_nesterov', type=int, default=0, 30 | help='optimizer nesterov momentum.') 31 | 32 | parser.add_argument('--n_epochs', type=int, 33 | help='Batch size.') 34 | parser.add_argument('--batch_size', type=int, 35 | help='Batch size.') 36 | 37 | def add_management_args(parser: ArgumentParser) -> None: 38 | parser.add_argument('--seed', type=int, default=None, 39 | help='The random seed.') 40 | parser.add_argument('--notes', type=str, default=None, 41 | help='Notes for this run.') 42 | 43 | parser.add_argument('--non_verbose', action='store_true') 44 | parser.add_argument('--csv_log', action='store_true', 45 | help='Enable csv logging') 46 | parser.add_argument('--tensorboard', action='store_true', 47 | help='Enable tensorboard logging') 48 | parser.add_argument('--validation', action='store_true', 49 | help='Test on the validation set') 50 | 51 | ############################################################# 52 | parser.add_argument('--noise_std', type=float, default=0.01, required=False, 53 | help='noise_std', ) 54 | parser.add_argument('--c_theta', type=float, default=30, required=False, 55 | help='c_theta', ) 56 | parser.add_argument('--on_sphere', type=str, default='none', required=False, 57 | help='on_sphere.') 58 | parser.add_argument('--noise_type', type=str, default='adv', required=False, 59 | help='noise_type.') 60 | parser.add_argument('--noise_factor', type=float, default=0.01, required=False, 61 | help='noise_factor', ) 62 | parser.add_argument('--drop_rate', type=float, default=0.5, required=False, 63 | help='drop_rate', ) 64 | parser.add_argument('--mix_rate', type=float, default=0.5, required=False, 65 | help='mix_rate', ) 66 | 67 | parser.add_argument('--drop_factor', type=float, default=0.7, required=False, 68 | help='drop_factor', ) 69 | parser.add_argument('--gaussian_factor', type=float, default=0.3, required=False, 70 | help='gaussian_factor', ) 71 | 72 | parser.add_argument('--para_scale', type=float, default=1, required=False, 73 | help='para_scale', ) 74 | parser.add_argument('--inner_iter', type=float, default=5, required=False, 75 | help='inner_iter', ) 76 | parser.add_argument('--gamma_loss', type=float, default=0.01, required=False, 77 | help='gamma', ) 78 | parser.add_argument('--method2', default="mean", type=str, 79 | help='Directory where data files are stored.') 80 | parser.add_argument('--norm_add', default="none", type=str, 81 | help='Directory where data files are stored.') 82 | parser.add_argument('--target_type', default="mean", type=str, 83 | help='Directory where data files are stored.') 84 | parser.add_argument('--advloss', default="nega", type=str, 85 | help='Directory where data files are stored.') 86 | 87 | parser.add_argument('--optimizer', default="SGD", type=str, 88 | help='Directory where data files are stored.') 89 | 90 | parser.add_argument('--epsilon', default= 0.05, type=float, 91 | help='Directory where data files are stored.') 92 | parser.add_argument('--cos_temp', type=float, default=15, required=False, 93 | help='cos_temp', ) 94 | 95 | 96 | def add_rehearsal_args(parser: ArgumentParser) -> None: 97 | """ 98 | Adds the arguments used by all the rehearsal-based methods 99 | :param parser: the parser instance 100 | """ 101 | parser.add_argument('--buffer_size', type=int, required=True, 102 | help='The size of the memory buffer.') 103 | parser.add_argument('--minibatch_size', type=int, 104 | help='The batch size of the memory buffer.') 105 | -------------------------------------------------------------------------------- /utils/plot_test_variance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy # needed (don't change it) 7 | import importlib 8 | import os 9 | import sys 10 | import socket 11 | mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(mammoth_path) 13 | sys.path.append(mammoth_path + '/datasets') 14 | sys.path.append(mammoth_path + '/backbone') 15 | sys.path.append(mammoth_path + '/models') 16 | 17 | from datasets import NAMES as DATASET_NAMES 18 | from models import get_all_models 19 | from argparse import ArgumentParser 20 | from utils.args import add_management_args 21 | from datasets import ContinualDataset 22 | from utils.continual_training import train as ctrain 23 | from datasets import get_dataset 24 | from models import get_model 25 | from utils.training import train, test_visual, test_variance 26 | from utils.best_args import best_args 27 | from utils.conf import set_random_seed 28 | import setproctitle 29 | import torch 30 | import uuid 31 | import datetime 32 | 33 | import numpy as np 34 | import random 35 | import torch 36 | import torch.nn as nn 37 | 38 | 39 | def lecun_fix(): 40 | # Yann moved his website to CloudFlare. You need this now 41 | from six.moves import urllib 42 | opener = urllib.request.build_opener() 43 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 44 | urllib.request.install_opener(opener) 45 | 46 | def parse_args(): 47 | parser = ArgumentParser(description='mammoth', allow_abbrev=False) 48 | parser.add_argument('--model', type=str, required=True, 49 | help='Model name.', choices=get_all_models()) 50 | parser.add_argument('--load_best_args', action='store_true', 51 | help='Loads the best arguments for each method, ' 52 | 'dataset and memory buffer.') 53 | torch.set_num_threads(4) 54 | add_management_args(parser) 55 | args = parser.parse_known_args()[0] 56 | mod = importlib.import_module('models.' + args.model) 57 | 58 | if args.load_best_args: 59 | parser.add_argument('--dataset', type=str, required=True, 60 | choices=DATASET_NAMES, 61 | help='Which dataset to perform experiments on.') 62 | if hasattr(mod, 'Buffer'): 63 | parser.add_argument('--buffer_size', type=int, required=True, 64 | help='The size of the memory buffer.') 65 | args = parser.parse_args() 66 | if args.model == 'joint': 67 | best = best_args[args.dataset]['sgd'] 68 | else: 69 | best = best_args[args.dataset][args.model] 70 | if hasattr(mod, 'Buffer'): 71 | best = best[args.buffer_size] 72 | else: 73 | best = best[-1] 74 | get_parser = getattr(mod, 'get_parser') 75 | parser = get_parser() 76 | to_parse = sys.argv[1:] + ['--' + k + '=' + str(v) for k, v in best.items()] 77 | to_parse.remove('--load_best_args') 78 | args = parser.parse_args(to_parse) 79 | if args.model == 'joint' and args.dataset == 'mnist-360': 80 | args.model = 'joint_gcl' 81 | else: 82 | get_parser = getattr(mod, 'get_parser') 83 | parser = get_parser() 84 | args = parser.parse_args() 85 | 86 | if args.seed is not None: 87 | set_random_seed(args.seed) 88 | 89 | return args 90 | 91 | def main(args=None): 92 | lecun_fix() 93 | if args is None: 94 | args = parse_args() 95 | 96 | random.seed(args.seed) 97 | os.environ['PYTHONHASHSEED'] = str(args.seed) # 为了禁止hash随机化,使得实验可复现 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed(args.seed) 101 | torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. 102 | torch.backends.cudnn.benchmark = False 103 | torch.backends.cudnn.deterministic = True 104 | 105 | # Add uuid, timestamp and hostname for logging 106 | args.conf_jobnum = str(uuid.uuid4()) 107 | args.conf_timestamp = str(datetime.datetime.now()) 108 | args.conf_host = socket.gethostname() 109 | dataset = get_dataset(args) 110 | 111 | if args.n_epochs is None and isinstance(dataset, ContinualDataset): 112 | args.n_epochs = dataset.get_epochs() 113 | if args.batch_size is None: 114 | args.batch_size = dataset.get_batch_size() 115 | if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and args.minibatch_size is None: 116 | args.minibatch_size = dataset.get_minibatch_size() 117 | 118 | backbone = dataset.get_backbone() 119 | loss = dataset.get_loss() 120 | model = get_model(args, backbone, loss, dataset.get_transform()) 121 | 122 | # model.net.linear.scale = args.cos_temp 123 | 124 | # set job name 125 | setproctitle.setproctitle('{}_{}_{}'.format(args.model, args.buffer_size if 'buffer_size' in args else 0, args.dataset)) 126 | 127 | # test_visual(model, dataset, args) 128 | if isinstance(dataset, ContinualDataset): 129 | test_variance(model, dataset, args) 130 | else: 131 | assert not hasattr(model, 'end_task') or model.NAME == 'joint_gcl' 132 | ctrain(args) 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /datasets/utils/continual_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from abc import abstractmethod 7 | from argparse import Namespace 8 | from torch import nn as nn 9 | from torchvision.transforms import transforms 10 | from torch.utils.data import DataLoader 11 | from typing import Tuple 12 | from torchvision import datasets 13 | import numpy as np 14 | import torch.optim 15 | 16 | class ContinualDataset: 17 | """ 18 | Continual learning evaluation setting. 19 | """ 20 | NAME = None 21 | SETTING = None 22 | N_CLASSES_PER_TASK = None 23 | N_TASKS = None 24 | TRANSFORM = None 25 | 26 | def __init__(self, args: Namespace) -> None: 27 | """ 28 | Initializes the train and test lists of dataloaders. 29 | :param args: the arguments which contains the hyperparameters 30 | """ 31 | self.train_loader = None 32 | self.test_loaders = [] 33 | self.i = 0 34 | self.args = args 35 | 36 | @abstractmethod 37 | def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]: 38 | """ 39 | Creates and returns the training and test loaders for the current task. 40 | The current training loader and all test loaders are stored in self. 41 | :return: the current training and test loaders 42 | """ 43 | pass 44 | 45 | @staticmethod 46 | @abstractmethod 47 | def get_backbone() -> nn.Module: 48 | """ 49 | Returns the backbone to be used for to the current dataset. 50 | """ 51 | pass 52 | 53 | @staticmethod 54 | @abstractmethod 55 | def get_transform() -> transforms: 56 | """ 57 | Returns the transform to be used for to the current dataset. 58 | """ 59 | pass 60 | 61 | @staticmethod 62 | @abstractmethod 63 | def get_loss() -> nn.functional: 64 | """ 65 | Returns the loss to be used for to the current dataset. 66 | """ 67 | pass 68 | 69 | @staticmethod 70 | @abstractmethod 71 | def get_normalization_transform() -> transforms: 72 | """ 73 | Returns the transform used for normalizing the current dataset. 74 | """ 75 | pass 76 | 77 | @staticmethod 78 | @abstractmethod 79 | def get_denormalization_transform() -> transforms: 80 | """ 81 | Returns the transform used for denormalizing the current dataset. 82 | """ 83 | pass 84 | 85 | @staticmethod 86 | @abstractmethod 87 | def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler: 88 | """ 89 | Returns the scheduler to be used for to the current dataset. 90 | """ 91 | pass 92 | 93 | @staticmethod 94 | def get_epochs(): 95 | pass 96 | 97 | @staticmethod 98 | def get_batch_size(): 99 | pass 100 | 101 | @staticmethod 102 | def get_minibatch_size(): 103 | pass 104 | 105 | 106 | 107 | def store_masked_loaders(train_dataset: datasets, test_dataset: datasets, 108 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]: 109 | """ 110 | Divides the dataset into tasks. 111 | :param train_dataset: train dataset 112 | :param test_dataset: test dataset 113 | :param setting: continual learning setting 114 | :return: train and test loaders 115 | """ 116 | train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i, 117 | np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK) 118 | test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i, 119 | np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK) 120 | 121 | train_dataset.data = train_dataset.data[train_mask] 122 | test_dataset.data = test_dataset.data[test_mask] 123 | 124 | train_dataset.targets = np.array(train_dataset.targets)[train_mask] 125 | test_dataset.targets = np.array(test_dataset.targets)[test_mask] 126 | 127 | train_loader = DataLoader(train_dataset, 128 | batch_size=setting.args.batch_size, shuffle=True, num_workers=4) 129 | test_loader = DataLoader(test_dataset, 130 | batch_size=setting.args.batch_size, shuffle=False, num_workers=4) 131 | setting.test_loaders.append(test_loader) 132 | setting.train_loader = train_loader 133 | 134 | setting.i += setting.N_CLASSES_PER_TASK 135 | return train_loader, test_loader 136 | 137 | 138 | def get_previous_train_loader(train_dataset: datasets, batch_size: int, 139 | setting: ContinualDataset) -> DataLoader: 140 | """ 141 | Creates a dataloader for the previous task. 142 | :param train_dataset: the entire training set 143 | :param batch_size: the desired batch size 144 | :param setting: the continual dataset at hand 145 | :return: a dataloader 146 | """ 147 | train_mask = np.logical_and(np.array(train_dataset.targets) >= 148 | setting.i - setting.N_CLASSES_PER_TASK, np.array(train_dataset.targets) 149 | < setting.i - setting.N_CLASSES_PER_TASK + setting.N_CLASSES_PER_TASK) 150 | 151 | train_dataset.data = train_dataset.data[train_mask] 152 | train_dataset.targets = np.array(train_dataset.targets)[train_mask] 153 | 154 | return DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 155 | -------------------------------------------------------------------------------- /utils/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy # needed (don't change it) 7 | import importlib 8 | import os 9 | import sys 10 | import socket 11 | mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(mammoth_path) 13 | sys.path.append(mammoth_path + '/datasets') 14 | sys.path.append(mammoth_path + '/backbone') 15 | sys.path.append(mammoth_path + '/models') 16 | 17 | from datasets import NAMES as DATASET_NAMES 18 | from models import get_all_models 19 | from argparse import ArgumentParser 20 | from utils.args import add_management_args 21 | from datasets import ContinualDataset 22 | from utils.continual_training import train as ctrain 23 | from datasets import get_dataset 24 | from models import get_model 25 | from utils.training import train, test_visual 26 | from utils.best_args import best_args 27 | from utils.conf import set_random_seed 28 | import setproctitle 29 | import torch 30 | import uuid 31 | import datetime 32 | 33 | import numpy as np 34 | import random 35 | import torch 36 | import torch.nn as nn 37 | 38 | 39 | def lecun_fix(): 40 | # Yann moved his website to CloudFlare. You need this now 41 | from six.moves import urllib 42 | opener = urllib.request.build_opener() 43 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 44 | urllib.request.install_opener(opener) 45 | 46 | def parse_args(): 47 | parser = ArgumentParser(description='mammoth', allow_abbrev=False) 48 | parser.add_argument('--model', type=str, required=True, 49 | help='Model name.', choices=get_all_models()) 50 | parser.add_argument('--load_best_args', action='store_true', 51 | help='Loads the best arguments for each method, ' 52 | 'dataset and memory buffer.') 53 | torch.set_num_threads(4) 54 | add_management_args(parser) 55 | args = parser.parse_known_args()[0] 56 | mod = importlib.import_module('models.' + args.model) 57 | 58 | if args.load_best_args: 59 | parser.add_argument('--dataset', type=str, required=True, 60 | choices=DATASET_NAMES, 61 | help='Which dataset to perform experiments on.') 62 | if hasattr(mod, 'Buffer'): 63 | parser.add_argument('--buffer_size', type=int, required=True, 64 | help='The size of the memory buffer.') 65 | args = parser.parse_args() 66 | if args.model == 'joint': 67 | best = best_args[args.dataset]['sgd'] 68 | else: 69 | best = best_args[args.dataset][args.model] 70 | if hasattr(mod, 'Buffer'): 71 | best = best[args.buffer_size] 72 | else: 73 | best = best[-1] 74 | get_parser = getattr(mod, 'get_parser') 75 | parser = get_parser() 76 | to_parse = sys.argv[1:] + ['--' + k + '=' + str(v) for k, v in best.items()] 77 | to_parse.remove('--load_best_args') 78 | args = parser.parse_args(to_parse) 79 | if args.model == 'joint' and args.dataset == 'mnist-360': 80 | args.model = 'joint_gcl' 81 | else: 82 | get_parser = getattr(mod, 'get_parser') 83 | parser = get_parser() 84 | args = parser.parse_args() 85 | 86 | if args.seed is not None: 87 | set_random_seed(args.seed) 88 | 89 | return args 90 | 91 | def main(args=None): 92 | lecun_fix() 93 | if args is None: 94 | args = parse_args() 95 | 96 | random.seed(args.seed) 97 | os.environ['PYTHONHASHSEED'] = str(args.seed) # 为了禁止hash随机化,使得实验可复现 98 | np.random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | torch.cuda.manual_seed(args.seed) 101 | torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. 102 | torch.backends.cudnn.benchmark = False 103 | torch.backends.cudnn.deterministic = True 104 | 105 | # Add uuid, timestamp and hostname for logging 106 | args.conf_jobnum = str(uuid.uuid4()) 107 | args.conf_timestamp = str(datetime.datetime.now()) 108 | args.conf_host = socket.gethostname() 109 | dataset = get_dataset(args) 110 | 111 | if args.n_epochs is None and isinstance(dataset, ContinualDataset): 112 | args.n_epochs = dataset.get_epochs() 113 | if args.batch_size is None: 114 | args.batch_size = dataset.get_batch_size() 115 | if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and args.minibatch_size is None: 116 | args.minibatch_size = dataset.get_minibatch_size() 117 | 118 | backbone = dataset.get_backbone() 119 | loss = dataset.get_loss() 120 | model = get_model(args, backbone, loss, dataset.get_transform()) 121 | 122 | # model.net.linear.scale = args.cos_temp 123 | 124 | # if args.model == 'lucir': 125 | # if args.dataset == 'seq-cifar10': 126 | # model.net.linear = nn.Linear(512, 10) 127 | # elif args.dataset == 'seq-cifar100': 128 | # model.net.linear = nn.Linear(512, 100) 129 | # else: 130 | # model.net.linear = nn.Linear(512, 200) 131 | 132 | # set job name 133 | setproctitle.setproctitle('{}_{}_{}'.format(args.model, args.buffer_size if 'buffer_size' in args else 0, args.dataset)) 134 | 135 | # test_visual(model, dataset, args) 136 | if isinstance(dataset, ContinualDataset): 137 | train(model, dataset, args) 138 | else: 139 | assert not hasattr(model, 'end_task') or model.NAME == 'joint_gcl' 140 | ctrain(args) 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /datasets/seq_cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchvision.datasets import CIFAR100 7 | import torchvision.transforms as transforms 8 | from backbone.ResNet18 import resnet18 9 | import torch.nn.functional as F 10 | from utils.conf import base_path 11 | from PIL import Image 12 | from datasets.utils.validation import get_train_val 13 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 14 | from typing import Tuple 15 | from datasets.transforms.denormalization import DeNormalize 16 | import torch.optim 17 | 18 | class TCIFAR100(CIFAR100): 19 | def __init__(self, root, train=True, transform=None, 20 | target_transform=None, download=False) -> None: 21 | self.root = root 22 | super(TCIFAR100, self).__init__(root, train, transform, target_transform, download=not self._check_integrity()) 23 | 24 | class MyCIFAR100(CIFAR100): 25 | """ 26 | Overrides the CIFAR100 dataset to change the getitem function. 27 | """ 28 | def __init__(self, root, train=True, transform=None, 29 | target_transform=None, download=False) -> None: 30 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()]) 31 | self.root = root 32 | super(MyCIFAR100, self).__init__(root, train, transform, target_transform, not self._check_integrity()) 33 | 34 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]: 35 | """ 36 | Gets the requested element from the dataset. 37 | :param index: index of the element to be returned 38 | :returns: tuple: (image, target) where target is index of the target class. 39 | """ 40 | img, target = self.data[index], self.targets[index] 41 | 42 | # to return a PIL Image 43 | img = Image.fromarray(img, mode='RGB') 44 | original_img = img.copy() 45 | 46 | not_aug_img = self.not_aug_transform(original_img) 47 | 48 | if self.transform is not None: 49 | img = self.transform(img) 50 | 51 | if self.target_transform is not None: 52 | target = self.target_transform(target) 53 | 54 | if hasattr(self, 'logits'): 55 | return img, target, not_aug_img, self.logits[index] 56 | 57 | return img, target, not_aug_img 58 | 59 | 60 | class SequentialCIFAR100(ContinualDataset): 61 | 62 | NAME = 'seq-cifar100' 63 | SETTING = 'class-il' 64 | # N_CLASSES_PER_TASK = 10 65 | # N_TASKS = 10 66 | N_CLASSES_PER_TASK = 20 67 | N_TASKS = 5 68 | # N_CLASSES_PER_TASK = 50 69 | # N_TASKS = 2 70 | TRANSFORM = transforms.Compose( 71 | [transforms.RandomCrop(32, padding=4), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.5071, 0.4867, 0.4408), 75 | (0.2675, 0.2565, 0.2761))]) 76 | 77 | def get_examples_number(self): 78 | train_dataset = MyCIFAR100(base_path() + 'CIFAR10', train=True, 79 | download=True) 80 | return len(train_dataset.data) 81 | 82 | def get_data_loaders(self): 83 | transform = self.TRANSFORM 84 | 85 | test_transform = transforms.Compose( 86 | [transforms.ToTensor(), self.get_normalization_transform()]) 87 | 88 | train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True, 89 | download=True, transform=transform) 90 | if self.args.validation: 91 | train_dataset, test_dataset = get_train_val(train_dataset, 92 | test_transform, self.NAME) 93 | else: 94 | test_dataset = TCIFAR100(base_path() + 'CIFAR100',train=False, 95 | download=True, transform=test_transform) 96 | 97 | train, test = store_masked_loaders(train_dataset, test_dataset, self) 98 | 99 | return train, test 100 | 101 | @staticmethod 102 | def get_transform(): 103 | transform = transforms.Compose( 104 | [transforms.ToPILImage(), SequentialCIFAR100.TRANSFORM]) 105 | return transform 106 | 107 | @staticmethod 108 | def get_backbone(): 109 | return resnet18(SequentialCIFAR100.N_CLASSES_PER_TASK 110 | * SequentialCIFAR100.N_TASKS) 111 | 112 | @staticmethod 113 | def get_loss(): 114 | return F.cross_entropy 115 | 116 | @staticmethod 117 | def get_normalization_transform(): 118 | transform = transforms.Normalize((0.5071, 0.4867, 0.4408), 119 | (0.2675, 0.2565, 0.2761)) 120 | return transform 121 | 122 | @staticmethod 123 | def get_denormalization_transform(): 124 | transform = DeNormalize((0.5071, 0.4867, 0.4408), 125 | (0.2675, 0.2565, 0.2761)) 126 | return transform 127 | 128 | @staticmethod 129 | def get_epochs(): 130 | return 50 131 | 132 | @staticmethod 133 | def get_batch_size(): 134 | return 32 135 | 136 | @staticmethod 137 | def get_minibatch_size(): 138 | return SequentialCIFAR100.get_batch_size() 139 | 140 | @staticmethod 141 | def get_scheduler(model, args) -> torch.optim.lr_scheduler: 142 | model.opt = torch.optim.SGD(model.net.parameters(), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom) 143 | scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False) 144 | return scheduler 145 | 146 | -------------------------------------------------------------------------------- /utils/loggers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import csv 7 | import os 8 | import sys 9 | from typing import Dict, Any 10 | from utils.metrics import * 11 | 12 | from utils import create_if_not_exists 13 | from utils.conf import base_path 14 | import numpy as np 15 | 16 | useless_args = ['dataset', 'tensorboard', 'validation', 'model', 17 | 'csv_log', 'notes', 'load_best_args'] 18 | 19 | 20 | def print_mean_accuracy(mean_acc: np.ndarray, task_number: int, 21 | setting: str) -> None: 22 | """ 23 | Prints the mean accuracy on stderr. 24 | :param mean_acc: mean accuracy value 25 | :param task_number: task index 26 | :param setting: the setting of the benchmark 27 | """ 28 | if setting == 'domain-il': 29 | mean_acc, _ = mean_acc 30 | print('\nAccuracy for {} task(s): {} %'.format( 31 | task_number, round(mean_acc, 2)), file=sys.stderr) 32 | else: 33 | mean_acc_class_il, mean_acc_task_il = mean_acc 34 | print('\nAccuracy for {} task(s): \t [Class-IL]: {} %' 35 | ' \t [Task-IL]: {} %\n'.format(task_number, round( 36 | mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr) 37 | 38 | 39 | class CsvLogger: 40 | def __init__(self, setting_str: str, dataset_str: str, 41 | model_str: str) -> None: 42 | self.accs = [] 43 | if setting_str == 'class-il': 44 | self.accs_mask_classes = [] 45 | self.setting = setting_str 46 | self.dataset = dataset_str 47 | self.model = model_str 48 | self.fwt = None 49 | self.fwt_mask_classes = None 50 | self.bwt = None 51 | self.bwt_mask_classes = None 52 | self.forgetting = None 53 | self.forgetting_mask_classes = None 54 | 55 | def add_fwt(self, results, accs, results_mask_classes, accs_mask_classes): 56 | self.fwt = forward_transfer(results, accs) 57 | if self.setting == 'class-il': 58 | self.fwt_mask_classes = forward_transfer(results_mask_classes, accs_mask_classes) 59 | 60 | def add_bwt(self, results, results_mask_classes): 61 | self.bwt = backward_transfer(results) 62 | self.bwt_mask_classes = backward_transfer(results_mask_classes) 63 | 64 | def add_forgetting(self, results, results_mask_classes): 65 | self.forgetting = forgetting(results) 66 | self.forgetting_mask_classes = forgetting(results_mask_classes) 67 | 68 | def log(self, mean_acc: np.ndarray) -> None: 69 | """ 70 | Logs a mean accuracy value. 71 | :param mean_acc: mean accuracy value 72 | """ 73 | if self.setting == 'general-continual': 74 | self.accs.append(mean_acc) 75 | elif self.setting == 'domain-il': 76 | mean_acc, _ = mean_acc 77 | self.accs.append(mean_acc) 78 | else: 79 | mean_acc_class_il, mean_acc_task_il = mean_acc 80 | self.accs.append(mean_acc_class_il) 81 | self.accs_mask_classes.append(mean_acc_task_il) 82 | 83 | def write(self, args: Dict[str, Any]) -> None: 84 | """ 85 | writes out the logged value along with its arguments. 86 | :param args: the namespace of the current experiment 87 | """ 88 | for cc in useless_args: 89 | if cc in args: 90 | del args[cc] 91 | 92 | columns = list(args.keys()) 93 | 94 | new_cols = [] 95 | for i, acc in enumerate(self.accs): 96 | args['task' + str(i + 1)] = acc 97 | new_cols.append('task' + str(i + 1)) 98 | 99 | args['forward_transfer'] = self.fwt 100 | new_cols.append('forward_transfer') 101 | 102 | args['backward_transfer'] = self.bwt 103 | new_cols.append('backward_transfer') 104 | 105 | args['forgetting'] = self.forgetting 106 | new_cols.append('forgetting') 107 | 108 | columns = new_cols + columns 109 | 110 | create_if_not_exists(base_path() + "results/" + self.setting) 111 | create_if_not_exists(base_path() + "results/" + self.setting + 112 | "/" + self.dataset) 113 | create_if_not_exists(base_path() + "results/" + self.setting + 114 | "/" + self.dataset + "/" + self.model) 115 | 116 | write_headers = False 117 | path = base_path() + "results/" + self.setting + "/" + self.dataset\ 118 | + "/" + self.model + "/mean_accs.csv" 119 | if not os.path.exists(path): 120 | write_headers = True 121 | with open(path, 'a') as tmp: 122 | writer = csv.DictWriter(tmp, fieldnames=columns) 123 | if write_headers: 124 | writer.writeheader() 125 | writer.writerow(args) 126 | 127 | if self.setting == 'class-il': 128 | create_if_not_exists(base_path() + "results/task-il/" 129 | + self.dataset) 130 | create_if_not_exists(base_path() + "results/task-il/" 131 | + self.dataset + "/" + self.model) 132 | 133 | for i, acc in enumerate(self.accs_mask_classes): 134 | args['task' + str(i + 1)] = acc 135 | 136 | args['forward_transfer'] = self.fwt_mask_classes 137 | args['backward_transfer'] = self.bwt_mask_classes 138 | args['forgetting'] = self.forgetting_mask_classes 139 | 140 | write_headers = False 141 | path = base_path() + "results/task-il" + "/" + self.dataset + "/"\ 142 | + self.model + "/mean_accs.csv" 143 | if not os.path.exists(path): 144 | write_headers = True 145 | with open(path, 'a') as tmp: 146 | writer = csv.DictWriter(tmp, fieldnames=columns) 147 | if write_headers: 148 | writer.writeheader() 149 | writer.writerow(args) 150 | -------------------------------------------------------------------------------- /models/gem.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the gem_license file in the root of this source tree. 6 | 7 | import quadprog 8 | 9 | import numpy as np 10 | import torch 11 | from models.utils.continual_model import ContinualModel 12 | 13 | from utils.buffer import Buffer 14 | from utils.args import * 15 | 16 | 17 | def get_parser() -> ArgumentParser: 18 | parser = ArgumentParser(description='Continual learning via' 19 | ' Gradient Episodic Memory.') 20 | add_management_args(parser) 21 | add_experiment_args(parser) 22 | add_rehearsal_args(parser) 23 | # remove minibatch_size from parser 24 | for i in range(len(parser._actions)): 25 | if parser._actions[i].dest == 'minibatch_size': 26 | del parser._actions[i] 27 | break 28 | 29 | parser.add_argument('--gamma', type=float, default=None, 30 | help='Margin parameter for GEM.') 31 | return parser 32 | 33 | 34 | def store_grad(params, grads, grad_dims): 35 | """ 36 | This stores parameter gradients of past tasks. 37 | pp: parameters 38 | grads: gradients 39 | grad_dims: list with number of parameters per layers 40 | """ 41 | # store the gradients 42 | grads.fill_(0.0) 43 | count = 0 44 | for param in params(): 45 | if param.grad is not None: 46 | begin = 0 if count == 0 else sum(grad_dims[:count]) 47 | end = np.sum(grad_dims[:count + 1]) 48 | grads[begin: end].copy_(param.grad.data.view(-1)) 49 | count += 1 50 | 51 | 52 | def overwrite_grad(params, newgrad, grad_dims): 53 | """ 54 | This is used to overwrite the gradients with a new gradient 55 | vector, whenever violations occur. 56 | pp: parameters 57 | newgrad: corrected gradient 58 | grad_dims: list storing number of parameters at each layer 59 | """ 60 | count = 0 61 | for param in params(): 62 | if param.grad is not None: 63 | begin = 0 if count == 0 else sum(grad_dims[:count]) 64 | end = sum(grad_dims[:count + 1]) 65 | this_grad = newgrad[begin: end].contiguous().view( 66 | param.grad.data.size()) 67 | param.grad.data.copy_(this_grad) 68 | count += 1 69 | 70 | 71 | def project2cone2(gradient, memories, margin=0.5, eps=1e-3): 72 | """ 73 | Solves the GEM dual QP described in the paper given a proposed 74 | gradient "gradient", and a memory of task gradients "memories". 75 | Overwrites "gradient" with the final projected update. 76 | 77 | input: gradient, p-vector 78 | input: memories, (t * p)-vector 79 | output: x, p-vector 80 | """ 81 | memories_np = memories.cpu().t().double().numpy() 82 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy() 83 | n_rows = memories_np.shape[0] 84 | self_prod = np.dot(memories_np, memories_np.transpose()) 85 | self_prod = 0.5 * (self_prod + self_prod.transpose()) + np.eye(n_rows) * eps 86 | grad_prod = np.dot(memories_np, gradient_np) * -1 87 | G = np.eye(n_rows) 88 | h = np.zeros(n_rows) + margin 89 | v = quadprog.solve_qp(self_prod, grad_prod, G, h)[0] 90 | x = np.dot(v, memories_np) + gradient_np 91 | gradient.copy_(torch.from_numpy(x).view(-1, 1)) 92 | 93 | 94 | class Gem(ContinualModel): 95 | NAME = 'gem' 96 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] 97 | 98 | def __init__(self, backbone, loss, args, transform): 99 | super(Gem, self).__init__(backbone, loss, args, transform) 100 | self.current_task = 0 101 | self.buffer = Buffer(self.args.buffer_size, self.device) 102 | 103 | # Allocate temporary synaptic memory 104 | self.grad_dims = [] 105 | for pp in self.parameters(): 106 | self.grad_dims.append(pp.data.numel()) 107 | 108 | self.grads_cs = [] 109 | self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device) 110 | 111 | def end_task(self, dataset): 112 | self.current_task += 1 113 | self.grads_cs.append(torch.zeros( 114 | np.sum(self.grad_dims)).to(self.device)) 115 | 116 | # add data to the buffer 117 | samples_per_task = self.args.buffer_size // dataset.N_TASKS 118 | 119 | loader = dataset.train_loader 120 | cur_y, cur_x = next(iter(loader))[1:] 121 | self.buffer.add_data( 122 | examples=cur_x.to(self.device), 123 | labels=cur_y.to(self.device), 124 | task_labels=torch.ones(samples_per_task, 125 | dtype=torch.long).to(self.device) * (self.current_task - 1) 126 | ) 127 | 128 | 129 | def observe(self, inputs, labels, not_aug_inputs): 130 | 131 | if not self.buffer.is_empty(): 132 | buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data( 133 | self.args.buffer_size, transform=self.transform) 134 | 135 | for tt in buf_task_labels.unique(): 136 | # compute gradient on the memory buffer 137 | self.opt.zero_grad() 138 | cur_task_inputs = buf_inputs[buf_task_labels == tt] 139 | cur_task_labels = buf_labels[buf_task_labels == tt] 140 | cur_task_outputs = self.forward(cur_task_inputs) 141 | penalty = self.loss(cur_task_outputs, cur_task_labels) 142 | penalty.backward() 143 | store_grad(self.parameters, self.grads_cs[tt], self.grad_dims) 144 | 145 | # now compute the grad on the current data 146 | self.opt.zero_grad() 147 | outputs = self.forward(inputs) 148 | loss = self.loss(outputs, labels) 149 | loss.backward() 150 | 151 | # check if gradient violates buffer constraints 152 | if not self.buffer.is_empty(): 153 | # copy gradient 154 | store_grad(self.parameters, self.grads_da, self.grad_dims) 155 | 156 | dot_prod = torch.mm(self.grads_da.unsqueeze(0), 157 | torch.stack(self.grads_cs).T) 158 | if (dot_prod < 0).sum() != 0: 159 | project2cone2(self.grads_da.unsqueeze(1), 160 | torch.stack(self.grads_cs).T, margin=self.args.gamma) 161 | # copy gradients back 162 | overwrite_grad(self.parameters, self.grads_da, 163 | self.grad_dims) 164 | 165 | self.opt.step() 166 | 167 | return loss.item() 168 | -------------------------------------------------------------------------------- /datasets/seq_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import Dataset 9 | from backbone.ResNet18 import resnet18 10 | import torch.nn.functional as F 11 | from utils.conf import base_path 12 | from PIL import Image 13 | import os 14 | from datasets.utils.validation import get_train_val 15 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 16 | from datasets.utils.continual_dataset import get_previous_train_loader 17 | from datasets.transforms.denormalization import DeNormalize 18 | 19 | 20 | class TinyImagenet(Dataset): 21 | """ 22 | Defines Tiny Imagenet as for the others pytorch datasets. 23 | """ 24 | def __init__(self, root: str, train: bool=True, transform: transforms=None, 25 | target_transform: transforms=None, download: bool=False) -> None: 26 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()]) 27 | self.root = root 28 | self.train = train 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | self.download = download 32 | 33 | if download: 34 | if os.path.isdir(root) and len(os.listdir(root)) > 0: 35 | print('Download not needed, files already on disk.') 36 | else: 37 | from google_drive_downloader import GoogleDriveDownloader as gdd 38 | 39 | # https://drive.google.com/file/d/1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj/view 40 | print('Downloading dataset') 41 | gdd.download_file_from_google_drive( 42 | file_id='1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj', 43 | 44 | dest_path=os.path.join(root, 'tiny-imagenet-processed.zip'), 45 | unzip=True) 46 | 47 | self.data = [] 48 | for num in range(20): 49 | self.data.append(np.load(os.path.join( 50 | root, 'processed/x_%s_%02d.npy' % 51 | ('train' if self.train else 'val', num+1)))) 52 | self.data = np.concatenate(np.array(self.data)) 53 | 54 | self.targets = [] 55 | for num in range(20): 56 | self.targets.append(np.load(os.path.join( 57 | root, 'processed/y_%s_%02d.npy' % 58 | ('train' if self.train else 'val', num+1)))) 59 | self.targets = np.concatenate(np.array(self.targets)) 60 | 61 | def __len__(self): 62 | return len(self.data) 63 | 64 | def __getitem__(self, index): 65 | img, target = self.data[index], self.targets[index] 66 | 67 | # doing this so that it is consistent with all other datasets 68 | # to return a PIL Image 69 | img = Image.fromarray(np.uint8(255 * img)) 70 | original_img = img.copy() 71 | 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | 75 | if self.target_transform is not None: 76 | target = self.target_transform(target) 77 | 78 | if hasattr(self, 'logits'): 79 | return img, target, original_img, self.logits[index] 80 | 81 | return img, target 82 | 83 | 84 | class MyTinyImagenet(TinyImagenet): 85 | """ 86 | Defines Tiny Imagenet as for the others pytorch datasets. 87 | """ 88 | def __init__(self, root: str, train: bool=True, transform: transforms=None, 89 | target_transform: transforms=None, download: bool=False) -> None: 90 | super(MyTinyImagenet, self).__init__( 91 | root, train, transform, target_transform, download) 92 | 93 | def __getitem__(self, index): 94 | img, target = self.data[index], self.targets[index] 95 | 96 | # doing this so that it is consistent with all other datasets 97 | # to return a PIL Image 98 | img = Image.fromarray(np.uint8(255 * img)) 99 | original_img = img.copy() 100 | 101 | not_aug_img = self.not_aug_transform(original_img) 102 | 103 | if self.transform is not None: 104 | img = self.transform(img) 105 | 106 | if self.target_transform is not None: 107 | target = self.target_transform(target) 108 | 109 | if hasattr(self, 'logits'): 110 | return img, target, not_aug_img, self.logits[index] 111 | 112 | return img, target, not_aug_img 113 | 114 | 115 | class SequentialTinyImagenet(ContinualDataset): 116 | 117 | NAME = 'seq-tinyimg' 118 | SETTING = 'class-il' 119 | N_CLASSES_PER_TASK = 20 120 | N_TASKS = 10 121 | TRANSFORM = transforms.Compose( 122 | [transforms.RandomCrop(64, padding=4), 123 | transforms.RandomHorizontalFlip(), 124 | transforms.ToTensor(), 125 | transforms.Normalize((0.4802, 0.4480, 0.3975), 126 | (0.2770, 0.2691, 0.2821))]) 127 | 128 | def get_data_loaders(self): 129 | transform = self.TRANSFORM 130 | 131 | test_transform = transforms.Compose( 132 | [transforms.ToTensor(), self.get_normalization_transform()]) 133 | 134 | train_dataset = MyTinyImagenet(base_path() + 'TINYIMG', 135 | train=True, download=True, transform=transform) 136 | if self.args.validation: 137 | train_dataset, test_dataset = get_train_val(train_dataset, 138 | test_transform, self.NAME) 139 | else: 140 | test_dataset = TinyImagenet(base_path() + 'TINYIMG', 141 | train=False, download=True, transform=test_transform) 142 | 143 | train, test = store_masked_loaders(train_dataset, test_dataset, self) 144 | return train, test 145 | 146 | @staticmethod 147 | def get_backbone(): 148 | return resnet18(SequentialTinyImagenet.N_CLASSES_PER_TASK 149 | * SequentialTinyImagenet.N_TASKS) 150 | 151 | @staticmethod 152 | def get_loss(): 153 | return F.cross_entropy 154 | 155 | def get_transform(self): 156 | transform = transforms.Compose( 157 | [transforms.ToPILImage(), self.TRANSFORM]) 158 | return transform 159 | 160 | @staticmethod 161 | def get_normalization_transform(): 162 | transform = transforms.Normalize((0.4802, 0.4480, 0.3975), 163 | (0.2770, 0.2691, 0.2821)) 164 | return transform 165 | 166 | @staticmethod 167 | def get_denormalization_transform(): 168 | transform = DeNormalize((0.4802, 0.4480, 0.3975), 169 | (0.2770, 0.2691, 0.2821)) 170 | return transform 171 | 172 | @staticmethod 173 | def get_scheduler(model, args): 174 | return None -------------------------------------------------------------------------------- /models/hal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from utils.buffer import Buffer 8 | from utils.args import * 9 | from models.utils.continual_model import ContinualModel 10 | from datasets import get_dataset 11 | import numpy as np 12 | from torch.optim import SGD 13 | import sys 14 | 15 | import numpy as np 16 | import copy 17 | from torch.nn import functional as F 18 | import math 19 | from torch.optim import SGD 20 | from collections import OrderedDict 21 | EPS = 1E-20 22 | 23 | def get_parser() -> ArgumentParser: 24 | parser = ArgumentParser(description='Continual learning via' 25 | ' Experience Replay.') 26 | add_management_args(parser) 27 | add_experiment_args(parser) 28 | add_rehearsal_args(parser) 29 | 30 | parser.add_argument('--hal_lambda', type=float, default=0.1) 31 | parser.add_argument('--beta', type=float, default=0.5) 32 | parser.add_argument('--gamma', type=float, default=0.1) 33 | 34 | return parser 35 | 36 | 37 | class HAL(ContinualModel): 38 | NAME = 'hal' 39 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] 40 | 41 | def __init__(self, backbone, loss, args, transform): 42 | super(HAL, self).__init__(backbone, loss, args, transform) 43 | self.task_number = 0 44 | self.buffer = Buffer(self.args.buffer_size, self.device, get_dataset(args).N_TASKS, mode='ring') 45 | self.hal_lambda = args.hal_lambda 46 | self.beta = args.beta 47 | self.gamma = args.gamma 48 | self.anchor_optimization_steps = 100 49 | self.finetuning_epochs = 1 50 | self.dataset = get_dataset(args) 51 | self.spare_model = self.dataset.get_backbone() 52 | self.spare_model.to(self.device) 53 | self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr) 54 | 55 | def end_task(self, dataset): 56 | self.task_number += 1 57 | # ring buffer mgmt (if we are not loading 58 | if self.task_number > self.buffer.task_number: 59 | self.buffer.num_seen_examples = 0 60 | self.buffer.task_number = self.task_number 61 | # get anchors (provided that we are not loading the model 62 | if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK: 63 | self.get_anchors(dataset) 64 | del self.phi 65 | 66 | def get_anchors(self, dataset): 67 | theta_t = self.net.get_params().detach().clone() 68 | self.spare_model.set_params(theta_t) 69 | 70 | # fine tune on memory buffer 71 | for _ in range(self.finetuning_epochs): 72 | inputs, labels = self.buffer.get_data(self.args.batch_size, transform=self.transform) 73 | self.spare_opt.zero_grad() 74 | out = self.spare_model(inputs) 75 | loss = self.loss(out, labels) 76 | loss.backward() 77 | self.spare_opt.step() 78 | 79 | theta_m = self.spare_model.get_params().detach().clone() 80 | 81 | classes_for_this_task = np.unique(dataset.train_loader.dataset.targets) 82 | 83 | for a_class in classes_for_this_task: 84 | e_t = torch.rand(self.input_shape, requires_grad=True, device=self.device) 85 | e_t_opt = SGD([e_t], lr=self.args.lr) 86 | print(file=sys.stderr) 87 | for i in range(self.anchor_optimization_steps): 88 | e_t_opt.zero_grad() 89 | cum_loss = 0 90 | 91 | self.spare_opt.zero_grad() 92 | self.spare_model.set_params(theta_m.detach().clone()) 93 | loss = -torch.sum(self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) 94 | loss.backward() 95 | cum_loss += loss.item() 96 | 97 | self.spare_opt.zero_grad() 98 | self.spare_model.set_params(theta_t.detach().clone()) 99 | loss = torch.sum(self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) 100 | loss.backward() 101 | cum_loss += loss.item() 102 | 103 | self.spare_opt.zero_grad() 104 | loss = torch.sum(self.gamma * (self.spare_model(e_t.unsqueeze(0), returnt='features') - self.phi) ** 2) 105 | assert not self.phi.requires_grad 106 | loss.backward() 107 | cum_loss += loss.item() 108 | 109 | e_t_opt.step() 110 | 111 | e_t = e_t.detach() 112 | e_t.requires_grad = False 113 | self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0))) 114 | del e_t 115 | print('Total anchors:', len(self.anchors), file=sys.stderr) 116 | 117 | self.spare_model.zero_grad() 118 | 119 | def observe(self, inputs, labels, not_aug_inputs): 120 | real_batch_size = inputs.shape[0] 121 | if not hasattr(self, 'input_shape'): 122 | self.input_shape = inputs.shape[1:] 123 | if not hasattr(self, 'anchors'): 124 | self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to(self.device) 125 | if not hasattr(self, 'phi'): 126 | print('Building phi', file=sys.stderr) 127 | with torch.no_grad(): 128 | self.phi = torch.zeros_like(self.net(inputs[0].unsqueeze(0), returnt='features'), requires_grad=False) 129 | assert not self.phi.requires_grad 130 | 131 | if not self.buffer.is_empty(): 132 | buf_inputs, buf_labels = self.buffer.get_data( 133 | self.args.minibatch_size, transform=self.transform) 134 | inputs = torch.cat((inputs, buf_inputs)) 135 | labels = torch.cat((labels, buf_labels)) 136 | 137 | old_weights = self.net.get_params().detach().clone() 138 | 139 | self.opt.zero_grad() 140 | outputs = self.net(inputs) 141 | 142 | k = self.task_number 143 | 144 | loss = self.loss(outputs, labels) 145 | loss.backward() 146 | self.opt.step() 147 | 148 | first_loss = 0 149 | 150 | assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k 151 | 152 | if len(self.anchors) > 0: 153 | first_loss = loss.item() 154 | with torch.no_grad(): 155 | pred_anchors = self.net(self.anchors) 156 | 157 | self.net.set_params(old_weights) 158 | pred_anchors -= self.net(self.anchors) 159 | loss = self.hal_lambda * (pred_anchors ** 2).mean() 160 | loss.backward() 161 | self.opt.step() 162 | 163 | with torch.no_grad(): 164 | self.phi = self.beta * self.phi + (1 - self.beta) * self.net(inputs[:real_batch_size], returnt='features').mean(0) 165 | 166 | self.buffer.add_data(examples=not_aug_inputs, 167 | labels=labels[:real_batch_size]) 168 | 169 | return first_loss + loss.item() 170 | -------------------------------------------------------------------------------- /utils/gss_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from typing import Tuple 10 | from torchvision import transforms 11 | 12 | class Buffer: 13 | """ 14 | The memory buffer of rehearsal method. 15 | """ 16 | def __init__(self, buffer_size, device, minibatch_size, model=None): 17 | self.buffer_size = buffer_size 18 | self.device = device 19 | self.num_seen_examples = 0 20 | self.attributes = ['examples', 'labels'] 21 | self.model = model 22 | self.minibatch_size = minibatch_size 23 | self.cache = {} 24 | self.fathom = 0 25 | self.fathom_mask = None 26 | self.reset_fathom() 27 | 28 | self.conterone = 0 29 | 30 | def reset_fathom(self): 31 | self.fathom = 0 32 | self.fathom_mask = torch.randperm(min(self.num_seen_examples, self.examples.shape[0] if hasattr(self, 'examples') else self.num_seen_examples)) 33 | 34 | def get_grad_score(self, x, y, X, Y, indices): 35 | g = self.model.get_grads(x, y) 36 | G = [] 37 | for x, y, idx in zip(X, Y, indices): 38 | if idx in self.cache: 39 | grd = self.cache[idx] 40 | else: 41 | grd = self.model.get_grads(x.unsqueeze(0), y.unsqueeze(0)) 42 | self.cache[idx] = grd 43 | G.append(grd) 44 | G = torch.cat(G).to(g.device) 45 | c_score = 0 46 | grads_at_a_time = 5 47 | # let's split this so your gpu does not melt. You're welcome. 48 | for it in range(int(np.ceil(G.shape[0] / grads_at_a_time))): 49 | tmp = F.cosine_similarity(g, G[it*grads_at_a_time: (it+1)*grads_at_a_time], dim=1).max().item() + 1 50 | c_score = max(c_score, tmp) 51 | return c_score 52 | 53 | def functional_reservoir(self, x, y, batch_c, bigX=None, bigY=None, indices=None): 54 | if self.num_seen_examples < self.buffer_size: 55 | return self.num_seen_examples, batch_c 56 | 57 | elif batch_c < 1: 58 | single_c = self.get_grad_score(x.unsqueeze(0), y.unsqueeze(0), bigX, bigY, indices) 59 | s = self.scores.cpu().numpy() 60 | i = np.random.choice(np.arange(0, self.buffer_size), size=1, p=s / s.sum())[0] 61 | rand = np.random.rand(1)[0] 62 | # print(rand, s[i] / (s[i] + c)) 63 | if rand < s[i] / (s[i] + single_c): 64 | return i, single_c 65 | 66 | return -1, 0 67 | 68 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor) -> None: 69 | """ 70 | Initializes just the required tensors. 71 | :param examples: tensor containing the images 72 | :param labels: tensor containing the labels 73 | :param logits: tensor containing the outputs of the network 74 | :param task_labels: tensor containing the task labels 75 | """ 76 | for attr_str in self.attributes: 77 | attr = eval(attr_str) 78 | if attr is not None and not hasattr(self, attr_str): 79 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 80 | setattr(self, attr_str, torch.zeros((self.buffer_size, 81 | *attr.shape[1:]), dtype=typ, device=self.device)) 82 | self.scores = torch.zeros((self.buffer_size,*attr.shape[1:]), 83 | dtype=torch.float32, device=self.device) 84 | 85 | def add_data(self, examples, labels=None): 86 | """ 87 | Adds the data to the memory buffer according to the reservoir strategy. 88 | :param examples: tensor containing the images 89 | :param labels: tensor containing the labels 90 | :param logits: tensor containing the outputs of the network 91 | :param task_labels: tensor containing the task labels 92 | :return: 93 | """ 94 | if not hasattr(self, 'examples'): 95 | self.init_tensors(examples, labels) 96 | 97 | # compute buffer score 98 | if self.num_seen_examples > 0: 99 | bigX, bigY, indices = self.get_data(min(self.minibatch_size, self.num_seen_examples), give_index=True, 100 | random=True) 101 | c = self.get_grad_score(examples, labels, bigX, bigY, indices) 102 | else: 103 | bigX, bigY, indices = None, None, None 104 | c = 0.1 105 | 106 | for i in range(examples.shape[0]): 107 | index, score = self.functional_reservoir(examples[i], labels[i], c, bigX, bigY, indices) 108 | self.num_seen_examples += 1 109 | if index >= 0: 110 | self.examples[index] = examples[i].to(self.device) 111 | if labels is not None: 112 | self.labels[index] = labels[i].to(self.device) 113 | self.scores[index] = score 114 | if index in self.cache: 115 | del self.cache[index] 116 | 117 | def drop_cache(self): 118 | self.cache = {} 119 | 120 | def get_data(self, size: int, transform: transforms=None, give_index=False, random=False) -> Tuple: 121 | """ 122 | Random samples a batch of size items. 123 | :param size: the number of requested items 124 | :param transform: the transformation to be applied (data augmentation) 125 | :return: 126 | """ 127 | 128 | if size > self.examples.shape[0]: 129 | size = self.examples.shape[0] 130 | 131 | if random: 132 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), 133 | size=min(size, self.num_seen_examples), 134 | replace=False) 135 | else: 136 | choice = np.arange(self.fathom, min(self.fathom + size, self.examples.shape[0], self.num_seen_examples)) 137 | choice = self.fathom_mask[choice] 138 | self.fathom += len(choice) 139 | if self.fathom >= self.examples.shape[0] or self.fathom >= self.num_seen_examples: 140 | self.fathom = 0 141 | if transform is None: transform = lambda x: x 142 | ret_tuple = (torch.stack([transform(ee.cpu()) 143 | for ee in self.examples[choice]]).to(self.device),) 144 | for attr_str in self.attributes[1:]: 145 | if hasattr(self, attr_str): 146 | attr = getattr(self, attr_str) 147 | ret_tuple += (attr[choice],) 148 | if give_index: 149 | ret_tuple += (choice,) 150 | 151 | return ret_tuple 152 | 153 | def is_empty(self) -> bool: 154 | """ 155 | Returns true if the buffer is empty, false otherwise. 156 | """ 157 | if self.num_seen_examples == 0: 158 | return True 159 | else: 160 | return False 161 | 162 | def get_all_data(self, transform: transforms=None) -> Tuple: 163 | """ 164 | Return all the items in the memory buffer. 165 | :param transform: the transformation to be applied (data augmentation) 166 | :return: a tuple with all the items in the memory buffer 167 | """ 168 | if transform is None: transform = lambda x: x 169 | ret_tuple = (torch.stack([transform(ee.cpu()) 170 | for ee in self.examples]).to(self.device),) 171 | for attr_str in self.attributes[1:]: 172 | if hasattr(self, attr_str): 173 | attr = getattr(self, attr_str) 174 | ret_tuple += (attr,) 175 | return ret_tuple 176 | 177 | def empty(self) -> None: 178 | """ 179 | Set all the tensors to None. 180 | """ 181 | for attr_str in self.attributes: 182 | if hasattr(self, attr_str): 183 | delattr(self, attr_str) 184 | self.num_seen_examples = 0 185 | -------------------------------------------------------------------------------- /utils/vmf_sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | class VonMisesFisher(torch.distributions.Distribution): 4 | 5 | arg_constraints = { 6 | "loc": torch.distributions.constraints.real, 7 | "scale": torch.distributions.constraints.positive, 8 | } 9 | support = torch.distributions.constraints.real 10 | has_rsample = True 11 | _mean_carrier_measure = 0 12 | 13 | def __init__(self, loc, scale, validate_args=None, k=1): 14 | self.dtype = loc.dtype 15 | self.loc = loc 16 | self.scale = scale 17 | self.device = loc.device 18 | self.__m = loc.shape[-1] 19 | self.__e1 = (torch.Tensor([1.0] + [0] * (loc.shape[-1] - 1))).to(self.device) 20 | self.k = k 21 | 22 | super().__init__(self.loc.size(), validate_args=validate_args) 23 | 24 | def sample(self, shape=torch.Size()): 25 | with torch.no_grad(): 26 | return self.rsample(shape) 27 | 28 | def rsample(self, shape=torch.Size()): 29 | shape = shape if isinstance(shape, torch.Size) else torch.Size([shape]) 30 | 31 | w = ( 32 | self.__sample_w3(shape=shape) 33 | if self.__m == 3 34 | else self.__sample_w_rej(shape=shape) 35 | ) 36 | 37 | v = ( 38 | torch.distributions.Normal(0, 1) 39 | .sample(shape + torch.Size(self.loc.shape)) 40 | .to(self.device) 41 | .transpose(0, -1)[1:] 42 | ).transpose(0, -1) 43 | v = v / v.norm(dim=-1, keepdim=True) 44 | 45 | w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) 46 | x = torch.cat((w, w_ * v), -1) 47 | z = self.__householder_rotation(x) 48 | 49 | return z.type(self.dtype) 50 | 51 | def __sample_w3(self, shape): 52 | shape = shape + torch.Size(self.scale.shape) 53 | u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device) 54 | self.__w = ( 55 | 1 56 | + torch.stack( 57 | [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0 58 | ).logsumexp(0) 59 | / self.scale 60 | ) 61 | return self.__w 62 | 63 | def __sample_w_rej(self, shape): 64 | c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2) 65 | b_true = (-2 * self.scale + c) / (self.__m - 1) 66 | 67 | # using Taylor approximation with a smooth swift from 10 < scale < 11 68 | # to avoid numerical errors for large scale 69 | b_app = (self.__m - 1) / (4 * self.scale) 70 | s = torch.min( 71 | torch.max( 72 | torch.tensor([0.0], dtype=self.dtype, device=self.device), 73 | self.scale - 10, 74 | ), 75 | torch.tensor([1.0], dtype=self.dtype, device=self.device), 76 | ) 77 | b = b_app * s + b_true * (1 - s) 78 | 79 | a = (self.__m - 1 + 2 * self.scale + c) / 4 80 | d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1) 81 | 82 | self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape, k=self.k) 83 | return self.__w 84 | 85 | @staticmethod 86 | def first_nonzero(x, dim, invalid_val=-1): 87 | mask = x > 0 88 | idx = torch.where( 89 | mask.any(dim=dim), 90 | mask.float().argmax(dim=1).squeeze(), 91 | torch.tensor(invalid_val, device=x.device), 92 | ) 93 | return idx 94 | 95 | def __while_loop(self, b, a, d, shape, k=20, eps=1e-20): 96 | # matrix while loop: samples a matrix of [A, k] samples, to avoid looping all together 97 | b, a, d = [ 98 | e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1) 99 | for e in (b, a, d) 100 | ] 101 | w, e, bool_mask = ( 102 | torch.zeros_like(b).to(self.device), 103 | torch.zeros_like(b).to(self.device), 104 | (torch.ones_like(b) == 1).to(self.device), 105 | ) 106 | 107 | sample_shape = torch.Size([b.shape[0], k]) 108 | shape = shape + torch.Size(self.scale.shape) 109 | 110 | while bool_mask.sum() != 0: 111 | con1 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) 112 | con2 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) 113 | e_ = ( 114 | torch.distributions.Beta(con1, con2) 115 | .sample(sample_shape) 116 | .to(self.device) 117 | .type(self.dtype) 118 | ) 119 | 120 | u = ( 121 | torch.distributions.Uniform(0 + eps, 1 - eps) 122 | .sample(sample_shape) 123 | .to(self.device) 124 | .type(self.dtype) 125 | ) 126 | 127 | w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_) 128 | t = (2 * a * b) / (1 - (1 - b) * e_) 129 | 130 | accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u) 131 | accept_idx = self.first_nonzero(accept, dim=-1, invalid_val=-1).unsqueeze(1) 132 | accept_idx_clamped = accept_idx.clamp(0) 133 | # we use .abs(), in order to not get -1 index issues, the -1 is still used afterwards 134 | w_ = w_.gather(1, accept_idx_clamped.view(-1, 1)) 135 | e_ = e_.gather(1, accept_idx_clamped.view(-1, 1)) 136 | 137 | reject = accept_idx < 0 138 | accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject 139 | 140 | w[bool_mask * accept] = w_[bool_mask * accept] 141 | e[bool_mask * accept] = e_[bool_mask * accept] 142 | 143 | bool_mask[bool_mask * accept] = reject[bool_mask * accept] 144 | 145 | return e.reshape(shape), w.reshape(shape) 146 | 147 | def __householder_rotation(self, x): 148 | u = self.__e1 - self.loc 149 | u = u / (u.norm(dim=-1, keepdim=True) + 1e-5) 150 | z = x - 2 * (x * u).sum(-1, keepdim=True) * u 151 | return z 152 | 153 | 154 | # polar_noise = torch.full([32, 512], 100.0) #torch.normal(0.0, 1, size=(32,512)) 155 | # norm_x = torch.norm(polar_noise.clone(), 2, 1, keepdim=True) 156 | # polar_noise = polar_noise / norm_x 157 | # print(polar_noise) 158 | # z_var = torch.full([32, 1], 1.0) 159 | # q_z = VonMisesFisher(polar_noise, z_var) 160 | 161 | # z = q_z.rsample() 162 | # print(z.shape) 163 | # print(z) 164 | 165 | 166 | 167 | # if 'vmf' in self.args.method2: 168 | # vmf_buf_features = copy.deepcopy(buf_features.detach()) 169 | # vmf_buf_features = euclid2polar(vmf_buf_features) 170 | # kappa = self.args.gamma_loss 171 | # polar_noise = torch.from_numpy(np.random.vonmises(0, kappa, (vmf_buf_features.shape[0], vmf_buf_features.shape[1]))) 172 | # if 'vmf_g' == self.args.method2: 173 | # polar_noise = torch.normal(0.0, 1, size=vmf_buf_features.shape).to(vmf_buf_features.device) 174 | # polar_mask = polar_noise < 0 175 | # polar_noise[polar_mask] = -polar_noise[polar_mask] 176 | # polar_noise = polar_noise.to(vmf_buf_features).cuda() 177 | 178 | # # vmf_buf_features = vmf_buf_features.cuda() + polar_noise.cuda() 179 | # # vmf_buf_features = polar2euclid(vmf_buf_features) 180 | # # buf_features_noise = vmf_buf_features - buf_features 181 | # buf_features_noise = polar2euclid(polar_noise) 182 | # buf_features_noise = buf_features_noise.detach().cuda() 183 | # buf_features_noise = buf_features_noise.to(torch.float32) 184 | # if 'vmf_pa' == self.args.method2: 185 | # vmf_buf_features = vmf_buf_features.cuda() + polar_noise.cuda() 186 | # buf_features_noise = polar2euclid(vmf_buf_features) 187 | # buf_features_noise = buf_features_noise.detach().cuda() 188 | # buf_features_noise = buf_features_noise.to(torch.float32) 189 | # buf_features_noise = buf_features_noise - buf_features.detach() --------------------------------------------------------------------------------