├── utils
├── euclid2polar.py
├── __init__.py
├── mixup_utils.py
├── conf.py
├── batch_norm.py
├── metrics.py
├── plot_lamda_acc.py
├── plot_perturbed_angular_acc.py
├── plot_mixup_acc.py
├── plot_training_gradient.py
├── plot_all_method_acc.py
├── continual_training.py
├── status.py
├── tb_logger.py
├── debug.py
├── simclrloss.py
├── augmentations.py
├── feature_gradient_svd.py
├── plot_acc_curve.py
├── method_variation.py
├── ring_buffer.py
├── args.py
├── plot_test_variance.py
├── main.py
├── loggers.py
├── gss_buffer.py
└── vmf_sampling.py
├── backbone
├── utils
│ ├── __init__.py
│ └── modules.py
└── __init__.py
├── datasets
├── utils
│ ├── __init__.py
│ ├── gcl_dataset.py
│ ├── validation.py
│ └── continual_dataset.py
├── transforms
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── rotation.cpython-36.pyc
│ │ ├── rotation.cpython-37.pyc
│ │ ├── permutation.cpython-36.pyc
│ │ ├── permutation.cpython-37.pyc
│ │ ├── denormalization.cpython-36.pyc
│ │ └── denormalization.cpython-37.pyc
│ ├── denormalization.py
│ ├── permutation.py
│ └── rotation.py
├── rot_mnist.py
├── __init__.py
├── seq_mnist.py
├── perm_mnist.py
├── seq_cifar10.py
├── seq_cifar100.py
└── seq_tinyimagenet.py
├── models
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── continual_model.cpython-36.pyc
│ │ └── continual_model.cpython-37.pyc
│ └── continual_model.py
├── __init__.py
├── sgd.py
├── der.py
├── joint_gcl.py
├── agem_r.py
├── si.py
├── agem.py
├── ewc_on.py
├── gss.py
├── pnn.py
├── mer.py
├── fdr.py
├── gdumb.py
├── lwf.py
├── joint.py
├── rpc.py
├── gem.py
└── hal.py
├── misc
├── vmf_svd-1.png
├── moca_graph-1.png
├── 2d_feat_vis-1.png
├── MOCA_variants-1.png
├── grad_combined-1.png
├── moca_framwork-1.png
└── var_comp_v4-1.png
└── README.md
/utils/euclid2polar.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/backbone/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/misc/vmf_svd-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/vmf_svd-1.png
--------------------------------------------------------------------------------
/misc/moca_graph-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/moca_graph-1.png
--------------------------------------------------------------------------------
/misc/2d_feat_vis-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/2d_feat_vis-1.png
--------------------------------------------------------------------------------
/misc/MOCA_variants-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/MOCA_variants-1.png
--------------------------------------------------------------------------------
/misc/grad_combined-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/grad_combined-1.png
--------------------------------------------------------------------------------
/misc/moca_framwork-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/moca_framwork-1.png
--------------------------------------------------------------------------------
/misc/var_comp_v4-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/misc/var_comp_v4-1.png
--------------------------------------------------------------------------------
/models/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/rotation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/rotation.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/rotation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/rotation.cpython-37.pyc
--------------------------------------------------------------------------------
/models/utils/__pycache__/continual_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/continual_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/utils/__pycache__/continual_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/models/utils/__pycache__/continual_model.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/permutation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/permutation.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/permutation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/permutation.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/denormalization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/denormalization.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/transforms/__pycache__/denormalization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulonghui/MOCA/HEAD/datasets/transforms/__pycache__/denormalization.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/utils/gcl_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | class GCLDataset:
7 | """
8 | Continual learning evaluation setting.
9 | """
10 | NAME = None
11 | SETTING = None
12 | N_CLASSES = None
13 | LENGTH = None
14 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 |
8 |
9 | def create_if_not_exists(path: str) -> None:
10 | """
11 | Creates the specified folder if it does not exist.
12 | :param path: the complete path of the folder to be created
13 | """
14 | if not os.path.exists(path):
15 | os.makedirs(path)
16 |
--------------------------------------------------------------------------------
/datasets/transforms/denormalization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | class DeNormalize(object):
8 | def __init__(self, mean, std):
9 | self.mean = mean
10 | self.std = std
11 |
12 | def __call__(self, tensor):
13 | """
14 | Args:
15 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
16 | Returns:
17 | Tensor: Normalized image.
18 | """
19 | for t, m, s in zip(tensor, self.mean, self.std):
20 | t.mul_(s).add_(m)
21 | return tensor
22 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 | import importlib
8 |
9 | def get_all_models():
10 | return [model.split('.')[0] for model in os.listdir('models')
11 | if not model.find('__') > -1 and 'py' in model]
12 |
13 | names = {}
14 | for model in get_all_models():
15 | mod = importlib.import_module('models.' + model)
16 | class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')]
17 | names[model] = getattr(mod, class_name)
18 |
19 | def get_model(args, backbone, loss, transform):
20 | return names[args.model](backbone, loss, args, transform)
21 |
--------------------------------------------------------------------------------
/utils/mixup_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys, time
3 | import torch
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import matplotlib
7 | matplotlib.use('agg')
8 | import matplotlib.pyplot as plt
9 |
10 | def mixup_old_data(x, y, alpha):
11 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
12 | if alpha > 0.:
13 | lam = np.random.beta(alpha, alpha)
14 | else:
15 | lam = 1.
16 | batch_size = int(x.size()[0]/2)
17 | index = torch.randperm(batch_size).cuda()
18 | index = index + batch_size
19 | if lam < 0.5:
20 | lam = 1. - lam
21 | mixed_x = lam * x[:batch_size,:] + (1 - lam) * x[index,:]
22 | y_a, y_b = y[:batch_size], y[index]
23 | return mixed_x, y_a, y_b, lam
24 |
25 | # mixed_x, y_a, y_b, lam = mixup_old_data(x, y, alpha=1.0)
26 | # print(mixed_x)
27 | # print(x)
28 | # # print(y_a)
29 | # # print(y_b)
--------------------------------------------------------------------------------
/utils/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import random
7 | import torch
8 | import numpy as np
9 |
10 | def get_device() -> torch.device:
11 | """
12 | Returns the GPU device if available else CPU.
13 | """
14 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | def base_path() -> str:
18 | """
19 | Returns the base bath where to log accuracies and tensorboard data.
20 | """
21 | return './data/'
22 |
23 |
24 | def set_random_seed(seed: int) -> None:
25 | """
26 | Sets the seeds at a certain value.
27 | :param seed: the value to be set
28 | """
29 | random.seed(seed)
30 | np.random.seed(seed)
31 | torch.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed)
33 |
--------------------------------------------------------------------------------
/utils/batch_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | class bn_track_stats:
10 | def __init__(self, module: nn.Module, condition=True):
11 | self.module = module
12 | self.enable = condition
13 |
14 | def __enter__(self):
15 | if not self.enable:
16 | for m in self.module.modules():
17 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
18 | m.track_running_stats = False
19 |
20 | def __exit__(self ,type, value, traceback):
21 | if not self.enable:
22 | for m in self.module.modules():
23 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
24 | m.track_running_stats = True
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 |
9 | def backward_transfer(results):
10 | n_tasks = len(results)
11 | li = list()
12 | for i in range(n_tasks - 1):
13 | li.append(results[-1][i] - results[i][i])
14 |
15 | return np.mean(li)
16 |
17 |
18 | def forward_transfer(results, random_results):
19 | n_tasks = len(results)
20 | li = list()
21 | for i in range(1, n_tasks):
22 | li.append(results[i-1][i] - random_results[i])
23 |
24 | return np.mean(li)
25 |
26 |
27 | def forgetting(results):
28 | n_tasks = len(results)
29 | li = list()
30 | for i in range(n_tasks - 1):
31 | results[i] += [0.0] * (n_tasks - len(results[i]))
32 | np_res = np.array(results)
33 | maxx = np.max(np_res, axis=0)
34 | for i in range(n_tasks - 1):
35 | li.append(maxx[i] - results[-1][i])
36 |
37 | return np.mean(li)
38 |
--------------------------------------------------------------------------------
/models/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from utils.args import *
7 | from models.utils.continual_model import ContinualModel
8 | import numpy as np
9 | import copy
10 | from torch.nn import functional as F
11 | import math
12 | from torch.optim import SGD
13 | from collections import OrderedDict
14 | EPS = 1E-20
15 |
16 | def get_parser() -> ArgumentParser:
17 | parser = ArgumentParser(description='Continual Learning via'
18 | ' Progressive Neural Networks.')
19 | add_management_args(parser)
20 | add_experiment_args(parser)
21 | return parser
22 |
23 |
24 | class Sgd(ContinualModel):
25 | NAME = 'sgd'
26 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
27 |
28 | def __init__(self, backbone, loss, args, transform):
29 | super(Sgd, self).__init__(backbone, loss, args, transform)
30 |
31 | def observe(self, inputs, labels, not_aug_inputs):
32 | self.opt.zero_grad()
33 | outputs = self.net(inputs)
34 | loss = self.loss(outputs, labels)
35 | loss.backward()
36 | self.opt.step()
37 |
38 | return loss.item()
39 |
--------------------------------------------------------------------------------
/utils/plot_lamda_acc.py:
--------------------------------------------------------------------------------
1 | from turtle import color
2 | import numpy as np
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import copy
8 | import math
9 | import torch
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.nn import functional as F
13 | import math
14 |
15 |
16 | Model_agnostic = [32.25, 34.46, 36.67, 38.01, 36.21]
17 | Model_based = [38.1, 39.48, 40.78, 41.32, 39.72]
18 | # plt.figure(figsize=(8,8))
19 | # title = 'CIFAR-100'
20 | # plt.title(r'Perturbation Magnitude $\lambda$')
21 | # plt.title(title, fontsize=14)
22 | plt.ylim(30, 45)
23 | plt.xlabel(r'Perturbation Magnitude $\lambda$', fontsize=14)
24 | plt.ylabel("Accuracy", fontsize=14)
25 |
26 | plt.grid(linewidth = 1.5)
27 | x = ['0.5', '1.0', '1.5', '2.0', '3.0']
28 | plt.plot(x, Model_agnostic, alpha=0.8, label="Model-agnostic MOCA", linewidth=3, marker= 'o')
29 | plt.plot(x, Model_based, alpha=0.8, label="Model-based MOCA", linewidth=3, marker= 'o')
30 | plt.tick_params(labelsize=14)
31 | plt.axhline(y = 31.12, linestyle='--', color = 'r', linewidth=1.5)
32 | ax = plt.gca()
33 | bwith = 1.
34 | ax.spines['top'].set_linewidth(bwith)
35 | ax.spines['right'].set_linewidth(bwith)
36 | ax.spines['bottom'].set_linewidth(bwith)
37 | ax.spines['left'].set_linewidth(bwith)
38 | plt.legend(loc="best", fontsize=14, edgecolor='black')
39 | plt.show()
40 | plt.savefig('./acc_lamda.pdf', bbox_inches = 'tight')
41 |
--------------------------------------------------------------------------------
/utils/plot_perturbed_angular_acc.py:
--------------------------------------------------------------------------------
1 | from turtle import color
2 | import numpy as np
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import copy
8 | import math
9 | import torch
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.nn import functional as F
13 | import math
14 |
15 |
16 | Model_agnostic = [30.48, 31.67, 32.58, 34.8, 37.98, 35.59]
17 | Model_based = [34.54, 38.0, 40.55, 39.94, 38.55, 20]
18 | # plt.figure(figsize=(8,8))
19 | # title = 'CIFAR-100'
20 | # plt.title(r'Perturbation Magnitude $\lambda$')
21 | # plt.title(title, fontsize=14)
22 | plt.ylim(30, 45)
23 | plt.xlabel(r'Perturbation Angular', fontsize=14)
24 | plt.ylabel("Accuracy", fontsize=14)
25 |
26 | plt.grid(linewidth = 1.5)
27 | x = ['5', '15', '30', '45', '60', '75']
28 | plt.plot(x, Model_agnostic, alpha=0.8, label="Model-agnostic MOCA", linewidth=3, marker= 'o')
29 | plt.plot(x, Model_based, alpha=0.8, label="Model-based MOCA", linewidth=3, marker= 'o')
30 | plt.tick_params(labelsize=14)
31 | plt.axhline(y = 31.12, linestyle='--', color = 'r', linewidth=1.5)
32 | ax = plt.gca()
33 | bwith = 1.
34 | ax.spines['top'].set_linewidth(bwith)
35 | ax.spines['right'].set_linewidth(bwith)
36 | ax.spines['bottom'].set_linewidth(bwith)
37 | ax.spines['left'].set_linewidth(bwith)
38 | plt.legend(loc="best", fontsize=14, edgecolor='black')
39 | plt.show()
40 | plt.savefig('./acc_perturbed_angular.pdf', bbox_inches = 'tight')
41 |
--------------------------------------------------------------------------------
/datasets/rot_mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torchvision.transforms as transforms
7 | from datasets.transforms.rotation import Rotation
8 | from torch.utils.data import DataLoader
9 | from backbone.MNISTMLP import MNISTMLP
10 | import torch.nn.functional as F
11 | from datasets.perm_mnist import store_mnist_loaders
12 | from datasets.utils.continual_dataset import ContinualDataset
13 |
14 |
15 | class RotatedMNIST(ContinualDataset):
16 | NAME = 'rot-mnist'
17 | SETTING = 'domain-il'
18 | N_CLASSES_PER_TASK = 10
19 | N_TASKS = 20
20 |
21 | def get_data_loaders(self):
22 | transform = transforms.Compose((Rotation(), transforms.ToTensor()))
23 | train, test = store_mnist_loaders(transform, self)
24 | return train, test
25 |
26 | @staticmethod
27 | def get_backbone():
28 | return MNISTMLP(28 * 28, RotatedMNIST.N_CLASSES_PER_TASK)
29 |
30 | @staticmethod
31 | def get_transform():
32 | return None
33 |
34 | @staticmethod
35 | def get_normalization_transform():
36 | return None
37 |
38 | @staticmethod
39 | def get_loss():
40 | return F.cross_entropy
41 |
42 | @staticmethod
43 | def get_denormalization_transform():
44 | return None
45 |
46 | @staticmethod
47 | def get_scheduler(model, args):
48 | return None
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 | import inspect
8 | import importlib
9 | from datasets.utils.gcl_dataset import GCLDataset
10 | from datasets.utils.continual_dataset import ContinualDataset
11 | from argparse import Namespace
12 |
13 | def get_all_models():
14 | return [model.split('.')[0] for model in os.listdir('datasets')
15 | if not model.find('__') > -1 and 'py' in model]
16 |
17 | NAMES = {}
18 | for model in get_all_models():
19 | mod = importlib.import_module('datasets.' + model)
20 | dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'ContinualDataset' in str(inspect.getmro(getattr(mod, x))[1:])]
21 | for d in dataset_classes_name:
22 | c = getattr(mod, d)
23 | NAMES[c.NAME] = c
24 |
25 | gcl_dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'GCLDataset' in str(inspect.getmro(getattr(mod, x))[1:])]
26 | for d in gcl_dataset_classes_name:
27 | c = getattr(mod, d)
28 | NAMES[c.NAME] = c
29 |
30 | def get_dataset(args: Namespace) -> ContinualDataset:
31 | """
32 | Creates and returns a continual dataset.
33 | :param args: the arguments which contains the hyperparameters
34 | :return: the continual dataset
35 | """
36 | assert args.dataset in NAMES.keys()
37 | return NAMES[args.dataset](args)
38 |
--------------------------------------------------------------------------------
/backbone/utils/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.parameter import Parameter
9 |
10 |
11 | class AlphaModule(nn.Module):
12 | def __init__(self, shape):
13 | super(AlphaModule, self).__init__()
14 | if not isinstance(shape, tuple):
15 | shape = (shape,)
16 | self.alpha = Parameter(torch.rand(tuple([1] + list(shape))) * 0.1,
17 | requires_grad=True)
18 |
19 | def forward(self, x):
20 | return x * self.alpha
21 |
22 | def parameters(self, recurse: bool = True):
23 | yield self.alpha
24 |
25 |
26 | class ListModule(nn.Module):
27 | def __init__(self, *args):
28 | super(ListModule, self).__init__()
29 | self.idx = 0
30 | for module in args:
31 | self.add_module(str(self.idx), module)
32 | self.idx += 1
33 |
34 | def append(self, module):
35 | self.add_module(str(self.idx), module)
36 | self.idx += 1
37 |
38 | def __getitem__(self, idx):
39 | if idx < 0:
40 | idx += self.idx
41 | if idx >= len(self._modules):
42 | raise IndexError('index {} is out of range'.format(idx))
43 | it = iter(self._modules.values())
44 | for i in range(idx):
45 | next(it)
46 | return next(it)
47 |
48 | def __iter__(self):
49 | return iter(self._modules.values())
50 |
51 | def __len__(self):
52 | return len(self._modules)
53 |
--------------------------------------------------------------------------------
/models/der.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from utils.buffer import Buffer
7 | from torch.nn import functional as F
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 |
11 | def get_parser() -> ArgumentParser:
12 | parser = ArgumentParser(description='Continual learning via'
13 | ' Dark Experience Replay.')
14 | add_management_args(parser)
15 | add_experiment_args(parser)
16 | add_rehearsal_args(parser)
17 | parser.add_argument('--alpha', type=float, required=True,
18 | help='Penalty weight.')
19 | return parser
20 |
21 |
22 | class Der(ContinualModel):
23 | NAME = 'der'
24 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
25 |
26 | def __init__(self, backbone, loss, args, transform):
27 | super(Der, self).__init__(backbone, loss, args, transform)
28 | self.buffer = Buffer(self.args.buffer_size, self.device)
29 |
30 | def observe(self, inputs, labels, not_aug_inputs):
31 |
32 | self.opt.zero_grad()
33 |
34 | outputs = self.net(inputs)
35 | loss = self.loss(outputs, labels)
36 |
37 | if not self.buffer.is_empty():
38 | buf_inputs, buf_logits = self.buffer.get_data(
39 | self.args.minibatch_size, transform=self.transform)
40 | buf_outputs = self.net(buf_inputs)
41 | loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits)
42 |
43 | loss.backward()
44 | self.opt.step()
45 | self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)
46 |
47 | return loss.item()
48 |
--------------------------------------------------------------------------------
/datasets/transforms/permutation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 |
9 | class Permutation(object):
10 | """
11 | Defines a fixed permutation for a numpy array.
12 | """
13 | def __init__(self) -> None:
14 | """
15 | Initializes the permutation.
16 | """
17 | self.perm = None
18 |
19 | def __call__(self, sample: np.ndarray) -> np.ndarray:
20 | """
21 | Randomly defines the permutation and applies the transformation.
22 | :param sample: image to be permuted
23 | :return: permuted image
24 | """
25 | old_shape = sample.shape
26 | if self.perm is None:
27 | self.perm = np.random.permutation(len(sample.flatten()))
28 |
29 | return sample.flatten()[self.perm].reshape(old_shape)
30 |
31 |
32 | class FixedPermutation(object):
33 | """
34 | Defines a fixed permutation (given the seed) for a numpy array.
35 | """
36 | def __init__(self, seed: int) -> None:
37 | """
38 | Defines the seed.
39 | :param seed: seed of the permutation
40 | """
41 | self.perm = None
42 | self.seed = seed
43 |
44 | def __call__(self, sample: np.ndarray) -> np.ndarray:
45 | """
46 | Defines the permutation and applies the transformation.
47 | :param sample: image to be permuted
48 | :return: permuted image
49 | """
50 | old_shape = sample.shape
51 | if self.perm is None:
52 | np.random.seed(self.seed)
53 | self.perm = np.random.permutation(len(sample.flatten()))
54 |
55 | return sample.flatten()[self.perm].reshape(old_shape)
56 |
--------------------------------------------------------------------------------
/utils/plot_mixup_acc.py:
--------------------------------------------------------------------------------
1 | from turtle import color
2 | import numpy as np
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import copy
8 | import math
9 | import torch
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.nn import functional as F
13 | import math
14 |
15 |
16 |
17 | data = [29.33, 34.08, 38.76, 41.02]
18 | labels = ['Original \n Manifold Mixup', 'Adapted \n Manifold Mixup', 'model-agnostic \n MOCA', 'model-based \n MOCA']
19 | plt.figure(figsize=(8,5))
20 | plt.bar(range(len(data)), data, tick_label=labels, width = 0.35, alpha = 0.8, edgecolor='black')
21 | plt.ylim(30, 42)
22 | plt.axhline(y = 31.08, linestyle='--', color = 'r', linewidth=1.5)
23 | plt.tick_params(labelsize=12)
24 | plt.ylabel('Accuracy', fontsize=14)
25 | plt.xlabel('Method', fontsize=14)
26 | ax = plt.gca()
27 | bwith = 1.
28 | ax.spines['top'].set_linewidth(bwith)
29 | ax.spines['right'].set_linewidth(bwith)
30 | ax.spines['bottom'].set_linewidth(bwith)
31 | ax.spines['left'].set_linewidth(bwith)
32 | ax.yaxis.set_ticks([26, 32, 38, 44])
33 | plt.show()
34 | plt.savefig('./mixup_acc.pdf', bbox_inches = 'tight')
35 |
36 |
37 |
38 | def visual_svd(s_list, method_list):
39 | x = np.arange(0,512)
40 | from matplotlib.ticker import MaxNLocator
41 | plt.ion()
42 | plt.clf()
43 | for i in range(len(s_list)):
44 | if 'V-vmf' in method_list[i]:
45 | continue
46 | s_list[i] = s_list[i] / torch.sum(s_list[i])
47 | plt.plot(x, torch.log(s_list[i]), label=method_list[i])
48 | plt.xlabel('Dimension', fontsize=14)
49 | plt.ylabel('Log Singular Value', fontsize=14)
50 | plt.legend()
51 |
52 | ax = plt.gca()
53 | bwith = 1.
54 | ax.spines['top'].set_linewidth(bwith)
55 | ax.spines['right'].set_linewidth(bwith)
56 | ax.spines['bottom'].set_linewidth(bwith)
57 | ax.spines['left'].set_linewidth(bwith)
58 | plt.tick_params(labelsize=14)
59 | plt.savefig('./aug_feat.pdf')
60 | plt.show
61 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ☕ (MOCA) Continual Learning by Modeling Intra-Class Variation
2 | This is an official implementation of the TMLR 2023 paper "Continual Learning by Modeling Intra-Class Variation" (MOCA).
3 | ## Environment
4 | This work is based on the code of [DER](https://github.com/aimagelab/mammoth):
5 | ```bash
6 | pip install -r requirements.txt
7 | ```
8 | ## Setup
9 | + Use `./utils/main.py` to run experiments.
10 | + Use argument `--load_best_args` to use the best hyperparameters from the paper.
11 |
12 | ## TODO:
13 | - [x] Release code!
14 | - [x] Bash Arguments!
15 | - [ ] The code is still dirty and we'll sort it out soon.
16 | - [ ] 2D Visualization code and Gradients Analysis code.
17 |
18 | ## Examples
19 | For reproducing the results of our MOCA-Gaussian on Cifar-100, run:
20 | ```bash
21 | python ./utils/main.py --load_best_args --model er --dataset seq-cifar100 --buffer_size 500 --para_scale 1.5 --gamma_loss 1 --norm_add norm_add --method2 gaussian --noise_type noise
22 | ```
23 |
24 | For reproducing the results of our MOCA-WAP on Cifar-100, run:
25 | ```bash
26 | python ./utils/main.py --load_best_args --model er --dataset seq-cifar100 --buffer_size 500 --para_scale 1.5 --gamma_loss 10 --norm_add norm_add --advloss none --target_type new_labels --noise_type adv --inner_iter 1
27 | ```
28 |
29 | ## Citation
30 | If you find this code or idea useful, please cite our work:
31 | ```bib
32 | @article{yu2022continual,
33 | title={Continual Learning by Modeling Intra-Class Variation},
34 | author={Yu, Longhui and Hu, Tianyang and Hong, Lanqing and Liu, Zhen and Weller, Adrian and Liu, Weiyang},
35 | journal={arXiv preprint arXiv:2210.05398},
36 | year={2022}
37 | }
38 | ```
39 |
40 | ## Contact
41 | If you have any questions, feel free to contact us through email (yulonghui@stu.pku.edu.cn). Enjoy!
42 |
43 | ## Intra-class Variation Gap
44 |
45 |
46 |
47 |
48 | ## Representation Collapse
49 |
50 |
51 |
52 |
53 | ## Gradient Collapse
54 |
55 |
56 |
57 |
58 | ## MOCA Framework
59 |
60 |
61 |
62 |
63 | ## MOCA Variants
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/utils/plot_training_gradient.py:
--------------------------------------------------------------------------------
1 |
2 | from turtle import color
3 | import numpy as np
4 | import numpy as np
5 | import matplotlib
6 | matplotlib.use("Agg")
7 | import matplotlib.pyplot as plt
8 | import copy
9 | import math
10 | import torch
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | from torch.nn import functional as F
14 | import math
15 |
16 | def visual_svd(s_list, method_list):
17 | x = np.arange(0,512)
18 | from matplotlib.ticker import MaxNLocator
19 | plt.ion()
20 | plt.clf()
21 | for i in [5, 3, 6, 7, 1, 2, 0, 4]:
22 | s_list[i] = s_list[i] / torch.sum(s_list[i])
23 | if method_list[i] == 'WAP':
24 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.9, color = 'purple')
25 | continue
26 | if method_list[i] == 'DOA-new':
27 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.65, color='purple')
28 | continue
29 | if method_list[i] == 'VT':
30 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.4, color='purple')
31 | continue
32 | if method_list[i] == 'vMF':
33 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.9, color='chocolate')
34 | continue
35 | plt.plot(x[:100], torch.log(s_list[i])[:100], label=method_list[i], linewidth=3.0, alpha = 0.9)
36 |
37 | plt.xlabel('Dimension', fontsize=14)
38 | plt.ylabel('Log Singular Value', fontsize=14)
39 | plt.legend(fontsize=12, edgecolor='black')
40 | # plt.legend.get_frame().set_edgecolor('black')
41 | plt.axvline(x = 50, linestyle='--', color = 'r', linewidth=1.5)
42 | ax = plt.gca()
43 | bwith = 1.
44 | ax.spines['top'].set_linewidth(bwith)
45 | ax.spines['right'].set_linewidth(bwith)
46 | ax.spines['bottom'].set_linewidth(bwith)
47 | ax.spines['left'].set_linewidth(bwith)
48 | plt.tick_params(labelsize=14)
49 | plt.savefig('./vmf_svd.pdf')
50 | plt.show()
51 |
52 |
53 | s_list2 = np.loadtxt('all_perturb_gradient_31.txt')
54 |
55 | s_list2 = torch.from_numpy(s_list2)
56 |
57 | method_list = [
58 | 'WAP',
59 | 'VT',
60 | 'DOA-new' ,
61 | 'DOA-old' ,
62 | 'Joint' ,
63 | 'ER' ,
64 | 'Gaussian' ,
65 | 'vMF' ]
66 |
67 | # s_list2 = s_list2.cpu().detach().numpy()
68 | # np.savetxt('all_perturb_gradient_31.txt',s_list2)
69 | visual_svd(s_list2, method_list)
--------------------------------------------------------------------------------
/models/joint_gcl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torch.optim import SGD
7 |
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 | from datasets.utils.validation import ValidationDataset
11 | from utils.status import progress_bar
12 | import torch
13 | import numpy as np
14 | import math
15 | from tqdm import tqdm
16 | from torchvision import transforms
17 |
18 |
19 | def get_parser() -> ArgumentParser:
20 | parser = ArgumentParser(description='Joint training: a strong, simple baseline.')
21 | add_management_args(parser)
22 | add_experiment_args(parser)
23 | return parser
24 |
25 |
26 | class JointGCL(ContinualModel):
27 | NAME = 'joint_gcl'
28 | COMPATIBILITY = ['general-continual']
29 |
30 | def __init__(self, backbone, loss, args, transform):
31 | super(JointGCL, self).__init__(backbone, loss, args, transform)
32 | self.old_data = []
33 | self.old_labels = []
34 | self.current_task = 0
35 |
36 | def end_task(self, dataset):
37 | # reinit network
38 | self.net = dataset.get_backbone()
39 | self.net.to(self.device)
40 | self.net.train()
41 | self.opt = SGD(self.net.parameters(), lr=self.args.lr)
42 |
43 | # gather data
44 | all_data = torch.cat(self.old_data)
45 | all_labels = torch.cat(self.old_labels)
46 |
47 | # train
48 | for e in range(1):#range(self.args.n_epochs):
49 | rp = torch.randperm(len(all_data))
50 | for i in range(math.ceil(len(all_data) / self.args.batch_size)):
51 | inputs = all_data[rp][i * self.args.batch_size:(i+1) * self.args.batch_size]
52 | labels = all_labels[rp][i * self.args.batch_size:(i+1) * self.args.batch_size]
53 | inputs, labels = inputs.to(self.device), labels.to(self.device)
54 |
55 | self.opt.zero_grad()
56 | outputs = self.net(inputs)
57 | loss = self.loss(outputs, labels.long())
58 | loss.backward()
59 | self.opt.step()
60 | progress_bar(i, math.ceil(len(all_data) / self.args.batch_size), e, 'J', loss.item())
61 |
62 | def observe(self, inputs, labels, not_aug_inputs):
63 | self.old_data.append(inputs.data)
64 | self.old_labels.append(labels.data)
65 | return 0
66 |
--------------------------------------------------------------------------------
/models/agem_r.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import numpy as np
8 | from utils.buffer import Buffer
9 | from models.gem import overwrite_grad
10 | from models.gem import store_grad
11 | from utils.args import *
12 | from models.agem import project
13 | from models.utils.continual_model import ContinualModel
14 |
15 | def get_parser() -> ArgumentParser:
16 | parser = ArgumentParser(description='Continual learning via A-GEM, '
17 | 'leveraging a reservoir buffer.')
18 | add_management_args(parser)
19 | add_experiment_args(parser)
20 | add_rehearsal_args(parser)
21 | return parser
22 |
23 | class AGemr(ContinualModel):
24 | NAME = 'agem_r'
25 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
26 |
27 | def __init__(self, backbone, loss, args, transform):
28 | super(AGemr, self).__init__(backbone, loss, args, transform)
29 |
30 | self.buffer = Buffer(self.args.buffer_size, self.device)
31 | self.grad_dims = []
32 | for param in self.parameters():
33 | self.grad_dims.append(param.data.numel())
34 | self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
35 | self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
36 | self.current_task = 0
37 |
38 | def observe(self, inputs, labels, not_aug_inputs):
39 | self.zero_grad()
40 | p = self.net.forward(inputs)
41 | loss = self.loss(p, labels)
42 | loss.backward()
43 |
44 | if not self.buffer.is_empty():
45 | store_grad(self.parameters, self.grad_xy, self.grad_dims)
46 |
47 | buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size)
48 | self.net.zero_grad()
49 | buf_outputs = self.net.forward(buf_inputs)
50 | penalty = self.loss(buf_outputs, buf_labels)
51 | penalty.backward()
52 | store_grad(self.parameters, self.grad_er, self.grad_dims)
53 |
54 | dot_prod = torch.dot(self.grad_xy, self.grad_er)
55 | if dot_prod.item() < 0:
56 | g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
57 | overwrite_grad(self.parameters, g_tilde, self.grad_dims)
58 | else:
59 | overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)
60 |
61 | self.opt.step()
62 |
63 | self.buffer.add_data(examples=not_aug_inputs, labels=labels)
64 |
65 | return loss.item()
66 |
--------------------------------------------------------------------------------
/models/si.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 |
11 | import numpy as np
12 | import copy
13 | from torch.nn import functional as F
14 | import math
15 | from torch.optim import SGD
16 | from collections import OrderedDict
17 | EPS = 1E-20
18 |
19 | def get_parser() -> ArgumentParser:
20 | parser = ArgumentParser(description='Continual Learning Through'
21 | ' Synaptic Intelligence.')
22 | add_management_args(parser)
23 | add_experiment_args(parser)
24 | parser.add_argument('--c', type=float, required=True,
25 | help='surrogate loss weight parameter c')
26 | parser.add_argument('--xi', type=float, required=True,
27 | help='xi parameter for EWC online')
28 |
29 | return parser
30 |
31 |
32 | class SI(ContinualModel):
33 | NAME = 'si'
34 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
35 |
36 | def __init__(self, backbone, loss, args, transform):
37 | super(SI, self).__init__(backbone, loss, args, transform)
38 |
39 | self.checkpoint = self.net.get_params().data.clone().to(self.device)
40 | self.big_omega = None
41 | self.small_omega = 0
42 |
43 | def penalty(self):
44 | if self.big_omega is None:
45 | return torch.tensor(0.0).to(self.device)
46 | else:
47 | penalty = (self.big_omega * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
48 | return penalty
49 |
50 | def end_task(self, dataset):
51 | # big omega calculation step
52 | if self.big_omega is None:
53 | self.big_omega = torch.zeros_like(self.net.get_params()).to(self.device)
54 |
55 | self.big_omega += self.small_omega / ((self.net.get_params().data - self.checkpoint) ** 2 + self.args.xi)
56 |
57 | # store parameters checkpoint and reset small_omega
58 | self.checkpoint = self.net.get_params().data.clone().to(self.device)
59 | self.small_omega = 0
60 |
61 | def observe(self, inputs, labels, not_aug_inputs):
62 | self.opt.zero_grad()
63 | outputs = self.net(inputs)
64 | penalty = self.penalty()
65 | loss = self.loss(outputs, labels) + self.args.c * penalty
66 | loss.backward()
67 | nn.utils.clip_grad.clip_grad_value_(self.net.parameters(), 1)
68 | self.opt.step()
69 |
70 | self.small_omega += self.args.lr * self.net.get_grads().data ** 2
71 |
72 | return loss.item()
73 |
--------------------------------------------------------------------------------
/utils/plot_all_method_acc.py:
--------------------------------------------------------------------------------
1 | from turtle import color
2 | import numpy as np
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import copy
8 | import math
9 | import torch
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.nn import functional as F
13 | import math
14 |
15 |
16 |
17 | data = [33.67, 37.29, 38.76, 38.75, 39.78, 41.02]
18 | labels = ['DOA-self', 'Gaussian', 'vMF', 'DOA', 'VT', 'WAP']
19 | plt.figure(figsize=(8,4))
20 | plt.figure(figsize=(8,4))
21 | plt.bar(range(len(data)), data, tick_label=labels, width = 0.35, alpha = 0.8, edgecolor='black')
22 | plt.ylim(30, 42)
23 | plt.axhline(y = 31.08, linestyle='--', color = 'r', linewidth=1.5)
24 | plt.tick_params(labelsize=12)
25 | plt.ylabel('Accuracy', fontsize=14)
26 | plt.xlabel('Method', fontsize=14)
27 | ax = plt.gca()
28 | bwith = 1.
29 | ax.spines['top'].set_linewidth(bwith)
30 | ax.spines['right'].set_linewidth(bwith)
31 | ax.spines['bottom'].set_linewidth(bwith)
32 | ax.spines['left'].set_linewidth(bwith)
33 | ax.yaxis.set_ticks([30, 34, 38, 42])
34 | plt.show()
35 | plt.savefig('./all_method_acc.pdf', bbox_inches = 'tight')
36 |
37 |
38 |
39 |
40 | from matplotlib.ticker import MaxNLocator
41 | data = [33.54, 51.279, 68.75, 41.17, 35.72, 37.08]
42 | labels = ['DOA-self', 'Gaussian', 'vMF', 'DOA', 'VT', 'WAP']
43 | plt.figure(figsize=(8,4))
44 | plt.bar(range(len(data)), data, tick_label=labels, width = 0.3, alpha = 0.8, edgecolor='black')
45 | plt.ylim(29, 46)
46 | plt.axhline(y = 30.12, linestyle='--', color = 'r', linewidth=1.5)
47 | plt.tick_params(labelsize=12)
48 | plt.ylabel('Old Intra-class Variance', fontsize=14)
49 | plt.xlabel('Method', fontsize=14)
50 | ax = plt.gca()
51 | bwith = 1.
52 | ax.spines['top'].set_linewidth(bwith)
53 | ax.spines['right'].set_linewidth(bwith)
54 | ax.spines['bottom'].set_linewidth(bwith)
55 | ax.spines['left'].set_linewidth(bwith)
56 | ax.yaxis.set_ticks([30, 34, 38, 42, 46])
57 | plt.show()
58 | plt.savefig('./all_method_variance_exp.pdf', bbox_inches = 'tight')
59 |
60 | def visual_svd(s_list, method_list):
61 | x = np.arange(0,512)
62 | from matplotlib.ticker import MaxNLocator
63 | plt.ion()
64 | plt.clf()
65 | for i in range(len(s_list)):
66 | if 'V-vmf' in method_list[i]:
67 | continue
68 | s_list[i] = s_list[i] / torch.sum(s_list[i])
69 | plt.plot(x, torch.log(s_list[i]), label=method_list[i])
70 | plt.xlabel('Dimension', fontsize=14)
71 | plt.ylabel('Log Singular Value', fontsize=14)
72 | plt.legend()
73 |
74 | ax = plt.gca()
75 | bwith = 1.
76 | ax.spines['top'].set_linewidth(bwith)
77 | ax.spines['right'].set_linewidth(bwith)
78 | ax.spines['bottom'].set_linewidth(bwith)
79 | ax.spines['left'].set_linewidth(bwith)
80 | plt.tick_params(labelsize=14)
81 | plt.savefig('./aug_feat.pdf')
82 | plt.show
83 |
--------------------------------------------------------------------------------
/utils/continual_training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import torch
8 | from datasets import get_dataset
9 | from models import get_model
10 | from utils.status import progress_bar
11 | from utils.tb_logger import *
12 | from utils.status import create_fake_stash
13 | from models.utils.continual_model import ContinualModel
14 | from argparse import Namespace
15 |
16 |
17 | def evaluate(model: ContinualModel, dataset) -> float:
18 | """
19 | Evaluates the final accuracy of the model.
20 | :param model: the model to be evaluated
21 | :param dataset: the GCL dataset at hand
22 | :return: a float value that indicates the accuracy
23 | """
24 | model.net.eval()
25 | correct, total = 0, 0
26 | while not dataset.test_over:
27 | inputs, labels = dataset.get_test_data()
28 | inputs, labels = inputs.to(model.device), labels.to(model.device)
29 | outputs = model(inputs)
30 | _, predicted = torch.max(outputs.data, 1)
31 | correct += torch.sum(predicted == labels).item()
32 | total += labels.shape[0]
33 |
34 | acc = correct / total * 100
35 | return acc
36 |
37 |
38 | def train(args: Namespace):
39 | """
40 | The training process, including evaluations and loggers.
41 | :param model: the module to be trained
42 | :param dataset: the continual dataset at hand
43 | :param args: the arguments of the current execution
44 | """
45 | if args.csv_log:
46 | from utils.loggers import CsvLogger
47 |
48 | dataset = get_dataset(args)
49 | backbone = dataset.get_backbone()
50 | loss = dataset.get_loss()
51 | model = get_model(args, backbone, loss, dataset.get_transform())
52 | model.net.to(model.device)
53 |
54 | model_stash = create_fake_stash(model, args)
55 |
56 | if args.csv_log:
57 | csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
58 | if args.tensorboard:
59 | tb_logger = TensorboardLogger(args, dataset.SETTING, model_stash)
60 |
61 | model.net.train()
62 | epoch, i = 0, 0
63 | while not dataset.train_over:
64 | inputs, labels, not_aug_inputs = dataset.get_train_data()
65 | inputs, labels = inputs.to(model.device), labels.to(model.device)
66 | not_aug_inputs = not_aug_inputs.to(model.device)
67 | loss = model.observe(inputs, labels, not_aug_inputs)
68 | progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss)
69 | if args.tensorboard:
70 | tb_logger.log_loss_gcl(loss, i)
71 | i += 1
72 |
73 | if model.NAME == 'joint_gcl':
74 | model.end_task(dataset)
75 |
76 | acc = evaluate(model, dataset)
77 | print('Accuracy:', acc)
78 |
79 | if args.csv_log:
80 | csv_logger.log(acc)
81 | csv_logger.write(vars(args))
82 |
--------------------------------------------------------------------------------
/models/agem.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import numpy as np
8 | from utils.buffer import Buffer
9 | from models.gem import overwrite_grad
10 | from models.gem import store_grad
11 | from utils.args import *
12 | from models.utils.continual_model import ContinualModel
13 |
14 | def get_parser() -> ArgumentParser:
15 | parser = ArgumentParser(description='Continual learning via A-GEM.')
16 | add_management_args(parser)
17 | add_experiment_args(parser)
18 | add_rehearsal_args(parser)
19 | return parser
20 |
21 | def project(gxy: torch.Tensor, ger: torch.Tensor) -> torch.Tensor:
22 | corr = torch.dot(gxy, ger) / torch.dot(ger, ger)
23 | return gxy - corr * ger
24 |
25 |
26 | class AGem(ContinualModel):
27 | NAME = 'agem'
28 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
29 |
30 | def __init__(self, backbone, loss, args, transform):
31 | super(AGem, self).__init__(backbone, loss, args, transform)
32 |
33 | self.buffer = Buffer(self.args.buffer_size, self.device)
34 | self.grad_dims = []
35 | for param in self.parameters():
36 | self.grad_dims.append(param.data.numel())
37 | self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
38 | self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
39 |
40 | def end_task(self, dataset):
41 | samples_per_task = self.args.buffer_size // dataset.N_TASKS
42 | loader = dataset.train_loader
43 | cur_y, cur_x = next(iter(loader))[1:]
44 | self.buffer.add_data(
45 | examples=cur_x.to(self.device),
46 | labels=cur_y.to(self.device)
47 | )
48 |
49 | def observe(self, inputs, labels, not_aug_inputs):
50 |
51 | self.zero_grad()
52 | p = self.net.forward(inputs)
53 | loss = self.loss(p, labels)
54 | loss.backward()
55 |
56 | if not self.buffer.is_empty():
57 | store_grad(self.parameters, self.grad_xy, self.grad_dims)
58 |
59 | buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform)
60 | self.net.zero_grad()
61 | buf_outputs = self.net.forward(buf_inputs)
62 | penalty = self.loss(buf_outputs, buf_labels)
63 | penalty.backward()
64 | store_grad(self.parameters, self.grad_er, self.grad_dims)
65 |
66 | dot_prod = torch.dot(self.grad_xy, self.grad_er)
67 | if dot_prod.item() < 0:
68 | g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
69 | overwrite_grad(self.parameters, g_tilde, self.grad_dims)
70 | else:
71 | overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)
72 |
73 | self.opt.step()
74 |
75 | return loss.item()
76 |
--------------------------------------------------------------------------------
/datasets/utils/validation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from PIL import Image
8 | import numpy as np
9 | import os
10 | from utils import create_if_not_exists
11 | import torchvision.transforms.transforms as transforms
12 | from torchvision import datasets
13 |
14 |
15 | class ValidationDataset(torch.utils.data.Dataset):
16 | def __init__(self, data: torch.Tensor, targets: np.ndarray,
17 | transform: transforms=None, target_transform: transforms=None) -> None:
18 | self.data = data
19 | self.targets = targets
20 | self.transform = transform
21 | self.target_transform = target_transform
22 |
23 | def __len__(self):
24 | return self.data.shape[0]
25 |
26 | def __getitem__(self, index):
27 | img, target = self.data[index], self.targets[index]
28 |
29 | # doing this so that it is consistent with all other datasets
30 | # to return a PIL Image
31 | if isinstance(img, np.ndarray):
32 | if np.max(img) < 2:
33 | img = Image.fromarray(np.uint8(img * 255))
34 | else:
35 | img = Image.fromarray(img)
36 | else:
37 | img = Image.fromarray(img.numpy())
38 |
39 | if self.transform is not None:
40 | img = self.transform(img)
41 |
42 | if self.target_transform is not None:
43 | target = self.target_transform(target)
44 |
45 | return img, target
46 |
47 | def get_train_val(train: datasets, test_transform: transforms,
48 | dataset: str, val_perc: float=0.1):
49 | """
50 | Extract val_perc% of the training set as the validation set.
51 | :param train: training dataset
52 | :param test_transform: transformation of the test dataset
53 | :param dataset: dataset name
54 | :param val_perc: percentage of the training set to be extracted
55 | :return: the training set and the validation set
56 | """
57 | dataset_length = train.data.shape[0]
58 | directory = 'datasets/val_permutations/'
59 | create_if_not_exists(directory)
60 | file_name = dataset + '.pt'
61 | if os.path.exists(directory + file_name):
62 | perm = torch.load(directory + file_name)
63 | else:
64 | perm = torch.randperm(dataset_length)
65 | torch.save(perm, directory + file_name)
66 | train.data = train.data[perm]
67 | train.targets = np.array(train.targets)[perm]
68 | test_dataset = ValidationDataset(train.data[:int(val_perc * dataset_length)],
69 | train.targets[:int(val_perc * dataset_length)],
70 | transform=test_transform)
71 | train.data = train.data[int(val_perc * dataset_length):]
72 | train.targets = train.targets[int(val_perc * dataset_length):]
73 |
74 | return train, test_dataset
75 |
--------------------------------------------------------------------------------
/models/ewc_on.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from utils.args import *
10 | from models.utils.continual_model import ContinualModel
11 |
12 |
13 | def get_parser() -> ArgumentParser:
14 | parser = ArgumentParser(description='Continual learning via'
15 | ' online EWC.')
16 | add_management_args(parser)
17 | add_experiment_args(parser)
18 | parser.add_argument('--e_lambda', type=float, required=True,
19 | help='lambda weight for EWC')
20 | parser.add_argument('--gamma', type=float, required=True,
21 | help='gamma parameter for EWC online')
22 |
23 | return parser
24 |
25 |
26 | class EwcOn(ContinualModel):
27 | NAME = 'ewc_on'
28 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
29 |
30 | def __init__(self, backbone, loss, args, transform):
31 | super(EwcOn, self).__init__(backbone, loss, args, transform)
32 |
33 | self.logsoft = nn.LogSoftmax(dim=1)
34 | self.checkpoint = None
35 | self.fish = None
36 |
37 | def penalty(self):
38 | if self.checkpoint is None:
39 | return torch.tensor(0.0).to(self.device)
40 | else:
41 | penalty = (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
42 | return penalty
43 |
44 | def end_task(self, dataset):
45 | fish = torch.zeros_like(self.net.get_params())
46 |
47 | for j, data in enumerate(dataset.train_loader):
48 | inputs, labels, _ = data
49 | inputs, labels = inputs.to(self.device), labels.to(self.device)
50 | for ex, lab in zip(inputs, labels):
51 | self.opt.zero_grad()
52 | output = self.net(ex.unsqueeze(0))
53 | loss = - F.nll_loss(self.logsoft(output), lab.unsqueeze(0),
54 | reduction='none')
55 | exp_cond_prob = torch.mean(torch.exp(loss.detach().clone()))
56 | loss = torch.mean(loss)
57 | loss.backward()
58 | fish += exp_cond_prob * self.net.get_grads() ** 2
59 |
60 | fish /= (len(dataset.train_loader) * self.args.batch_size)
61 |
62 | if self.fish is None:
63 | self.fish = fish
64 | else:
65 | self.fish *= self.args.gamma
66 | self.fish += fish
67 |
68 | self.checkpoint = self.net.get_params().data.clone()
69 |
70 | def observe(self, inputs, labels, not_aug_inputs):
71 |
72 | self.opt.zero_grad()
73 | outputs = self.net(inputs)
74 | penalty = self.penalty()
75 | loss = self.loss(outputs, labels) + self.args.e_lambda * penalty
76 | assert not torch.isnan(loss)
77 | loss.backward()
78 | self.opt.step()
79 |
80 | return loss.item()
81 |
--------------------------------------------------------------------------------
/datasets/transforms/rotation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | class Rotation(object):
11 | """
12 | Defines a fixed rotation for a numpy array.
13 | """
14 |
15 | def __init__(self, deg_min: int = 0, deg_max: int = 180) -> None:
16 | """
17 | Initializes the rotation with a random angle.
18 | :param deg_min: lower extreme of the possible random angle
19 | :param deg_max: upper extreme of the possible random angle
20 | """
21 | self.deg_min = deg_min
22 | self.deg_max = deg_max
23 | self.degrees = np.random.uniform(self.deg_min, self.deg_max)
24 |
25 | def __call__(self, x: np.ndarray) -> np.ndarray:
26 | """
27 | Applies the rotation.
28 | :param x: image to be rotated
29 | :return: rotated image
30 | """
31 | return F.rotate(x, self.degrees)
32 |
33 |
34 | class FixedRotation(object):
35 | """
36 | Defines a fixed rotation for a numpy array.
37 | """
38 |
39 | def __init__(self, seed: int, deg_min: int = 0, deg_max: int = 180) -> None:
40 | """
41 | Initializes the rotation with a random angle.
42 | :param seed: seed of the rotation
43 | :param deg_min: lower extreme of the possible random angle
44 | :param deg_max: upper extreme of the possible random angle
45 | """
46 | self.seed = seed
47 | self.deg_min = deg_min
48 | self.deg_max = deg_max
49 |
50 | np.random.seed(seed)
51 | self.degrees = np.random.uniform(self.deg_min, self.deg_max)
52 |
53 | def __call__(self, x: np.ndarray) -> np.ndarray:
54 | """
55 | Applies the rotation.
56 | :param x: image to be rotated
57 | :return: rotated image
58 | """
59 | return F.rotate(x, self.degrees)
60 |
61 |
62 | class IncrementalRotation(object):
63 | """
64 | Defines an incremental rotation for a numpy array.
65 | """
66 |
67 | def __init__(self, init_deg: int = 0, increase_per_iteration: float = 0.006) -> None:
68 | """
69 | Defines the initial angle as well as the increase for each rotation
70 | :param init_deg:
71 | :param increase_per_iteration:
72 | """
73 | self.increase_per_iteration = increase_per_iteration
74 | self.iteration = 0
75 | self.degrees = init_deg
76 |
77 | def __call__(self, x: np.ndarray) -> np.ndarray:
78 | """
79 | Applies the rotation.
80 | :param x: image to be rotated
81 | :return: rotated image
82 | """
83 | degs = (self.iteration * self.increase_per_iteration + self.degrees) % 360
84 | self.iteration += 1
85 | return F.rotate(x, degs)
86 |
87 | def set_iteration(self, x: int) -> None:
88 | """
89 | Set the iteration to a given integer
90 | :param x: iteration index
91 | """
92 | self.iteration = x
93 |
--------------------------------------------------------------------------------
/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | def xavier(m: nn.Module) -> None:
12 | """
13 | Applies Xavier initialization to linear modules.
14 |
15 | :param m: the module to be initialized
16 |
17 | Example::
18 | >>> net = nn.Sequential(nn.Linear(10, 10), nn.ReLU())
19 | >>> net.apply(xavier)
20 | """
21 | if m.__class__.__name__ == 'Linear':
22 | fan_in = m.weight.data.size(1)
23 | fan_out = m.weight.data.size(0)
24 | std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out))
25 | a = math.sqrt(3.0) * std
26 | m.weight.data.uniform_(-a, a)
27 | if m.bias is not None:
28 | m.bias.data.fill_(0.0)
29 |
30 |
31 | def num_flat_features(x: torch.Tensor) -> int:
32 | """
33 | Computes the total number of items except the first dimension.
34 |
35 | :param x: input tensor
36 | :return: number of item from the second dimension onward
37 | """
38 | size = x.size()[1:]
39 | num_features = 1
40 | for ff in size:
41 | num_features *= ff
42 | return num_features
43 |
44 | class MammothBackbone(nn.Module):
45 |
46 | def __init__(self, **kwargs) -> None:
47 | super(MammothBackbone, self).__init__()
48 |
49 | def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
50 | raise NotImplementedError
51 |
52 | def features(self, x: torch.Tensor) -> torch.Tensor:
53 | return self.forward(x, returnt='features')
54 |
55 | def get_params(self) -> torch.Tensor:
56 | """
57 | Returns all the parameters concatenated in a single tensor.
58 | :return: parameters tensor (??)
59 | """
60 | params = []
61 | for pp in list(self.parameters()):
62 | params.append(pp.view(-1))
63 | return torch.cat(params)
64 |
65 | def set_params(self, new_params: torch.Tensor) -> None:
66 | """
67 | Sets the parameters to a given value.
68 | :param new_params: concatenated values to be set (??)
69 | """
70 | assert new_params.size() == self.get_params().size()
71 | progress = 0
72 | for pp in list(self.parameters()):
73 | cand_params = new_params[progress: progress +
74 | torch.tensor(pp.size()).prod()].view(pp.size())
75 | progress += torch.tensor(pp.size()).prod()
76 | pp.data = cand_params
77 |
78 | def get_grads(self) -> torch.Tensor:
79 | """
80 | Returns all the gradients concatenated in a single tensor.
81 | :return: gradients tensor (??)
82 | """
83 | return torch.cat(self.get_grads_list())
84 |
85 | def get_grads_list(self):
86 | """
87 | Returns a list containing the gradients (a tensor for each layer).
88 | :return: gradients list
89 | """
90 | grads = []
91 | for pp in list(self.parameters()):
92 | grads.append(pp.grad.view(-1))
93 | return grads
94 |
--------------------------------------------------------------------------------
/models/gss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from utils.gss_buffer import Buffer as Buffer
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 |
11 | import numpy as np
12 | import copy
13 | from torch.nn import functional as F
14 | import math
15 | from torch.optim import SGD
16 | from collections import OrderedDict
17 | EPS = 1E-20
18 |
19 | def get_parser() -> ArgumentParser:
20 | parser = ArgumentParser(description='Gradient based sample selection'
21 | 'for online continual learning')
22 | add_management_args(parser)
23 | add_experiment_args(parser)
24 | add_rehearsal_args(parser)
25 | parser.add_argument('--batch_num', type=int, required=True,
26 | help='Number of batches extracted from the buffer.')
27 | parser.add_argument('--gss_minibatch_size', type=int, default=None,
28 | help='The batch size of the gradient comparison.')
29 | return parser
30 |
31 |
32 | class Gss(ContinualModel):
33 | NAME = 'gss'
34 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
35 |
36 | def __init__(self, backbone, loss, args, transform):
37 | super(Gss, self).__init__(backbone, loss, args, transform)
38 | self.buffer = Buffer(self.args.buffer_size, self.device,
39 | self.args.gss_minibatch_size if
40 | self.args.gss_minibatch_size is not None
41 | else self.args.minibatch_size, self)
42 | self.alj_nepochs = self.args.batch_num
43 |
44 | def get_grads(self, inputs, labels):
45 | self.net.eval()
46 | self.opt.zero_grad()
47 | outputs = self.net(inputs)
48 | loss = self.loss(outputs, labels)
49 | loss.backward()
50 | grads = self.net.get_grads().clone().detach()
51 | self.opt.zero_grad()
52 | self.net.train()
53 | if len(grads.shape) == 1:
54 | grads = grads.unsqueeze(0)
55 | return grads
56 |
57 | def observe(self, inputs, labels, not_aug_inputs):
58 |
59 | real_batch_size = inputs.shape[0]
60 | self.buffer.drop_cache()
61 | self.buffer.reset_fathom()
62 |
63 | for _ in range(self.alj_nepochs):
64 | self.opt.zero_grad()
65 | if not self.buffer.is_empty():
66 | buf_inputs, buf_labels = self.buffer.get_data(
67 | self.args.minibatch_size, transform=self.transform)
68 | tinputs = torch.cat((inputs, buf_inputs))
69 | tlabels = torch.cat((labels, buf_labels))
70 | else:
71 | tinputs = inputs
72 | tlabels = labels
73 |
74 | outputs = self.net(tinputs)
75 | loss = self.loss(outputs, tlabels)
76 | loss.backward()
77 | self.opt.step()
78 |
79 | self.buffer.add_data(examples=not_aug_inputs,
80 | labels=labels[:real_batch_size])
81 |
82 | return loss.item()
83 |
--------------------------------------------------------------------------------
/models/pnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim as optim
7 | from torch.optim import SGD
8 | import torch
9 | import torch.nn as nn
10 | from utils.conf import get_device
11 | from utils.args import *
12 | from datasets import get_dataset
13 |
14 |
15 | def get_parser() -> ArgumentParser:
16 | parser = ArgumentParser(description='Continual Learning via'
17 | ' Progressive Neural Networks.')
18 | add_management_args(parser)
19 | add_experiment_args(parser)
20 | return parser
21 |
22 |
23 | def get_backbone(bone, old_cols=None, x_shape=None):
24 | from backbone.MNISTMLP import MNISTMLP
25 | from backbone.MNISTMLP_PNN import MNISTMLP_PNN
26 | from backbone.ResNet18 import ResNet
27 | from backbone.ResNet18_PNN import resnet18_pnn
28 |
29 | if isinstance(bone, MNISTMLP):
30 | return MNISTMLP_PNN(bone.input_size, bone.output_size, old_cols)
31 | elif isinstance(bone, ResNet):
32 | return resnet18_pnn(bone.num_classes, bone.nf, old_cols, x_shape)
33 | else:
34 | raise NotImplementedError('Progressive Neural Networks is not implemented for this backbone')
35 |
36 |
37 | class Pnn(nn.Module):
38 | NAME = 'pnn'
39 | COMPATIBILITY = ['task-il']
40 |
41 | def __init__(self, backbone, loss, args, transform):
42 | super(Pnn, self).__init__()
43 | self.loss = loss
44 | self.args = args
45 | self.transform = transform
46 | self.device = get_device()
47 | self.x_shape = None
48 | self.nets = [get_backbone(backbone).to(self.device)]
49 | self.net = self.nets[-1]
50 | self.opt = SGD(self.net.parameters(), lr=self.args.lr)
51 |
52 | self.soft = torch.nn.Softmax(dim=0)
53 | self.logsoft = torch.nn.LogSoftmax(dim=0)
54 | self.dataset = get_dataset(args)
55 | self.task_idx = 0
56 |
57 | def forward(self, x, task_label):
58 | if self.x_shape is None:
59 | self.x_shape = x.shape
60 |
61 | if self.task_idx == 0:
62 | out = self.net(x)
63 | else:
64 | self.nets[task_label].to(self.device)
65 | out = self.nets[task_label](x)
66 | if self.task_idx != task_label:
67 | self.nets[task_label].cpu()
68 | return out
69 |
70 | def end_task(self, dataset):
71 | # instantiate new column
72 | if self.task_idx == 4:
73 | return
74 | self.task_idx += 1
75 | self.nets[-1].cpu()
76 | self.nets.append(get_backbone(dataset.get_backbone(), self.nets, self.x_shape).to(self.device))
77 | self.net = self.nets[-1]
78 | self.opt = optim.SGD(self.net.parameters(), lr=self.args.lr)
79 |
80 | def observe(self, inputs, labels, not_aug_inputs):
81 | if self.x_shape is None:
82 | self.x_shape = inputs.shape
83 |
84 | self.opt.zero_grad()
85 | outputs = self.net(inputs)
86 | loss = self.loss(outputs, labels)
87 | loss.backward()
88 | self.opt.step()
89 |
90 | return loss.item()
91 |
--------------------------------------------------------------------------------
/utils/status.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from datetime import datetime
7 | import sys
8 | import os
9 | from utils.conf import base_path
10 | from typing import Any, Dict, Union
11 | from torch import nn
12 | from argparse import Namespace
13 | from datasets.utils.continual_dataset import ContinualDataset
14 |
15 |
16 | def create_stash(model: nn.Module, args: Namespace,
17 | dataset: ContinualDataset) -> Dict[Any, str]:
18 | """
19 | Creates the dictionary where to save the model status.
20 | :param model: the model
21 | :param args: the current arguments
22 | :param dataset: the dataset at hand
23 | """
24 | now = datetime.now()
25 | model_stash = {'task_idx': 0, 'epoch_idx': 0, 'batch_idx': 0}
26 | name_parts = [args.dataset, model.NAME]
27 | if 'buffer_size' in vars(args).keys():
28 | name_parts.append('buf_' + str(args.buffer_size))
29 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f"))
30 | model_stash['model_name'] = '/'.join(name_parts)
31 | model_stash['mean_accs'] = []
32 | model_stash['args'] = args
33 | model_stash['backup_folder'] = os.path.join(base_path(), 'backups',
34 | dataset.SETTING,
35 | model_stash['model_name'])
36 | return model_stash
37 |
38 |
39 | def create_fake_stash(model: nn.Module, args: Namespace) -> Dict[Any, str]:
40 | """
41 | Create a fake stash, containing just the model name.
42 | This is used in general continual, as it is useless to backup
43 | a lightweight MNIST-360 training.
44 | :param model: the model
45 | :param args: the arguments of the call
46 | :return: a dict containing a fake stash
47 | """
48 | now = datetime.now()
49 | model_stash = {'task_idx': 0, 'epoch_idx': 0}
50 | name_parts = [args.dataset, model.NAME]
51 | if 'buffer_size' in vars(args).keys():
52 | name_parts.append('buf_' + str(args.buffer_size))
53 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f"))
54 | model_stash['model_name'] = '/'.join(name_parts)
55 |
56 | return model_stash
57 |
58 |
59 | def progress_bar(i: int, max_iter: int, epoch: Union[int, str],
60 | task_number: int, loss: float) -> None:
61 | """
62 | Prints out the progress bar on the stderr file.
63 | :param i: the current iteration
64 | :param max_iter: the maximum number of iteration
65 | :param epoch: the epoch
66 | :param task_number: the task index
67 | :param loss: the current value of the loss function
68 | """
69 | if not (i + 1) % 100 or (i + 1) == max_iter:
70 | progress = min(float((i + 1) / max_iter), 1)
71 | progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress)))
72 | print('\r[ {} ] Task {} | epoch {}: |{}| loss: {}'.format(
73 | datetime.now().strftime("%m-%d | %H:%M"),
74 | task_number + 1 if isinstance(task_number, int) else task_number,
75 | epoch,
76 | progress_bar,
77 | round(loss, 8)
78 | ), file=sys.stderr, end='', flush=True)
79 |
--------------------------------------------------------------------------------
/models/mer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from utils.buffer import Buffer
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 |
11 |
12 | def get_parser() -> ArgumentParser:
13 | parser = ArgumentParser(description='Continual Learning via'
14 | ' Meta-Experience Replay.')
15 | add_management_args(parser)
16 | add_experiment_args(parser)
17 | add_rehearsal_args(parser)
18 | # remove batch_size from parser
19 | for i in range(len(parser._actions)):
20 | if parser._actions[i].dest == 'batch_size':
21 | del parser._actions[i]
22 | break
23 |
24 | parser.add_argument('--beta', type=float, required=True,
25 | help='Within-batch update beta parameter.')
26 | parser.add_argument('--gamma', type=float, required=True,
27 | help='Across-batch update gamma parameter.')
28 | parser.add_argument('--batch_num', type=int, required=True,
29 | help='Number of batches extracted from the buffer.')
30 |
31 | return parser
32 |
33 |
34 | class Mer(ContinualModel):
35 | NAME = 'mer'
36 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
37 |
38 | def __init__(self, backbone, loss, args, transform):
39 | super(Mer, self).__init__(backbone, loss, args, transform)
40 | self.buffer = Buffer(self.args.buffer_size, self.device)
41 | assert args.batch_size == 1, 'Mer only works with batch_size=1'
42 |
43 | def draw_batches(self, inp, lab):
44 | batches = []
45 | for i in range(self.args.batch_num):
46 | if not self.buffer.is_empty():
47 | buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform)
48 | inputs = torch.cat((buf_inputs, inp.unsqueeze(0)))
49 | labels = torch.cat((buf_labels, torch.tensor([lab]).to(self.device)))
50 | batches.append((inputs, labels))
51 | else:
52 | batches.append((inp.unsqueeze(0), torch.tensor([lab]).unsqueeze(0).to(self.device)))
53 | return batches
54 |
55 | def observe(self, inputs, labels, not_aug_inputs):
56 |
57 | batches = self.draw_batches(inputs, labels)
58 | theta_A0 = self.net.get_params().data.clone()
59 |
60 | for i in range(self.args.batch_num):
61 | theta_Wi0 = self.net.get_params().data.clone()
62 |
63 | batch_inputs, batch_labels = batches[i]
64 |
65 | # within-batch step
66 | self.opt.zero_grad()
67 | outputs = self.net(batch_inputs)
68 | loss = self.loss(outputs, batch_labels.squeeze(-1))
69 | loss.backward()
70 | self.opt.step()
71 |
72 | # within batch reptile meta-update
73 | new_params = theta_Wi0 + self.args.beta * (self.net.get_params() - theta_Wi0)
74 | self.net.set_params(new_params)
75 |
76 | self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0), labels=labels)
77 |
78 | # across batch reptile meta-update
79 | new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() - theta_A0)
80 | self.net.set_params(new_new_params)
81 |
82 | return loss.item()
83 |
--------------------------------------------------------------------------------
/utils/tb_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from utils.conf import base_path
8 | import os
9 | from argparse import Namespace
10 | from typing import Dict, Any
11 | import numpy as np
12 |
13 |
14 | class TensorboardLogger:
15 | def __init__(self, args: Namespace, setting: str) -> None:
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | self.settings = [setting]
19 | if setting == 'class-il':
20 | self.settings.append('task-il')
21 | self.loggers = {}
22 | self.name = args.model
23 | for a_setting in self.settings:
24 | self.loggers[a_setting] = SummaryWriter(
25 | os.path.join(base_path(), 'tensorboard_runs', a_setting, self.name))
26 | config_text = ', '.join(
27 | ["%s=%s" % (name, getattr(args, name)) for name in args.__dir__()
28 | if not name.startswith('_')])
29 | for a_logger in self.loggers.values():
30 | a_logger.add_text('config', config_text)
31 |
32 | def get_name(self) -> str:
33 | """
34 | :return: the name of the model
35 | """
36 | return self.name
37 |
38 | def log_accuracy(self, all_accs: np.ndarray, all_mean_accs: np.ndarray,
39 | args: Namespace, task_number: int) -> None:
40 | """
41 | Logs the current accuracy value for each task.
42 | :param all_accs: the accuracies (class-il, task-il) for each task
43 | :param all_mean_accs: the mean accuracies for (class-il, task-il)
44 | :param args: the arguments of the run
45 | :param task_number: the task index
46 | """
47 | mean_acc_common, mean_acc_task_il = all_mean_accs
48 | for setting, a_logger in self.loggers.items():
49 | mean_acc = mean_acc_task_il\
50 | if setting == 'task-il' else mean_acc_common
51 | index = 1 if setting == 'task-il' else 0
52 | accs = [all_accs[index][kk] for kk in range(len(all_accs[0]))]
53 | for kk, acc in enumerate(accs):
54 | a_logger.add_scalar('acc_task%02d' % (kk + 1), acc,
55 | task_number * args.n_epochs)
56 | a_logger.add_scalar('acc_mean', mean_acc, task_number * args.n_epochs)
57 |
58 | def log_loss(self, loss: float, args: Namespace, epoch: int,
59 | task_number: int, iteration: int) -> None:
60 | """
61 | Logs the loss value at each iteration.
62 | :param loss: the loss value
63 | :param args: the arguments of the run
64 | :param epoch: the epoch index
65 | :param task_number: the task index
66 | :param iteration: the current iteration
67 | """
68 | for a_logger in self.loggers.values():
69 | a_logger.add_scalar('loss', loss, task_number * args.n_epochs + epoch)
70 |
71 | def log_loss_gcl(self, loss: float, iteration: int) -> None:
72 | """
73 | Logs the loss value at each iteration.
74 | :param loss: the loss value
75 | :param iteration: the current iteration
76 | """
77 | for a_logger in self.loggers.values():
78 | a_logger.add_scalar('loss', loss, iteration)
79 |
80 | def close(self) -> None:
81 | """
82 | At the end of the execution, closes the logger.
83 | """
84 | for a_logger in self.loggers.values():
85 | a_logger.close()
86 |
--------------------------------------------------------------------------------
/models/fdr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from utils.buffer import Buffer
7 | from utils.args import *
8 | from models.utils.continual_model import ContinualModel
9 | import torch
10 |
11 |
12 | def get_parser() -> ArgumentParser:
13 | parser = ArgumentParser(description='Continual learning via'
14 | ' Dark Experience Replay.')
15 | add_management_args(parser)
16 | add_experiment_args(parser)
17 | add_rehearsal_args(parser)
18 | parser.add_argument('--alpha', type=float, required=True,
19 | help='Penalty weight.')
20 | return parser
21 |
22 |
23 | class Fdr(ContinualModel):
24 | NAME = 'fdr'
25 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
26 |
27 | def __init__(self, backbone, loss, args, transform):
28 | super(Fdr, self).__init__(backbone, loss, args, transform)
29 | self.buffer = Buffer(self.args.buffer_size, self.device)
30 | self.current_task = 0
31 | self.i = 0
32 | self.soft = torch.nn.Softmax(dim=1)
33 | self.logsoft = torch.nn.LogSoftmax(dim=1)
34 |
35 | def end_task(self, dataset):
36 | self.current_task += 1
37 | examples_per_task = self.args.buffer_size // self.current_task
38 |
39 | if self.current_task > 1:
40 | buf_x, buf_log, buf_tl = self.buffer.get_all_data()
41 | self.buffer.empty()
42 |
43 | for ttl in buf_tl.unique():
44 | idx = (buf_tl == ttl)
45 | ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx]
46 | first = min(ex.shape[0], examples_per_task)
47 | self.buffer.add_data(
48 | examples=ex[:first],
49 | logits=log[:first],
50 | task_labels=tasklab[:first]
51 | )
52 | counter = 0
53 | with torch.no_grad():
54 | for i, data in enumerate(dataset.train_loader):
55 | inputs, labels, not_aug_inputs = data
56 | inputs = inputs.to(self.device)
57 | not_aug_inputs = not_aug_inputs.to(self.device)
58 | outputs = self.net(inputs)
59 | if examples_per_task - counter < 0:
60 | break
61 | self.buffer.add_data(examples=not_aug_inputs[:(examples_per_task - counter)],
62 | logits=outputs.data[:(examples_per_task - counter)],
63 | task_labels=(torch.ones(self.args.batch_size) *
64 | (self.current_task - 1))[:(examples_per_task - counter)])
65 | counter += self.args.batch_size
66 |
67 | def observe(self, inputs, labels, not_aug_inputs):
68 | self.i += 1
69 |
70 | self.opt.zero_grad()
71 | outputs = self.net(inputs)
72 | loss = self.loss(outputs, labels)
73 | loss.backward()
74 | self.opt.step()
75 | if not self.buffer.is_empty():
76 | self.opt.zero_grad()
77 | buf_inputs, buf_logits, _ = self.buffer.get_data(self.args.minibatch_size, transform=self.transform)
78 | buf_outputs = self.net(buf_inputs)
79 | loss = torch.norm(self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean()
80 | assert not torch.isnan(loss)
81 | loss.backward()
82 | self.opt.step()
83 |
84 | return loss.item()
85 |
--------------------------------------------------------------------------------
/datasets/seq_mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torchvision.datasets import MNIST
7 | import torchvision.transforms as transforms
8 | from torch.utils.data import DataLoader
9 | from backbone.MNISTMLP import MNISTMLP
10 | import torch.nn.functional as F
11 | from utils.conf import base_path
12 | from PIL import Image
13 | import numpy as np
14 | from datasets.utils.validation import get_train_val
15 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
16 | from typing import Tuple
17 |
18 |
19 | class MyMNIST(MNIST):
20 | """
21 | Overrides the MNIST dataset to change the getitem function.
22 | """
23 | def __init__(self, root, train=True, transform=None,
24 | target_transform=None, download=False) -> None:
25 | self.not_aug_transform = transforms.ToTensor()
26 | super(MyMNIST, self).__init__(root, train,
27 | transform, target_transform, download)
28 |
29 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
30 | """
31 | Gets the requested element from the dataset.
32 | :param index: index of the element to be returned
33 | :returns: tuple: (image, target) where target is index of the target class.
34 | """
35 | img, target = self.data[index], self.targets[index]
36 |
37 | # doing this so that it is consistent with all other datasets
38 | # to return a PIL Image
39 | img = Image.fromarray(img.numpy(), mode='L')
40 | original_img = self.not_aug_transform(img.copy())
41 |
42 | if self.transform is not None:
43 | img = self.transform(img)
44 |
45 | if self.target_transform is not None:
46 | target = self.target_transform(target)
47 |
48 | if hasattr(self, 'logits'):
49 | return img, target, original_img, self.logits[index]
50 |
51 | return img, target, original_img
52 |
53 |
54 | class SequentialMNIST(ContinualDataset):
55 |
56 | NAME = 'seq-mnist'
57 | SETTING = 'class-il'
58 | N_CLASSES_PER_TASK = 2
59 | N_TASKS = 5
60 | TRANSFORM = None
61 |
62 | def get_data_loaders(self):
63 | transform = transforms.ToTensor()
64 | train_dataset = MyMNIST(base_path() + 'MNIST',
65 | train=True, download=True, transform=transform)
66 | if self.args.validation:
67 | train_dataset, test_dataset = get_train_val(train_dataset,
68 | transform, self.NAME)
69 | else:
70 | test_dataset = MNIST(base_path() + 'MNIST',
71 | train=False, download=True, transform=transform)
72 |
73 | train, test = store_masked_loaders(train_dataset, test_dataset, self)
74 | return train, test
75 |
76 | @staticmethod
77 | def get_backbone():
78 | return MNISTMLP(28 * 28, SequentialMNIST.N_TASKS
79 | * SequentialMNIST.N_CLASSES_PER_TASK)
80 |
81 | @staticmethod
82 | def get_transform():
83 | return None
84 |
85 | @staticmethod
86 | def get_loss():
87 | return F.cross_entropy
88 |
89 | @staticmethod
90 | def get_normalization_transform():
91 | return None
92 |
93 | @staticmethod
94 | def get_denormalization_transform():
95 | return None
96 |
97 | @staticmethod
98 | def get_scheduler(model, args):
99 | return None
--------------------------------------------------------------------------------
/datasets/perm_mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torchvision.datasets import MNIST
7 | import torchvision.transforms as transforms
8 | from datasets.transforms.permutation import Permutation
9 | from torch.utils.data import DataLoader
10 | from backbone.MNISTMLP import MNISTMLP
11 | import torch.nn.functional as F
12 | from utils.conf import base_path
13 | from PIL import Image
14 | from datasets.utils.validation import get_train_val
15 | from typing import Tuple
16 | from datasets.utils.continual_dataset import ContinualDataset
17 |
18 |
19 | def store_mnist_loaders(transform, setting):
20 | train_dataset = MyMNIST(base_path() + 'MNIST',
21 | train=True, download=True, transform=transform)
22 | if setting.args.validation:
23 | train_dataset, test_dataset = get_train_val(train_dataset,
24 | transform, setting.NAME)
25 | else:
26 | test_dataset = MNIST(base_path() + 'MNIST',
27 | train=False, download=True, transform=transform)
28 |
29 | train_loader = DataLoader(train_dataset,
30 | batch_size=setting.args.batch_size, shuffle=True)
31 | test_loader = DataLoader(test_dataset,
32 | batch_size=setting.args.batch_size, shuffle=False)
33 | setting.test_loaders.append(test_loader)
34 | setting.train_loader = train_loader
35 |
36 | return train_loader, test_loader
37 |
38 |
39 | class MyMNIST(MNIST):
40 | """
41 | Overrides the MNIST dataset to change the getitem function.
42 | """
43 | def __init__(self, root, train=True, transform=None,
44 | target_transform=None, download=False) -> None:
45 | super(MyMNIST, self).__init__(root, train, transform,
46 | target_transform, download)
47 |
48 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
49 | """
50 | Gets the requested element from the dataset.
51 | :param index: index of the element to be returned
52 | :returns: tuple: (image, target) where target is index of the target class.
53 | """
54 | img, target = self.data[index], int(self.targets[index])
55 |
56 | # doing this so that it is consistent with all other datasets
57 | # to return a PIL Image
58 | img = Image.fromarray(img.numpy(), mode='L')
59 |
60 | if self.transform is not None:
61 | img = self.transform(img)
62 |
63 | if self.target_transform is not None:
64 | target = self.target_transform(target)
65 |
66 | return img, target, img
67 |
68 |
69 | class PermutedMNIST(ContinualDataset):
70 |
71 | NAME = 'perm-mnist'
72 | SETTING = 'domain-il'
73 | N_CLASSES_PER_TASK = 10
74 | N_TASKS = 20
75 |
76 | def get_data_loaders(self):
77 | transform = transforms.Compose((transforms.ToTensor(), Permutation()))
78 | train, test = store_mnist_loaders(transform, self)
79 | return train, test
80 |
81 | @staticmethod
82 | def get_backbone():
83 | return MNISTMLP(28 * 28, PermutedMNIST.N_CLASSES_PER_TASK)
84 |
85 | @staticmethod
86 | def get_transform():
87 | return None
88 |
89 | @staticmethod
90 | def get_normalization_transform():
91 | return None
92 |
93 | @staticmethod
94 | def get_denormalization_transform():
95 | return None
96 |
97 | @staticmethod
98 | def get_loss():
99 | return F.cross_entropy
100 |
101 | @staticmethod
102 | def get_scheduler(model, args):
103 | return None
104 |
--------------------------------------------------------------------------------
/utils/debug.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import copy
8 | import math
9 | import torch
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.nn import functional as F
13 | import math
14 |
15 | def euclid2polar(feat):
16 | norm_x = torch.norm(feat.clone(), 2, 1, keepdim=True)
17 | feat = feat / norm_x
18 | feat_2 = feat * feat
19 | length = feat.shape[1] - 1
20 | for idx in range(length):
21 | sum_ = torch.sum(feat_2[:,idx:], dim=1)
22 | if idx != length -1:
23 | polar = torch.arccos(feat[:, idx]/torch.sqrt(sum_))
24 | else:
25 | polar = torch.arccos(feat[:, idx]/torch.sqrt(sum_))
26 | if idx == 0:
27 | polar_feat = polar
28 | else:
29 | if len(polar_feat.shape) <2:
30 | polar_feat, polar = polar_feat.unsqueeze(-1), polar.unsqueeze(-1)
31 | else:
32 | polar = polar.unsqueeze(-1)
33 | polar_feat = torch.cat((polar_feat, polar), dim=1)
34 | return polar_feat
35 |
36 | def polar2euclid(feat):
37 | length = feat.shape[1] + 1
38 | sin_feat = torch.sin(feat)
39 | cos_feat = torch.cos(feat)
40 |
41 | sin_product = torch.cumprod(sin_feat, dim=1)
42 | for idx in range(length):
43 | if idx == 0 :
44 | euclid = cos_feat[:, 0]
45 | else:
46 | if idx != length -1 :
47 | euclid = sin_product[:, idx -1]
48 | euclid = euclid * cos_feat[: , idx]
49 | else:
50 | euclid = sin_product[:, idx -2]
51 | euclid = euclid * sin_feat[: , idx-1]
52 | if idx == 0:
53 | euclid_feat = euclid
54 | else:
55 | if len(euclid_feat.shape) <2:
56 | euclid_feat, euclid = euclid_feat.unsqueeze(-1), euclid.unsqueeze(-1)
57 | else:
58 | euclid = euclid.unsqueeze(-1)
59 | euclid_feat = torch.cat((euclid_feat, euclid), dim=1)
60 |
61 | return euclid_feat
62 | # feat = torch.normal(0.0, 1, size=(32,512))
63 | # print(0,feat[0])
64 | # feat = euclid2polar(feat)
65 | # print(1,feat[0])
66 | # polar_noise = torch.from_numpy(np.random.vonmises(0, 0000000, (feat.shape[0], feat.shape[1])))
67 | # polar_noise.to(feat)
68 | # feat = feat + polar_noise
69 | # print(2,feat[0])
70 | # feat = polar2euclid(feat)
71 | # print(3,feat[0])
72 |
73 | def visualize_2d(feat, labels, step):
74 | feat = feat.numpy()
75 | plt.figure(figsize=(6, 6))
76 | plt.ion()
77 | c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
78 | '#ff00ff', '#990000', '#999900', '#009900', '#009999']
79 | plt.clf()
80 | for i in range(10):
81 | plt.plot(feat[:, 0], feat[:, 1], '.')
82 | XMax = np.max(feat[:,0])
83 | XMin = np.min(feat[:,1])
84 | YMax = np.max(feat[:,0])
85 | YMin = np.min(feat[:,1])
86 |
87 | plt.xlim(xmin=XMin,xmax=XMax)
88 | plt.ylim(ymin=YMin,ymax=YMax)
89 | plt.savefig('./%s.pdf' % str(step))
90 |
91 | def visualize_3d(feat, labels, step):
92 | feat = feat.numpy()
93 | fig = plt.figure(figsize=(8, 8))
94 | ax = fig.add_subplot(projection='3d')
95 |
96 |
97 | ax.scatter(feat[:,0], feat[:,1], feat[:,2])
98 | plt.show()
99 | plt.savefig('./%s.pdf' % str(step))
100 | plt.show()
101 |
102 | # feat = torch.from_numpy(np.random.vonmises(0, 1000, (1000, 2)))
103 | # feat = feat
104 | # mask = feat < 0
105 | # copy_feat = copy.deepcopy(feat)
106 | # feat[mask] = -feat[mask]
107 | # # feat[:, :-1] = copy_feat[:, :-1]
108 | # feat = polar2euclid(feat)
109 | # visualize_3d(feat, feat, 'vmf')
110 |
111 | # feat2 = euclid2polar(feat)
112 |
113 | # feat2 = polar2euclid(feat2)
114 |
115 | # visualize_3d(feat2, feat2, 'vmf2')
116 |
117 | # true_ = feat == feat2
118 | # print(true_)
119 | # print(feat,feat2)
--------------------------------------------------------------------------------
/models/gdumb.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from utils.args import *
7 | from models.utils.continual_model import ContinualModel
8 | from torch.optim import SGD, lr_scheduler
9 | import math
10 | from utils.buffer import Buffer
11 | import torch
12 | from utils.augmentations import cutmix_data
13 | import numpy as np
14 | from utils.status import progress_bar
15 |
16 | def get_parser() -> ArgumentParser:
17 | parser = ArgumentParser(description='Continual Learning via'
18 | ' Progressive Neural Networks.')
19 | add_management_args(parser)
20 | add_rehearsal_args(parser)
21 | parser.add_argument('--maxlr', type=float, default=5e-2,
22 | help='Penalty weight.')
23 | parser.add_argument('--minlr', type=float, default=5e-4,
24 | help='Penalty weight.')
25 | parser.add_argument('--fitting_epochs', type=int, default=256,
26 | help='Penalty weight.')
27 | parser.add_argument('--cutmix_alpha', type=float, default=None,
28 | help='Penalty weight.')
29 | add_experiment_args(parser)
30 | return parser
31 |
32 | def fit_buffer(self, epochs):
33 | for epoch in range(epochs):
34 |
35 | optimizer = SGD(self.net.parameters(), lr=self.args.maxlr, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd, nesterov=self.args.optim_nesterov)
36 | scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=self.args.minlr)
37 |
38 | if epoch <= 0: # Warm start of 1 epoch
39 | for param_group in optimizer.param_groups:
40 | param_group['lr'] = self.args.maxlr * 0.1
41 | elif epoch == 1: # Then set to maxlr
42 | for param_group in optimizer.param_groups:
43 | param_group['lr'] = self.args.maxlr
44 | else:
45 | scheduler.step()
46 |
47 | all_inputs, all_labels = self.buffer.get_data(
48 | len(self.buffer.examples), transform=self.transform)
49 |
50 | while len(all_inputs):
51 | optimizer.zero_grad()
52 | buf_inputs, buf_labels = all_inputs[:self.args.batch_size], all_labels[:self.args.batch_size]
53 | all_inputs, all_labels = all_inputs[self.args.batch_size:], all_labels[self.args.batch_size:]
54 |
55 | if self.args.cutmix_alpha is not None:
56 | inputs, labels_a, labels_b, lam = cutmix_data(x=buf_inputs.cpu(), y=buf_labels.cpu(), alpha=self.args.cutmix_alpha)
57 | buf_inputs = inputs.to(self.device)
58 | buf_labels_a = labels_a.to(self.device)
59 | buf_labels_b = labels_b.to(self.device)
60 | buf_outputs = self.net(buf_inputs)
61 | loss = lam * self.loss(buf_outputs, buf_labels_a) + (1 - lam) * self.loss(buf_outputs, buf_labels_b)
62 | else:
63 | buf_outputs = self.net(buf_inputs)
64 | loss = self.loss(buf_outputs, buf_labels)
65 |
66 | loss.backward()
67 | optimizer.step()
68 | progress_bar(epoch, epochs, 1, 'G', loss.item())
69 |
70 | class GDumb(ContinualModel):
71 | NAME = 'gdumb'
72 | COMPATIBILITY = ['class-il', 'task-il']
73 |
74 | def __init__(self, backbone, loss, args, transform):
75 | super(GDumb, self).__init__(backbone, loss, args, transform)
76 | self.buffer = Buffer(self.args.buffer_size, self.device)
77 | self.task = 0
78 |
79 | def observe(self, inputs, labels, not_aug_inputs):
80 | self.buffer.add_data(examples=not_aug_inputs,
81 | labels=labels)
82 | return 0
83 |
84 | def end_task(self, dataset):
85 | # new model
86 | self.task += 1
87 | if not (self.task == dataset.N_TASKS):
88 | return
89 | self.net = dataset.get_backbone().to(self.device)
90 | fit_buffer(self, self.args.fitting_epochs)
--------------------------------------------------------------------------------
/utils/simclrloss.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Yonglong Tian (yonglong@mit.edu)
3 | Date: May 07, 2020
4 | Source: https://github.com/HobbitLong/SupContrast/blob/master/losses.py
5 | """
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class SupConLoss(nn.Module):
13 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
14 | It also supports the unsupervised contrastive loss in SimCLR"""
15 | def __init__(self, temperature=0.07, contrast_mode='all',
16 | base_temperature=0.07, reduction='mean'):
17 | super(SupConLoss, self).__init__()
18 | self.temperature = temperature
19 | self.contrast_mode = contrast_mode
20 | self.base_temperature = base_temperature
21 | self.reduction = reduction
22 |
23 | def forward(self, features, labels=None, mask=None):
24 | """Compute loss for model. If both `labels` and `mask` are None,
25 | it degenerates to SimCLR unsupervised loss:
26 | https://arxiv.org/pdf/2002.05709.pdf
27 | Args:
28 | features: hidden vector of shape [bsz, n_views, ...].
29 | labels: ground truth of shape [bsz].
30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
31 | has the same class as sample i. Can be asymmetric.
32 | Returns:
33 | A loss scalar.
34 | """
35 | device = features.device
36 |
37 | if len(features.shape) < 3:
38 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
39 | 'at least 3 dimensions are required')
40 | if len(features.shape) > 3:
41 | features = features.view(features.shape[0], features.shape[1], -1)
42 |
43 | batch_size = features.shape[0]
44 | if labels is not None and mask is not None:
45 | raise ValueError('Cannot define both `labels` and `mask`')
46 | elif labels is None and mask is None:
47 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
48 | elif labels is not None:
49 | labels = labels.contiguous().view(-1, 1)
50 | if labels.shape[0] != batch_size:
51 | raise ValueError('Num of labels does not match num of features')
52 | mask = torch.eq(labels, labels.T).float().to(device)
53 | else:
54 | mask = mask.float().to(device)
55 |
56 | contrast_count = features.shape[1]
57 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
58 | if self.contrast_mode == 'one':
59 | anchor_feature = features[:, 0]
60 | anchor_count = 1
61 | elif self.contrast_mode == 'all':
62 | anchor_feature = contrast_feature
63 | anchor_count = contrast_count
64 | else:
65 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
66 |
67 | # compute logits
68 | anchor_dot_contrast = torch.div(
69 | torch.matmul(anchor_feature, contrast_feature.T),
70 | self.temperature)
71 | # for numerical stability
72 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
73 | logits = anchor_dot_contrast - logits_max.detach()
74 |
75 | # tile mask
76 | mask = mask.repeat(anchor_count, contrast_count)
77 | # mask-out self-contrast cases
78 | logits_mask = torch.scatter(
79 | torch.ones_like(mask),
80 | 1,
81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
82 | 0
83 | )
84 | mask = mask * logits_mask
85 |
86 | # compute log_prob
87 | exp_logits = torch.exp(logits) * logits_mask
88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
89 |
90 | # compute mean of log-likelihood over positive
91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
92 |
93 | # loss
94 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
95 | loss = loss.view(anchor_count, batch_size).mean(0)
96 |
97 | return loss.mean() if self.reduction == 'mean' else loss.sum()
--------------------------------------------------------------------------------
/models/lwf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from datasets import get_dataset
8 | from torch.optim import SGD
9 | from utils.args import *
10 | from models.utils.continual_model import ContinualModel
11 |
12 |
13 | def get_parser() -> ArgumentParser:
14 | parser = ArgumentParser(description='Continual learning via'
15 | ' Learning without Forgetting.')
16 | add_management_args(parser)
17 | add_experiment_args(parser)
18 | parser.add_argument('--alpha', type=float, default=0.5,
19 | help='Penalty weight.')
20 | parser.add_argument('--softmax_temp', type=float, default=2,
21 | help='Temperature of the softmax function.')
22 | return parser
23 |
24 |
25 | def smooth(logits, temp, dim):
26 | log = logits ** (1 / temp)
27 | return log / torch.sum(log, dim).unsqueeze(1)
28 |
29 |
30 | def modified_kl_div(old, new):
31 | return -torch.mean(torch.sum(old * torch.log(new), 1))
32 |
33 |
34 | class Lwf(ContinualModel):
35 | NAME = 'lwf'
36 | COMPATIBILITY = ['class-il', 'task-il']
37 |
38 | def __init__(self, backbone, loss, args, transform):
39 | super(Lwf, self).__init__(backbone, loss, args, transform)
40 | self.old_net = None
41 | self.soft = torch.nn.Softmax(dim=1)
42 | self.logsoft = torch.nn.LogSoftmax(dim=1)
43 | self.dataset = get_dataset(args)
44 | self.current_task = 0
45 | self.cpt = get_dataset(args).N_CLASSES_PER_TASK
46 | nc = get_dataset(args).N_TASKS * self.cpt
47 | self.eye = torch.tril(torch.ones((nc, nc))).bool().to(self.device)
48 |
49 | def begin_task(self, dataset):
50 | self.net.eval()
51 | if self.current_task > 0:
52 | # warm-up
53 | opt = SGD(self.net.classifier.parameters(), lr=self.args.lr)
54 | for epoch in range(self.args.n_epochs):
55 | for i, data in enumerate(dataset.train_loader):
56 | inputs, labels, not_aug_inputs = data
57 | inputs, labels = inputs.to(self.device), labels.to(self.device)
58 | opt.zero_grad()
59 | with torch.no_grad():
60 | feats = self.net(inputs, returnt='features')
61 | mask = self.eye[(self.current_task + 1) * self.cpt - 1] ^ self.eye[self.current_task * self.cpt - 1]
62 | outputs = self.net.classifier(feats)[:, mask]
63 | loss = self.loss(outputs, labels - self.current_task * self.cpt)
64 | loss.backward()
65 | opt.step()
66 |
67 | logits = []
68 | with torch.no_grad():
69 | for i in range(0, dataset.train_loader.dataset.data.shape[0], self.args.batch_size):
70 | inputs = torch.stack([dataset.train_loader.dataset.__getitem__(j)[2]
71 | for j in range(i, min(i + self.args.batch_size,
72 | len(dataset.train_loader.dataset)))])
73 | log = self.net(inputs.to(self.device)).cpu()
74 | logits.append(log)
75 | setattr(dataset.train_loader.dataset, 'logits', torch.cat(logits))
76 | self.net.train()
77 |
78 | self.current_task += 1
79 |
80 | def observe(self, inputs, labels, not_aug_inputs, logits=None):
81 | self.opt.zero_grad()
82 | outputs = self.net(inputs)
83 |
84 | mask = self.eye[self.current_task * self.cpt - 1]
85 | loss = self.loss(outputs[:, mask], labels)
86 | if logits is not None:
87 | mask = self.eye[(self.current_task - 1) * self.cpt - 1]
88 | loss += self.args.alpha * modified_kl_div(smooth(self.soft(logits[:, mask]).to(self.device), 2, 1),
89 | smooth(self.soft(outputs[:, mask]), 2, 1))
90 |
91 | loss.backward()
92 | self.opt.step()
93 |
94 | return loss.item()
95 |
--------------------------------------------------------------------------------
/utils/augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import numpy as np
9 |
10 | def rand_bbox(size, lam):
11 | W = size[2]
12 | H = size[3]
13 | cut_rat = np.sqrt(1. - lam)
14 | cut_w = np.int(W * cut_rat)
15 | cut_h = np.int(H * cut_rat)
16 |
17 | # uniform
18 | cx = np.random.randint(W)
19 | cy = np.random.randint(H)
20 |
21 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
22 | bby1 = np.clip(cy - cut_h // 2, 0, H)
23 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
24 | bby2 = np.clip(cy + cut_h // 2, 0, H)
25 |
26 | return bbx1, bby1, bbx2, bby2
27 |
28 | def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5):
29 | assert(alpha > 0)
30 | # generate mixed sample
31 | lam = np.random.beta(alpha, alpha)
32 |
33 | batch_size = x.size()[0]
34 | index = torch.randperm(batch_size)
35 |
36 | if torch.cuda.is_available():
37 | index = index.cuda()
38 |
39 | y_a, y_b = y, y[index]
40 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
41 | x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
42 |
43 | # adjust lambda to exactly match pixel ratio
44 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
45 | return x, y_a, y_b, lam
46 |
47 | def normalize(x, mean, std):
48 | assert len(x.shape) == 4
49 | return (x - torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)) \
50 | / torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
51 |
52 | def random_flip(x):
53 | assert len(x.shape) == 4
54 | mask = torch.rand(x.shape[0]) < 0.5
55 | x[mask] = x[mask].flip(3)
56 | return x
57 |
58 | def random_grayscale(x, prob=0.2):
59 | assert len(x.shape) == 4
60 | mask = torch.rand(x.shape[0]) < prob
61 | x[mask] = (x[mask] * torch.tensor([[0.299,0.587,0.114]]).unsqueeze(2).unsqueeze(2).to(x.device)).sum(1, keepdim=True).repeat_interleave(3, 1)
62 | return x
63 |
64 | def random_crop(x, padding):
65 | assert len(x.shape) == 4
66 | crop_x = torch.randint(-padding, padding, size=(x.shape[0],))
67 | crop_y = torch.randint(-padding, padding, size=(x.shape[0],))
68 |
69 | crop_x_start, crop_y_start = crop_x + padding, crop_y + padding
70 | crop_x_end, crop_y_end = crop_x_start + x.shape[-1], crop_y_start + x.shape[-2]
71 |
72 | oboe = F.pad(x, (padding, padding, padding, padding))
73 | mask_x = torch.arange(x.shape[-1] + padding * 2).repeat(x.shape[0], x.shape[-1] + padding * 2, 1)
74 | mask_y = mask_x.transpose(1,2)
75 | mask_x = ((mask_x >= crop_x_start.unsqueeze(1).unsqueeze(2)) & (mask_x < crop_x_end.unsqueeze(1).unsqueeze(2)))
76 | mask_y = ((mask_y >= crop_y_start.unsqueeze(1).unsqueeze(2)) & (mask_y < crop_y_end.unsqueeze(1).unsqueeze(2)))
77 | return oboe[mask_x.unsqueeze(1).repeat(1,x.shape[1],1,1) * mask_y.unsqueeze(1).repeat(1,x.shape[1],1,1)].reshape(x.shape[0], 3, x.shape[2], x.shape[3])
78 |
79 | class soft_aug():
80 |
81 | def __init__(self, mean, std):
82 | self.mean = mean
83 | self.std = std
84 |
85 | def __call__(self, x):
86 | return normalize(
87 | random_flip(
88 | random_crop(x, 4)
89 | ),
90 | self.mean, self.std)
91 | class strong_aug():
92 |
93 | def __init__(self, size, mean, std):
94 | from torchvision import transforms
95 | self.transform = transforms.Compose([
96 | transforms.ToPILImage(),
97 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
98 | transforms.RandomApply([
99 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
100 | ], p=0.8),
101 | transforms.ToTensor()
102 | ])
103 | self.mean = mean
104 | self.std = std
105 |
106 | def __call__(self, x):
107 | flip = random_flip(x)
108 | return normalize(random_grayscale(
109 | torch.stack(
110 | [self.transform(a) for a in flip]
111 | )), self.mean, self.std)
112 |
--------------------------------------------------------------------------------
/models/utils/continual_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 | from torch.optim import SGD
8 | import torch
9 | import torchvision
10 | from argparse import Namespace
11 | from utils.conf import get_device
12 | import numpy as np
13 | import copy
14 | from torch.nn import functional as F
15 | import math
16 | from torch.optim import SGD
17 | from collections import OrderedDict
18 | EPS = 1E-20
19 |
20 | class ContinualModel(nn.Module):
21 | """
22 | Continual learning model.
23 | """
24 | NAME = None
25 | COMPATIBILITY = []
26 |
27 | def __init__(self, backbone: nn.Module, loss: nn.Module,
28 | args: Namespace, transform: torchvision.transforms) -> None:
29 | super(ContinualModel, self).__init__()
30 |
31 | self.net = backbone
32 | self.loss = loss
33 | self.args = args
34 | self.transform = transform
35 | self.opt = SGD(self.net.parameters(), lr=self.args.lr)
36 | self.device = get_device()
37 |
38 | self.current_task = 0
39 | self.buff_feature = []
40 | self.buff_labels = []
41 | self.new_feature = []
42 | self.new_labels = []
43 | self.buff_noise = []
44 | self.EMA_cLass_mean = {}
45 | self.proxy = copy.deepcopy(self.net)
46 | self.proxy.to(self.device)
47 | self.proxy_optim = SGD(self.proxy.parameters(), lr=self.args.lr)
48 | self.theta_list = []
49 |
50 | self.all_iteration = 0
51 | for i in range(201):
52 | self.EMA_cLass_mean[i] = 0.01 + torch.zeros(512).to(self.device)
53 |
54 | def diff_in_weights(self, model, proxy):
55 | diff_dict = OrderedDict()
56 | model_state_dict = model.state_dict()
57 | proxy_state_dict = proxy.state_dict()
58 | for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()):
59 | if len(old_w.size()) <= 1:
60 | continue
61 | if 'weight' in old_k:
62 | diff_w = new_w - old_w
63 | diff_dict[old_k] = old_w.norm() / (diff_w.norm() + EPS) * diff_w
64 | return diff_dict
65 |
66 | def calc_awp(self, inputs, targets):
67 | if self.all_iteration > 1000:
68 | if self.all_iteration % self.args.inner_iter ==0:
69 | self.proxy.load_state_dict(self.net.state_dict())
70 | else:
71 | self.proxy.load_state_dict(self.net.state_dict())
72 | self.proxy.train()
73 |
74 | if self.args.advloss == 'nega':
75 | loss = - F.cross_entropy(self.proxy(inputs), targets)
76 | else:
77 | loss = F.cross_entropy(self.proxy(inputs), targets)
78 | loss = self.args.gamma_loss * loss
79 | self.proxy_optim.zero_grad()
80 | loss.backward()
81 | self.proxy_optim.step()
82 |
83 | # the adversary weight perturb
84 | diff = self.diff_in_weights(self.net, self.proxy)
85 | return diff
86 |
87 | def norm_scale(self, theta, theta_limit):
88 | theta = torch.arccos(theta)
89 | theta_limit = torch.tensor(theta_limit/180 * math.pi).to(theta.device)
90 | norm_scale = torch.sin(theta_limit) / (torch.sin(theta) * torch.cos(theta_limit) - torch.cos(theta) * torch.sin(theta_limit))
91 | return norm_scale
92 |
93 | def forward(self, x: torch.Tensor) -> torch.Tensor:
94 | """
95 | Computes a forward pass.
96 | :param x: batch of inputs
97 | :param task_label: some models require the task label
98 | :return: the result of the computation
99 | """
100 | return self.net(x)
101 |
102 | def observe(self, inputs: torch.Tensor, labels: torch.Tensor,
103 | not_aug_inputs: torch.Tensor) -> float:
104 | """
105 | Compute a training step over a given batch of examples.
106 | :param inputs: batch of examples
107 | :param labels: ground-truth labels
108 | :param kwargs: some methods could require additional parameters
109 | :return: the value of the loss function
110 | """
111 | pass
--------------------------------------------------------------------------------
/datasets/seq_cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torchvision.datasets import CIFAR10
7 | import torchvision.transforms as transforms
8 | from backbone.ResNet18 import resnet18
9 | import torch.nn.functional as F
10 | from datasets.seq_tinyimagenet import base_path
11 | from PIL import Image
12 | from datasets.utils.validation import get_train_val
13 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
14 | from typing import Tuple
15 | from datasets.transforms.denormalization import DeNormalize
16 |
17 | class MyCIFAR10(CIFAR10):
18 | """
19 | Overrides the CIFAR10 dataset to change the getitem function.
20 | """
21 | def __init__(self, root, train=True, transform=None,
22 | target_transform=None, download=False) -> None:
23 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
24 | super(MyCIFAR10, self).__init__(root, train, transform, target_transform, download)
25 |
26 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
27 | """
28 | Gets the requested element from the dataset.
29 | :param index: index of the element to be returned
30 | :returns: tuple: (image, target) where target is index of the target class.
31 | """
32 | img, target = self.data[index], self.targets[index]
33 |
34 | # to return a PIL Image
35 | img = Image.fromarray(img, mode='RGB')
36 | original_img = img.copy()
37 |
38 | not_aug_img = self.not_aug_transform(original_img)
39 |
40 | if self.transform is not None:
41 | img = self.transform(img)
42 |
43 | if self.target_transform is not None:
44 | target = self.target_transform(target)
45 |
46 | if hasattr(self, 'logits'):
47 | return img, target, not_aug_img, self.logits[index]
48 |
49 | return img, target, not_aug_img
50 |
51 |
52 | class SequentialCIFAR10(ContinualDataset):
53 |
54 | NAME = 'seq-cifar10'
55 | SETTING = 'class-il'
56 | N_CLASSES_PER_TASK = 2
57 | N_TASKS = 5
58 | TRANSFORM = transforms.Compose(
59 | [transforms.RandomCrop(32, padding=4),
60 | transforms.RandomHorizontalFlip(),
61 | transforms.ToTensor(),
62 | transforms.Normalize((0.4914, 0.4822, 0.4465),
63 | (0.2470, 0.2435, 0.2615))])
64 |
65 | def get_data_loaders(self):
66 | transform = self.TRANSFORM
67 |
68 | test_transform = transforms.Compose(
69 | [transforms.ToTensor(), self.get_normalization_transform()])
70 |
71 | train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True,
72 | download=True, transform=transform)
73 | if self.args.validation:
74 | train_dataset, test_dataset = get_train_val(train_dataset,
75 | test_transform, self.NAME)
76 | else:
77 | test_dataset = CIFAR10(base_path() + 'CIFAR10',train=False,
78 | download=True, transform=test_transform)
79 |
80 | train, test = store_masked_loaders(train_dataset, test_dataset, self)
81 | return train, test
82 |
83 | @staticmethod
84 | def get_transform():
85 | transform = transforms.Compose(
86 | [transforms.ToPILImage(), SequentialCIFAR10.TRANSFORM])
87 | return transform
88 |
89 | @staticmethod
90 | def get_backbone():
91 | return resnet18(SequentialCIFAR10.N_CLASSES_PER_TASK
92 | * SequentialCIFAR10.N_TASKS)
93 |
94 | @staticmethod
95 | def get_loss():
96 | return F.cross_entropy
97 |
98 | @staticmethod
99 | def get_normalization_transform():
100 | transform = transforms.Normalize((0.4914, 0.4822, 0.4465),
101 | (0.2470, 0.2435, 0.2615))
102 | return transform
103 |
104 | @staticmethod
105 | def get_denormalization_transform():
106 | transform = DeNormalize((0.4914, 0.4822, 0.4465),
107 | (0.2470, 0.2435, 0.2615))
108 | return transform
109 |
110 | @staticmethod
111 | def get_scheduler(model, args):
112 | return None
113 |
--------------------------------------------------------------------------------
/models/joint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torch.optim import SGD
7 |
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 | from datasets.utils.validation import ValidationDataset
11 | from utils.status import progress_bar
12 | import torch
13 | import numpy as np
14 | import math
15 | from torchvision import transforms
16 |
17 |
18 | def get_parser() -> ArgumentParser:
19 | parser = ArgumentParser(description='Joint training: a strong, simple baseline.')
20 | add_management_args(parser)
21 | add_experiment_args(parser)
22 | return parser
23 |
24 |
25 | class Joint(ContinualModel):
26 | NAME = 'joint'
27 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
28 |
29 | def __init__(self, backbone, loss, args, transform):
30 | super(Joint, self).__init__(backbone, loss, args, transform)
31 | self.old_data = []
32 | self.old_labels = []
33 | self.current_task = 0
34 |
35 | def end_task(self, dataset):
36 | if dataset.SETTING != 'domain-il':
37 | self.old_data.append(dataset.train_loader.dataset.data)
38 | self.old_labels.append(torch.tensor(dataset.train_loader.dataset.targets))
39 | self.current_task += 1
40 |
41 | # # for non-incremental joint training
42 | if len(dataset.test_loaders) != dataset.N_TASKS: return
43 |
44 | # reinit network
45 | self.net = dataset.get_backbone()
46 | self.net.to(self.device)
47 | self.net.train()
48 | self.opt = SGD(self.net.parameters(), lr=self.args.lr)
49 |
50 | # prepare dataloader
51 | all_data, all_labels = None, None
52 | for i in range(len(self.old_data)):
53 | if all_data is None:
54 | all_data = self.old_data[i]
55 | all_labels = self.old_labels[i]
56 | else:
57 | all_data = np.concatenate([all_data, self.old_data[i]])
58 | all_labels = np.concatenate([all_labels, self.old_labels[i]])
59 |
60 | transform = dataset.TRANSFORM if dataset.TRANSFORM is not None else transforms.ToTensor()
61 | temp_dataset = ValidationDataset(all_data, all_labels, transform=transform)
62 | loader = torch.utils.data.DataLoader(temp_dataset, batch_size=self.args.batch_size, shuffle=True)
63 |
64 | # train
65 | for e in range(self.args.n_epochs):
66 | for i, batch in enumerate(loader):
67 | inputs, labels = batch
68 | inputs, labels = inputs.to(self.device), labels.to(self.device)
69 |
70 | self.opt.zero_grad()
71 | outputs = self.net(inputs)
72 | loss = self.loss(outputs, labels.long())
73 | loss.backward()
74 | self.opt.step()
75 | progress_bar(i, len(loader), e, 'J', loss.item())
76 | else:
77 | self.old_data.append(dataset.train_loader)
78 | # train
79 | if len(dataset.test_loaders) != dataset.N_TASKS: return
80 |
81 | all_inputs = []
82 | all_labels = []
83 | for source in self.old_data:
84 | for x, l, _ in source:
85 | all_inputs.append(x)
86 | all_labels.append(l)
87 | all_inputs = torch.cat(all_inputs)
88 | all_labels = torch.cat(all_labels)
89 | bs = self.args.batch_size
90 | scheduler = dataset.get_scheduler(self, self.args)
91 |
92 | for e in range(self.args.n_epochs):
93 | order = torch.randperm(len(all_inputs))
94 | for i in range(int(math.ceil(len(all_inputs) / bs))):
95 | inputs = all_inputs[order][i * bs: (i+1) * bs]
96 | labels = all_labels[order][i * bs: (i+1) * bs]
97 | inputs, labels = inputs.to(self.device), labels.to(self.device)
98 | self.opt.zero_grad()
99 | outputs = self.net(inputs)
100 | loss = self.loss(outputs, labels.long())
101 | loss.backward()
102 | self.opt.step()
103 | progress_bar(i, int(math.ceil(len(all_inputs) / bs)), e, 'J', loss.item())
104 |
105 | if scheduler is not None:
106 | scheduler.step()
107 |
108 | def observe(self, inputs, labels, not_aug_inputs):
109 | return 0
110 |
--------------------------------------------------------------------------------
/utils/feature_gradient_svd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy as np
3 | import matplotlib
4 | matplotlib.use("Agg")
5 | import matplotlib.pyplot as plt
6 | import copy
7 | import math
8 | import torch
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from torch.nn import functional as F
12 | import math
13 |
14 | def visual_svd(s_list):
15 | x = np.arange(0,512)
16 | from matplotlib.ticker import MaxNLocator
17 | plt.ion()
18 | plt.clf()
19 | for i in range(len(s_list)):
20 | s_list[i] = s_list[i] / torch.sum(s_list[i])
21 | plt.plot(x, torch.log(s_list[3]), label=method_list[3])
22 | plt.plot(x, torch.log(s_list[0]), label=method_list[0])
23 | plt.plot(x, torch.log(s_list[1]), label=method_list[1])
24 | plt.plot(x, torch.log(s_list[2]), label=method_list[2])
25 | plt.plot(x, torch.log(s_list[4]), label=method_list[4])
26 | plt.xlabel('Dimension', fontsize=14)
27 | plt.ylabel('Log Singular Value', fontsize=14)
28 | plt.legend()
29 |
30 | ax = plt.gca()
31 | bwith = 1.
32 | ax.spines['top'].set_linewidth(bwith)
33 | ax.spines['right'].set_linewidth(bwith)
34 | ax.spines['bottom'].set_linewidth(bwith)
35 | ax.spines['left'].set_linewidth(bwith)
36 | plt.tick_params(labelsize=14)
37 | plt.savefig('./svd.pdf')
38 | plt.show
39 |
40 | def trans_feat(feat, label, ori_feat, new_feat, new_label):
41 | feat_all = []
42 | label_all = []
43 | new_feat_all = []
44 | new_label_all = []
45 | ori_feat_all = []
46 | for len_i in range(label.shape[0]):
47 | if len(feat[len_i]) == len(feat[2]):
48 | feat_all.append(feat[len_i])
49 | label_all.append(label[len_i])
50 | ori_feat_all.append(ori_feat[len_i])
51 | new_feat_all.append(new_feat[len_i])
52 | new_label_all.append(new_label[len_i])
53 | feat = torch.tensor(list(feat_all))
54 | label = torch.tensor(list(label_all))
55 | ori_feat = torch.tensor(list(ori_feat_all))
56 | new_feat = torch.tensor(list(new_feat_all))
57 | new_label = torch.tensor(list(new_label_all))
58 |
59 | new_feat = torch.reshape(new_feat, (new_feat.shape[0] * new_feat.shape[1], new_feat.shape[2]))
60 | new_label = torch.reshape(new_label, (new_label.shape[0] * new_label.shape[1],))
61 | feat = torch.reshape(feat, (feat.shape[0] * feat.shape[1], feat.shape[2]))
62 | ori_feat = torch.reshape(ori_feat, (ori_feat.shape[0] * ori_feat.shape[1], ori_feat.shape[2]))
63 | label = torch.reshape(label, (label.shape[0] * label.shape[1],))
64 | return feat, label, ori_feat, new_feat, new_label
65 |
66 | EPS = 1E-20
67 | current_task = '5'
68 | target_type = 'target_buf_labels'
69 | gamma_loss = '1.0'
70 | noise_type = 'noise'
71 | method2 = 'gaussian'
72 | c_theta = '45.0'
73 | para_scale = '1.0'
74 |
75 | s_list = []
76 | method_list = ['drop_self', 'drop_new', 'class_mean', 'gaussian', 'none']
77 | for method2 in method_list:
78 | if method2 == 'none':
79 | target_type = 'new_labels'
80 | gamma_loss = '50.0'
81 | noise_type = 'adv'
82 |
83 | epoch = 50
84 | name = str(current_task) + '_' + target_type + \
85 | '_' + str(gamma_loss) + '_' + noise_type + \
86 | '_' + method2 + '_' + str(c_theta) +\
87 | '_' + str(para_scale)
88 | model_name = target_type + \
89 | '_' + str(gamma_loss) + '_' + noise_type + \
90 | '_' + method2 + '_' + str(c_theta) +\
91 | '_' + str(para_scale)
92 |
93 | model_path = './output/task_models/seq-cifar100/%s' %model_name + '/task_5_model.ph'
94 | net = torch.load(model_path)
95 | net.train()
96 | feat = np.load('output/buff_featurte_task_%s.npy' %name, allow_pickle=True)
97 | label = np.load('output/buff_labels_task_%s.npy' %name, allow_pickle=True)
98 | new_feat = np.load('output/new_featurte_task_%s.npy' %name, allow_pickle=True)
99 | new_label = np.load('output/new_labels_task_%s.npy' %name, allow_pickle=True)
100 | ori_feat = np.load('output/ori_buffer_feat_task_%s.npy' %name, allow_pickle=True)
101 | feat, label, ori_feat, new_feat, new_label = trans_feat(feat, label, ori_feat, new_feat, new_label)
102 |
103 | aug_feat_1, ori_feat_1 = copy.deepcopy(feat), copy.deepcopy(ori_feat)
104 | feat_noise = aug_feat_1.detach() - ori_feat_1.detach()
105 |
106 | feat = feat + feat_noise
107 | num_per_epoch = int(feat_noise.shape[0]/epoch)
108 | mask = label[-num_per_epoch:] <80
109 |
110 | feat_all = torch.cat((feat, new_feat), 0)
111 | label_all = torch.cat((label, new_label), 0)
112 |
113 | pred = net.classifier(feat_all)
114 | loss = F.cross_entropy(pred, label_all)
115 | loss.backward()
116 |
117 | u, s, v = torch.svd(feat.grad[-num_per_epoch:][mask])
118 | s_list.append(s)
119 |
120 | method_list = ['V-self-drop', 'V-new-drop', 'V-trans', 'V-gaussian', 'V-adv']
121 | visual_svd(s_list)
--------------------------------------------------------------------------------
/models/rpc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from utils.buffer import Buffer
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 | from datasets import get_dataset
11 |
12 | def dsimplex(num_classes=10):
13 | def simplex_coordinates2(m):
14 | # add the credit
15 | import numpy as np
16 |
17 | x = np.zeros([m, m + 1])
18 | for j in range(0, m):
19 | x[j, j] = 1.0
20 |
21 | a = (1.0 - np.sqrt(float(1 + m))) / float(m)
22 |
23 | for i in range(0, m):
24 | x[i, m] = a
25 |
26 | # Adjust coordinates so the centroid is at zero.
27 | c = np.zeros(m)
28 | for i in range(0, m):
29 | s = 0.0
30 | for j in range(0, m + 1):
31 | s = s + x[i, j]
32 | c[i] = s / float(m + 1)
33 |
34 | for j in range(0, m + 1):
35 | for i in range(0, m):
36 | x[i, j] = x[i, j] - c[i]
37 |
38 | # Scale so each column has norm 1. UNIT NORMALIZED
39 | s = 0.0
40 | for i in range(0, m):
41 | s = s + x[i, 0] ** 2
42 | s = np.sqrt(s)
43 |
44 | for j in range(0, m + 1):
45 | for i in range(0, m):
46 | x[i, j] = x[i, j] / s
47 |
48 | return x
49 |
50 | feat_dim = num_classes - 1
51 | ds = simplex_coordinates2(feat_dim)
52 | return ds
53 |
54 | def get_parser() -> ArgumentParser:
55 | parser = ArgumentParser(description='Continual learning via'
56 | ' Experience Replay.')
57 | add_management_args(parser)
58 | add_experiment_args(parser)
59 | add_rehearsal_args(parser)
60 | return parser
61 |
62 |
63 | class RPC(ContinualModel):
64 | NAME = 'rpc'
65 | COMPATIBILITY = ['class-il', 'task-il']
66 |
67 | def __init__(self, backbone, loss, args, transform):
68 | super(RPC, self).__init__(backbone, loss, args, transform)
69 | self.buffer = Buffer(self.args.buffer_size, self.device)
70 | self.cpt = get_dataset(args).N_CLASSES_PER_TASK
71 | self.tasks = get_dataset(args).N_TASKS
72 | self.task=0
73 | self.rpchead = torch.from_numpy(dsimplex(self.cpt * self.tasks)).float().to(self.device)
74 |
75 | def forward(self, x):
76 | x = self.net(x)[:, :-1]
77 | x = x @ self.rpchead
78 | return x
79 |
80 | def end_task(self, dataset):
81 | # reduce coreset
82 | if self.task > 0:
83 | examples_per_class = self.args.buffer_size // ((self.task + 1) * self.cpt)
84 | buf_x, buf_lab = self.buffer.get_all_data()
85 | self.buffer.empty()
86 | for tl in buf_lab.unique():
87 | idx = tl == buf_lab
88 | ex, lab = buf_x[idx], buf_lab[idx]
89 | first = min(ex.shape[0], examples_per_class)
90 | self.buffer.add_data(
91 | examples=ex[:first],
92 | labels = lab[:first]
93 | )
94 |
95 | # add new task
96 | examples_last_task = self.buffer.buffer_size - self.buffer.num_seen_examples
97 | examples_per_class = examples_last_task // self.cpt
98 | ce = torch.tensor([examples_per_class] * self.cpt).int()
99 | ce[torch.randperm(self.cpt)[:examples_last_task - (examples_per_class * self.cpt)]] += 1
100 |
101 | with torch.no_grad():
102 | for data in dataset.train_loader:
103 | _, labels, not_aug_inputs = data
104 | not_aug_inputs = not_aug_inputs.to(self.device)
105 | if all(ce == 0):
106 | break
107 |
108 | flags = torch.zeros(len(labels)).bool()
109 | for j in range(len(flags)):
110 | if ce[labels[j] % self.cpt] > 0:
111 | flags[j] = True
112 | ce[labels[j] % self.cpt] -= 1
113 |
114 | self.buffer.add_data(examples=not_aug_inputs[flags],
115 | labels=labels[flags])
116 | self.task += 1
117 |
118 | def observe(self, inputs, labels, not_aug_inputs):
119 | self.opt.zero_grad()
120 | if not self.buffer.is_empty():
121 | buf_inputs, buf_labels = self.buffer.get_data(
122 | self.args.minibatch_size, transform=self.transform)
123 | inputs = torch.cat((inputs, buf_inputs))
124 | labels = torch.cat((labels, buf_labels))
125 |
126 | outputs = self.net(inputs)
127 | losses = self.loss(outputs, labels, reduction='none')
128 | loss = losses.mean()
129 |
130 | loss.backward()
131 | self.opt.step()
132 |
133 |
134 | return loss.item()
135 |
--------------------------------------------------------------------------------
/utils/plot_acc_curve.py:
--------------------------------------------------------------------------------
1 | from turtle import color
2 | import numpy as np
3 | import numpy as np
4 | import matplotlib
5 | matplotlib.use("Agg")
6 | import matplotlib.pyplot as plt
7 | import copy
8 | import math
9 | import torch
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from torch.nn import functional as F
13 | import math
14 |
15 |
16 | # Baseline = [83.05, 62.15, 46.12, 40.19, 31.08]
17 | # Self_drop = [83.05, 61.32, 49.02, 41.62, 33.67]
18 | # gaussian = [83.05, 64.05, 51.13, 42.66, 37.29]
19 | # vMF = [83.05, 66.45, 53.22, 44.17, 38.76]
20 | # new_drop = [83.05, 65.15, 52.13, 43.93, 38.75]
21 | # Trans = [83.05, 65.88, 53.10, 45.30, 39.78]
22 | # Adv = [83.05, 66.92, 54.58, 46.72, 41.02]
23 | # # plt.figure(figsize=(8,8))
24 | # title = 'CIFAR-100'
25 | # plt.title(title, fontsize=14)
26 | # plt.ylim(30, 68)
27 | # plt.xlabel("Continual Task", fontsize=14)
28 | # plt.ylabel("Accuracy", fontsize=14)
29 |
30 | # plt.grid(linewidth = 1.5)
31 | # x = ['Task1', 'Task2', 'Task3', 'Task4', 'Task5']
32 | # plt.plot(x, Baseline, alpha=0.8, label="ER", linewidth=3, marker= 'o')
33 | # plt.plot(x, Self_drop, alpha=0.8, label="DOA-self", linewidth=3, marker= 'o')
34 | # plt.plot(x, gaussian, alpha=0.8, label="Gaussian", linewidth=3, marker= 'o')
35 | # plt.plot(x, vMF, alpha=0.8, label="vMF", linewidth=3, marker= 'o')
36 | # plt.plot(x, new_drop, alpha=0.8, label="DOA", linewidth=3, marker= 'o')
37 | # plt.plot(x, Trans, alpha=0.8, label="VT", linewidth=3, marker= 'o')
38 | # plt.plot(x, Adv, alpha=0.8, label="WAP", linewidth=3, marker= 'o')
39 | # plt.tick_params(labelsize=14)
40 | # ax = plt.gca()
41 | # bwith = 1.
42 | # ax.spines['top'].set_linewidth(bwith)
43 | # ax.spines['right'].set_linewidth(bwith)
44 | # ax.spines['bottom'].set_linewidth(bwith)
45 | # ax.spines['left'].set_linewidth(bwith)
46 | # plt.legend(loc="best", fontsize=14, edgecolor='black')
47 | # plt.show()
48 | # plt.savefig('./acc_curve_%s.pdf' %title, bbox_inches = 'tight')
49 |
50 |
51 |
52 | # Baseline = [74.6, 52.2, 40.63, 32.22, 24.6, 20.38, 16.71, 14.68, 12.17, 11.24]
53 | # Self_drop = [74.6, 53.75, 40.33, 32.42, 26.94, 21.9, 17.17, 15.41, 12.31, 11.15]
54 | # gaussian = [74.6, 59.1, 46.63, 37.52, 28.74, 25.17, 20.56, 19.12, 16.21, 14.01]
55 | # vMF = [74.6, 57.21, 43.24, 34.55, 29.62, 24.27, 19.24, 19.95, 16.87, 14.62]
56 | # new_drop = [74.6, 55.85, 42.53, 34.8, 25.06, 21.8, 18.9, 19.06, 15.58, 14.96]
57 | # Trans = [74.6, 56.0, 44.63, 36.58, 31.1, 23.67, 19.93, 18.51, 14.91, 15.03]
58 | # Adv = [74.6, 59.2, 45.77, 35.58, 30.12, 26.52, 23.93, 20.35, 18.05, 16.68]
59 | # # plt.figure(figsize=(8,8))
60 | # title = 'TinyImageNet'
61 | # plt.title(title, fontsize=14)
62 | # plt.ylim(0, 60)
63 | # plt.xlabel("Continual Task", fontsize=14)
64 | # plt.ylabel("Accuracy", fontsize=14)
65 |
66 | # plt.grid(linewidth = 1.5)
67 | # x = ['Task1', 'Task2', 'Task3', 'Task4', 'Task5', 'Task6', 'Task7', 'Task8', 'Task9', 'Task10']
68 | # plt.plot(x, Baseline, alpha=0.8, label="ER", linewidth=3, marker= 'o')
69 | # plt.plot(x, Self_drop, alpha=0.8, label="DOA-self", linewidth=3, marker= 'o')
70 | # plt.plot(x, gaussian, alpha=0.8, label="Gaussian", linewidth=3, marker= 'o')
71 | # plt.plot(x, vMF, alpha=0.8, label="vMF", linewidth=3, marker= 'o')
72 | # plt.plot(x, new_drop, alpha=0.8, label="DOA", linewidth=3, marker= 'o')
73 | # plt.plot(x, Trans, alpha=0.8, label="VT", linewidth=3, marker= 'o')
74 | # plt.plot(x, Adv, alpha=0.8, label="WAP", linewidth=3, marker= 'o')
75 | # plt.tick_params(labelsize=10)
76 | # ax = plt.gca()
77 | # bwith = 1.
78 | # ax.spines['top'].set_linewidth(bwith)
79 | # ax.spines['right'].set_linewidth(bwith)
80 | # ax.spines['bottom'].set_linewidth(bwith)
81 | # ax.spines['left'].set_linewidth(bwith)
82 | # plt.legend(loc="best", fontsize=14, edgecolor='black')
83 | # plt.show()
84 | # plt.savefig('./acc_curve_%s.pdf' %title, bbox_inches = 'tight')
85 |
86 |
87 |
88 | Baseline = [98.25, 90.22, 74.22, 71.03, 62.53]
89 | Self_drop = [98.25, 90.7, 79.23, 70.88, 72.18]
90 | gaussian = [98.25, 88.82, 78.68, 74.23, 67.85]
91 | vMF = [98.25, 88.82, 78.87, 76.18, 71.58]
92 | new_drop = [98.25, 90.15, 72.68, 75.42, 67.88]
93 | Trans = [98.25, 89.5, 78.98, 69.24, 71.66]
94 | Adv = [98.25, 91.35, 80.03, 74.24, 72.99]
95 |
96 | # plt.figure(figsize=(8,8))
97 | title = 'CIFAR-10'
98 | plt.title(title, fontsize=14)
99 | plt.ylim(50, 100)
100 | plt.xlabel("Continual Task", fontsize=14)
101 | plt.ylabel("Accuracy", fontsize=14)
102 |
103 | plt.grid(linewidth = 1.5)
104 | x = ['Task1', 'Task2', 'Task3', 'Task4', 'Task5']
105 | plt.plot(x, Baseline, alpha=0.8, label="ER", linewidth=3, marker= 'o')
106 | plt.plot(x, Self_drop, alpha=0.8, label="DOA-self", linewidth=3, marker= 'o')
107 | plt.plot(x, gaussian, alpha=0.8, label="Gaussian", linewidth=3, marker= 'o')
108 | plt.plot(x, vMF, alpha=0.8, label="vMF", linewidth=3, marker= 'o')
109 | plt.plot(x, new_drop, alpha=0.8, label="DOA", linewidth=3, marker= 'o')
110 | plt.plot(x, Trans, alpha=0.8, label="VT", linewidth=3, marker= 'o')
111 | plt.plot(x, Adv, alpha=0.8, label="WAP", linewidth=3, marker= 'o')
112 | plt.tick_params(labelsize=14)
113 | ax = plt.gca()
114 | bwith = 1.
115 | ax.spines['top'].set_linewidth(bwith)
116 | ax.spines['right'].set_linewidth(bwith)
117 | ax.spines['bottom'].set_linewidth(bwith)
118 | ax.spines['left'].set_linewidth(bwith)
119 | plt.legend(loc="best", fontsize=14, edgecolor='black')
120 | plt.show()
121 | plt.savefig('./acc_curve_%s.pdf' %title, bbox_inches = 'tight')
122 |
--------------------------------------------------------------------------------
/utils/method_variation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy as np
3 | import matplotlib
4 | matplotlib.use("Agg")
5 | import matplotlib.pyplot as plt
6 | import copy
7 | import math
8 | import torch
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from torch.nn import functional as F
12 | import math
13 |
14 | def visual_svd(s_list):
15 | x = np.arange(0,512)
16 | from matplotlib.ticker import MaxNLocator
17 | plt.ion()
18 | plt.clf()
19 | for i in range(len(s_list)):
20 | s_list[i] = s_list[i] / torch.sum(s_list[i])
21 | plt.plot(x, torch.log(s_list[3]), label=method_list[3])
22 | plt.plot(x, torch.log(s_list[0]), label=method_list[0])
23 | plt.plot(x, torch.log(s_list[1]), label=method_list[1])
24 | plt.plot(x, torch.log(s_list[2]), label=method_list[2])
25 | plt.plot(x, torch.log(s_list[4]), label=method_list[4])
26 | plt.xlabel('Dimension', fontsize=14)
27 | plt.ylabel('Log Singular Value', fontsize=14)
28 | plt.legend()
29 |
30 | ax = plt.gca()
31 | bwith = 1.
32 | ax.spines['top'].set_linewidth(bwith)
33 | ax.spines['right'].set_linewidth(bwith)
34 | ax.spines['bottom'].set_linewidth(bwith)
35 | ax.spines['left'].set_linewidth(bwith)
36 | plt.tick_params(labelsize=14)
37 | plt.savefig('./svd.pdf')
38 | plt.show
39 |
40 | def trans_feat(feat, label, ori_feat, new_feat, new_label):
41 | feat_all = []
42 | label_all = []
43 | new_feat_all = []
44 | new_label_all = []
45 | ori_feat_all = []
46 | for len_i in range(label.shape[0]):
47 | if len(feat[len_i]) == len(feat[2]):
48 | feat_all.append(feat[len_i])
49 | label_all.append(label[len_i])
50 | ori_feat_all.append(ori_feat[len_i])
51 | new_feat_all.append(new_feat[len_i])
52 | new_label_all.append(new_label[len_i])
53 | feat = torch.tensor(list(feat_all))
54 | label = torch.tensor(list(label_all))
55 | ori_feat = torch.tensor(list(ori_feat_all))
56 | new_feat = torch.tensor(list(new_feat_all))
57 | new_label = torch.tensor(list(new_label_all))
58 |
59 | new_feat = torch.reshape(new_feat, (new_feat.shape[0] * new_feat.shape[1], new_feat.shape[2]))
60 | new_label = torch.reshape(new_label, (new_label.shape[0] * new_label.shape[1],))
61 | feat = torch.reshape(feat, (feat.shape[0] * feat.shape[1], feat.shape[2]))
62 | ori_feat = torch.reshape(ori_feat, (ori_feat.shape[0] * ori_feat.shape[1], ori_feat.shape[2]))
63 | label = torch.reshape(label, (label.shape[0] * label.shape[1],))
64 | return feat, label, ori_feat, new_feat, new_label
65 |
66 | def compute_variance(class_feat, m_feature):
67 | RMSSTD = 0
68 | for i_ter in range(class_feat.shape[0]):
69 | RMSSTD += 1 - torch.cosine_similarity(m_feature[0], class_feat[i_ter], dim=0)
70 | RMSSTD = RMSSTD / (class_feat.shape[0])
71 | return RMSSTD
72 |
73 | EPS = 1E-20
74 | current_task = '5'
75 | target_type = 'target_buf_labels'
76 | gamma_loss = '1.0'
77 | noise_type = 'noise'
78 | method2 = 'gaussian'
79 | c_theta = '45.0'
80 | para_scale = '1.0'
81 |
82 | s_list = []
83 | # method_list = ['drop_self', 'drop_new', 'class_mean', 'gaussian', 'none'] #'drop_self', 'drop_new', 'class_mean', 'gaussian',
84 | method_list = ['drop_new', 'drop_self']
85 | for idx, method2 in enumerate(method_list):
86 | if method2 == 'none':
87 | target_type = 'new_labels'
88 | gamma_loss = '50.0'
89 | noise_type = 'adv'
90 | if idx == 0:
91 | para_scale = '1.0'
92 | if idx == 1:
93 | para_scale = '1.5'
94 |
95 | name = str(current_task) + '_' + target_type + \
96 | '_' + str(gamma_loss) + '_' + noise_type + \
97 | '_' + method2 + '_' + str(c_theta) +\
98 | '_' + str(para_scale)
99 |
100 | # dirs = './out_gradient/' + 'mnist_2d/gaussian/'+ 'er_ours_3_50_task_2_model.ph'
101 | feat = np.load('output_7_22_all_method/buff_featurte_task_%s.npy' %name, allow_pickle=True)
102 | label = np.load('output_7_22_all_method/buff_labels_task_%s.npy' %name, allow_pickle=True)
103 | new_feat = np.load('output_7_22_all_method/new_featurte_task_%s.npy' %name, allow_pickle=True)
104 | new_label = np.load('output_7_22_all_method/new_labels_task_%s.npy' %name, allow_pickle=True)
105 | ori_feat = np.load('output_7_22_all_method/ori_buffer_feat_task_%s.npy' %name, allow_pickle=True)
106 | feat, label, ori_feat, new_feat, new_label = trans_feat(feat, label, ori_feat, new_feat, new_label)
107 |
108 | aug_feat_1, ori_feat_1 = copy.deepcopy(feat), copy.deepcopy(ori_feat)
109 | feat_noise = aug_feat_1.detach() - ori_feat_1.detach()
110 |
111 | epoch = 50
112 | num_per_epoch = int(feat_noise.shape[0]/epoch)
113 | all_feat = torch.cat((feat, new_feat), 0)
114 | all_label = torch.cat((label, new_label), 0)
115 | RMSSTD_list = []
116 | for i_class in range(0,100):
117 | mask = all_label == i_class
118 | class_feat = all_feat[mask]
119 |
120 | class_feat = F.normalize(class_feat, dim=1, p=2)
121 | m_feature = torch.mean(class_feat, dim=0).unsqueeze(0)
122 | m_feature = F.normalize(m_feature, dim=1, p=2)
123 |
124 | RMSSTD = compute_variance(class_feat, m_feature)
125 |
126 | RMSSTD_list.append(RMSSTD)
127 |
128 | print(sum(RMSSTD_list[:80])/80, sum(RMSSTD_list[80:])/20)
129 |
130 |
131 |
132 | # method_list = ['V-self-drop', 'V-new-drop', 'V-trans', 'V-gaussian', 'V-adv']
133 | # visual_svd(s_list)
--------------------------------------------------------------------------------
/utils/ring_buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import numpy as np
8 | from typing import Tuple
9 | from torchvision import transforms
10 |
11 |
12 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int:
13 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size
14 |
15 |
16 | class RingBuffer:
17 | """
18 | The memory buffer of rehearsal method.
19 | """
20 | def __init__(self, buffer_size, device, n_tasks):
21 | self.buffer_size = buffer_size
22 | self.buffer_portion_size = buffer_size // n_tasks
23 | self.device = device
24 | self.task_number = 0
25 | self.num_seen_examples = 0
26 | self.attributes = ['examples', 'labels', 'logits', 'task_labels']
27 |
28 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor,
29 | logits: torch.Tensor, task_labels: torch.Tensor) -> None:
30 | """
31 | Initializes just the required tensors.
32 | :param examples: tensor containing the images
33 | :param labels: tensor containing the labels
34 | :param logits: tensor containing the outputs of the network
35 | :param task_labels: tensor containing the task labels
36 | """
37 | for attr_str in self.attributes:
38 | attr = eval(attr_str)
39 | if attr is not None and not hasattr(self, attr_str):
40 | typ = torch.int64 if attr_str.endswith('els') else torch.float32
41 | setattr(self, attr_str, torch.zeros((self.buffer_size,
42 | *attr.shape[1:]), dtype=typ, device=self.device))
43 |
44 | self.labels -= 1
45 |
46 | def add_data(self, examples, labels=None, logits=None, task_labels=None):
47 | """
48 | Adds the data to the memory buffer according to the reservoir strategy.
49 | :param examples: tensor containing the images
50 | :param labels: tensor containing the labels
51 | :param logits: tensor containing the outputs of the network
52 | :param task_labels: tensor containing the task labels
53 | :return:
54 | """
55 | if not hasattr(self, 'examples'):
56 | self.init_tensors(examples, labels, logits, task_labels)
57 |
58 | for i in range(examples.shape[0]):
59 | index = ring(self.num_seen_examples, self.buffer_portion_size, self.task_number)
60 | self.num_seen_examples += 1
61 | if index >= 0:
62 | self.examples[index] = examples[i].to(self.device)
63 | if labels is not None:
64 | self.labels[index] = labels[i].to(self.device)
65 | if logits is not None:
66 | self.logits[index] = logits[i].to(self.device)
67 | if task_labels is not None:
68 | self.task_labels[index] = task_labels[i].to(self.device)
69 |
70 | def get_data(self, size: int, transform: transforms=None) -> Tuple:
71 | """
72 | Random samples a batch of size items.
73 | :param size: the number of requested items
74 | :param transform: the transformation to be applied (data augmentation)
75 | :return:
76 | """
77 | populated_portion_length = (self.labels != -1).sum().item()
78 |
79 | if size > populated_portion_length:
80 | size = populated_portion_length
81 |
82 |
83 | choice = np.random.choice(populated_portion_length, size=size, replace=False)
84 | if transform is None: transform = lambda x: x
85 | ret_tuple = (torch.stack([transform(ee.cpu())
86 | for ee in self.examples[choice]]).to(self.device),)
87 | for attr_str in self.attributes[1:]:
88 | if hasattr(self, attr_str):
89 | attr = getattr(self, attr_str)
90 | ret_tuple += (attr[choice],)
91 |
92 | return ret_tuple
93 |
94 | def is_empty(self) -> bool:
95 | """
96 | Returns true if the buffer is empty, false otherwise.
97 | """
98 | if self.num_seen_examples == 0 and self.task_number == 0:
99 | return True
100 | else:
101 | return False
102 |
103 | def get_all_data(self, transform: transforms=None) -> Tuple:
104 | """
105 | Return all the items in the memory buffer.
106 | :param transform: the transformation to be applied (data augmentation)
107 | :return: a tuple with all the items in the memory buffer
108 | """
109 | if transform is None: transform = lambda x: x
110 | ret_tuple = (torch.stack([transform(ee.cpu())
111 | for ee in self.examples]).to(self.device),)
112 | for attr_str in self.attributes[1:]:
113 | if hasattr(self, attr_str):
114 | attr = getattr(self, attr_str)
115 | ret_tuple += (attr,)
116 | return ret_tuple
117 |
118 | def empty(self) -> None:
119 | """
120 | Set all the tensors to None.
121 | """
122 | for attr_str in self.attributes:
123 | if hasattr(self, attr_str):
124 | delattr(self, attr_str)
125 | self.num_seen_examples = 0
126 |
--------------------------------------------------------------------------------
/utils/args.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from argparse import ArgumentParser
7 | from datasets import NAMES as DATASET_NAMES
8 | from models import get_all_models
9 |
10 |
11 | def add_experiment_args(parser: ArgumentParser) -> None:
12 | """
13 | Adds the arguments used by all the models.
14 | :param parser: the parser instance
15 | """
16 | parser.add_argument('--dataset', type=str, required=True,
17 | choices=DATASET_NAMES,
18 | help='Which dataset to perform experiments on.')
19 | parser.add_argument('--model', type=str, required=True,
20 | help='Model name.', choices=get_all_models())
21 |
22 | parser.add_argument('--lr', type=float, required=True,
23 | help='Learning rate.')
24 |
25 | parser.add_argument('--optim_wd', type=float, default=0.,
26 | help='optimizer weight decay.')
27 | parser.add_argument('--optim_mom', type=float, default=0.,
28 | help='optimizer momentum.')
29 | parser.add_argument('--optim_nesterov', type=int, default=0,
30 | help='optimizer nesterov momentum.')
31 |
32 | parser.add_argument('--n_epochs', type=int,
33 | help='Batch size.')
34 | parser.add_argument('--batch_size', type=int,
35 | help='Batch size.')
36 |
37 | def add_management_args(parser: ArgumentParser) -> None:
38 | parser.add_argument('--seed', type=int, default=None,
39 | help='The random seed.')
40 | parser.add_argument('--notes', type=str, default=None,
41 | help='Notes for this run.')
42 |
43 | parser.add_argument('--non_verbose', action='store_true')
44 | parser.add_argument('--csv_log', action='store_true',
45 | help='Enable csv logging')
46 | parser.add_argument('--tensorboard', action='store_true',
47 | help='Enable tensorboard logging')
48 | parser.add_argument('--validation', action='store_true',
49 | help='Test on the validation set')
50 |
51 | #############################################################
52 | parser.add_argument('--noise_std', type=float, default=0.01, required=False,
53 | help='noise_std', )
54 | parser.add_argument('--c_theta', type=float, default=30, required=False,
55 | help='c_theta', )
56 | parser.add_argument('--on_sphere', type=str, default='none', required=False,
57 | help='on_sphere.')
58 | parser.add_argument('--noise_type', type=str, default='adv', required=False,
59 | help='noise_type.')
60 | parser.add_argument('--noise_factor', type=float, default=0.01, required=False,
61 | help='noise_factor', )
62 | parser.add_argument('--drop_rate', type=float, default=0.5, required=False,
63 | help='drop_rate', )
64 | parser.add_argument('--mix_rate', type=float, default=0.5, required=False,
65 | help='mix_rate', )
66 |
67 | parser.add_argument('--drop_factor', type=float, default=0.7, required=False,
68 | help='drop_factor', )
69 | parser.add_argument('--gaussian_factor', type=float, default=0.3, required=False,
70 | help='gaussian_factor', )
71 |
72 | parser.add_argument('--para_scale', type=float, default=1, required=False,
73 | help='para_scale', )
74 | parser.add_argument('--inner_iter', type=float, default=5, required=False,
75 | help='inner_iter', )
76 | parser.add_argument('--gamma_loss', type=float, default=0.01, required=False,
77 | help='gamma', )
78 | parser.add_argument('--method2', default="mean", type=str,
79 | help='Directory where data files are stored.')
80 | parser.add_argument('--norm_add', default="none", type=str,
81 | help='Directory where data files are stored.')
82 | parser.add_argument('--target_type', default="mean", type=str,
83 | help='Directory where data files are stored.')
84 | parser.add_argument('--advloss', default="nega", type=str,
85 | help='Directory where data files are stored.')
86 |
87 | parser.add_argument('--optimizer', default="SGD", type=str,
88 | help='Directory where data files are stored.')
89 |
90 | parser.add_argument('--epsilon', default= 0.05, type=float,
91 | help='Directory where data files are stored.')
92 | parser.add_argument('--cos_temp', type=float, default=15, required=False,
93 | help='cos_temp', )
94 |
95 |
96 | def add_rehearsal_args(parser: ArgumentParser) -> None:
97 | """
98 | Adds the arguments used by all the rehearsal-based methods
99 | :param parser: the parser instance
100 | """
101 | parser.add_argument('--buffer_size', type=int, required=True,
102 | help='The size of the memory buffer.')
103 | parser.add_argument('--minibatch_size', type=int,
104 | help='The batch size of the memory buffer.')
105 |
--------------------------------------------------------------------------------
/utils/plot_test_variance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy # needed (don't change it)
7 | import importlib
8 | import os
9 | import sys
10 | import socket
11 | mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12 | sys.path.append(mammoth_path)
13 | sys.path.append(mammoth_path + '/datasets')
14 | sys.path.append(mammoth_path + '/backbone')
15 | sys.path.append(mammoth_path + '/models')
16 |
17 | from datasets import NAMES as DATASET_NAMES
18 | from models import get_all_models
19 | from argparse import ArgumentParser
20 | from utils.args import add_management_args
21 | from datasets import ContinualDataset
22 | from utils.continual_training import train as ctrain
23 | from datasets import get_dataset
24 | from models import get_model
25 | from utils.training import train, test_visual, test_variance
26 | from utils.best_args import best_args
27 | from utils.conf import set_random_seed
28 | import setproctitle
29 | import torch
30 | import uuid
31 | import datetime
32 |
33 | import numpy as np
34 | import random
35 | import torch
36 | import torch.nn as nn
37 |
38 |
39 | def lecun_fix():
40 | # Yann moved his website to CloudFlare. You need this now
41 | from six.moves import urllib
42 | opener = urllib.request.build_opener()
43 | opener.addheaders = [('User-agent', 'Mozilla/5.0')]
44 | urllib.request.install_opener(opener)
45 |
46 | def parse_args():
47 | parser = ArgumentParser(description='mammoth', allow_abbrev=False)
48 | parser.add_argument('--model', type=str, required=True,
49 | help='Model name.', choices=get_all_models())
50 | parser.add_argument('--load_best_args', action='store_true',
51 | help='Loads the best arguments for each method, '
52 | 'dataset and memory buffer.')
53 | torch.set_num_threads(4)
54 | add_management_args(parser)
55 | args = parser.parse_known_args()[0]
56 | mod = importlib.import_module('models.' + args.model)
57 |
58 | if args.load_best_args:
59 | parser.add_argument('--dataset', type=str, required=True,
60 | choices=DATASET_NAMES,
61 | help='Which dataset to perform experiments on.')
62 | if hasattr(mod, 'Buffer'):
63 | parser.add_argument('--buffer_size', type=int, required=True,
64 | help='The size of the memory buffer.')
65 | args = parser.parse_args()
66 | if args.model == 'joint':
67 | best = best_args[args.dataset]['sgd']
68 | else:
69 | best = best_args[args.dataset][args.model]
70 | if hasattr(mod, 'Buffer'):
71 | best = best[args.buffer_size]
72 | else:
73 | best = best[-1]
74 | get_parser = getattr(mod, 'get_parser')
75 | parser = get_parser()
76 | to_parse = sys.argv[1:] + ['--' + k + '=' + str(v) for k, v in best.items()]
77 | to_parse.remove('--load_best_args')
78 | args = parser.parse_args(to_parse)
79 | if args.model == 'joint' and args.dataset == 'mnist-360':
80 | args.model = 'joint_gcl'
81 | else:
82 | get_parser = getattr(mod, 'get_parser')
83 | parser = get_parser()
84 | args = parser.parse_args()
85 |
86 | if args.seed is not None:
87 | set_random_seed(args.seed)
88 |
89 | return args
90 |
91 | def main(args=None):
92 | lecun_fix()
93 | if args is None:
94 | args = parse_args()
95 |
96 | random.seed(args.seed)
97 | os.environ['PYTHONHASHSEED'] = str(args.seed) # 为了禁止hash随机化,使得实验可复现
98 | np.random.seed(args.seed)
99 | torch.manual_seed(args.seed)
100 | torch.cuda.manual_seed(args.seed)
101 | torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
102 | torch.backends.cudnn.benchmark = False
103 | torch.backends.cudnn.deterministic = True
104 |
105 | # Add uuid, timestamp and hostname for logging
106 | args.conf_jobnum = str(uuid.uuid4())
107 | args.conf_timestamp = str(datetime.datetime.now())
108 | args.conf_host = socket.gethostname()
109 | dataset = get_dataset(args)
110 |
111 | if args.n_epochs is None and isinstance(dataset, ContinualDataset):
112 | args.n_epochs = dataset.get_epochs()
113 | if args.batch_size is None:
114 | args.batch_size = dataset.get_batch_size()
115 | if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and args.minibatch_size is None:
116 | args.minibatch_size = dataset.get_minibatch_size()
117 |
118 | backbone = dataset.get_backbone()
119 | loss = dataset.get_loss()
120 | model = get_model(args, backbone, loss, dataset.get_transform())
121 |
122 | # model.net.linear.scale = args.cos_temp
123 |
124 | # set job name
125 | setproctitle.setproctitle('{}_{}_{}'.format(args.model, args.buffer_size if 'buffer_size' in args else 0, args.dataset))
126 |
127 | # test_visual(model, dataset, args)
128 | if isinstance(dataset, ContinualDataset):
129 | test_variance(model, dataset, args)
130 | else:
131 | assert not hasattr(model, 'end_task') or model.NAME == 'joint_gcl'
132 | ctrain(args)
133 |
134 |
135 | if __name__ == '__main__':
136 | main()
137 |
--------------------------------------------------------------------------------
/datasets/utils/continual_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from abc import abstractmethod
7 | from argparse import Namespace
8 | from torch import nn as nn
9 | from torchvision.transforms import transforms
10 | from torch.utils.data import DataLoader
11 | from typing import Tuple
12 | from torchvision import datasets
13 | import numpy as np
14 | import torch.optim
15 |
16 | class ContinualDataset:
17 | """
18 | Continual learning evaluation setting.
19 | """
20 | NAME = None
21 | SETTING = None
22 | N_CLASSES_PER_TASK = None
23 | N_TASKS = None
24 | TRANSFORM = None
25 |
26 | def __init__(self, args: Namespace) -> None:
27 | """
28 | Initializes the train and test lists of dataloaders.
29 | :param args: the arguments which contains the hyperparameters
30 | """
31 | self.train_loader = None
32 | self.test_loaders = []
33 | self.i = 0
34 | self.args = args
35 |
36 | @abstractmethod
37 | def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]:
38 | """
39 | Creates and returns the training and test loaders for the current task.
40 | The current training loader and all test loaders are stored in self.
41 | :return: the current training and test loaders
42 | """
43 | pass
44 |
45 | @staticmethod
46 | @abstractmethod
47 | def get_backbone() -> nn.Module:
48 | """
49 | Returns the backbone to be used for to the current dataset.
50 | """
51 | pass
52 |
53 | @staticmethod
54 | @abstractmethod
55 | def get_transform() -> transforms:
56 | """
57 | Returns the transform to be used for to the current dataset.
58 | """
59 | pass
60 |
61 | @staticmethod
62 | @abstractmethod
63 | def get_loss() -> nn.functional:
64 | """
65 | Returns the loss to be used for to the current dataset.
66 | """
67 | pass
68 |
69 | @staticmethod
70 | @abstractmethod
71 | def get_normalization_transform() -> transforms:
72 | """
73 | Returns the transform used for normalizing the current dataset.
74 | """
75 | pass
76 |
77 | @staticmethod
78 | @abstractmethod
79 | def get_denormalization_transform() -> transforms:
80 | """
81 | Returns the transform used for denormalizing the current dataset.
82 | """
83 | pass
84 |
85 | @staticmethod
86 | @abstractmethod
87 | def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler:
88 | """
89 | Returns the scheduler to be used for to the current dataset.
90 | """
91 | pass
92 |
93 | @staticmethod
94 | def get_epochs():
95 | pass
96 |
97 | @staticmethod
98 | def get_batch_size():
99 | pass
100 |
101 | @staticmethod
102 | def get_minibatch_size():
103 | pass
104 |
105 |
106 |
107 | def store_masked_loaders(train_dataset: datasets, test_dataset: datasets,
108 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]:
109 | """
110 | Divides the dataset into tasks.
111 | :param train_dataset: train dataset
112 | :param test_dataset: test dataset
113 | :param setting: continual learning setting
114 | :return: train and test loaders
115 | """
116 | train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i,
117 | np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
118 | test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i,
119 | np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
120 |
121 | train_dataset.data = train_dataset.data[train_mask]
122 | test_dataset.data = test_dataset.data[test_mask]
123 |
124 | train_dataset.targets = np.array(train_dataset.targets)[train_mask]
125 | test_dataset.targets = np.array(test_dataset.targets)[test_mask]
126 |
127 | train_loader = DataLoader(train_dataset,
128 | batch_size=setting.args.batch_size, shuffle=True, num_workers=4)
129 | test_loader = DataLoader(test_dataset,
130 | batch_size=setting.args.batch_size, shuffle=False, num_workers=4)
131 | setting.test_loaders.append(test_loader)
132 | setting.train_loader = train_loader
133 |
134 | setting.i += setting.N_CLASSES_PER_TASK
135 | return train_loader, test_loader
136 |
137 |
138 | def get_previous_train_loader(train_dataset: datasets, batch_size: int,
139 | setting: ContinualDataset) -> DataLoader:
140 | """
141 | Creates a dataloader for the previous task.
142 | :param train_dataset: the entire training set
143 | :param batch_size: the desired batch size
144 | :param setting: the continual dataset at hand
145 | :return: a dataloader
146 | """
147 | train_mask = np.logical_and(np.array(train_dataset.targets) >=
148 | setting.i - setting.N_CLASSES_PER_TASK, np.array(train_dataset.targets)
149 | < setting.i - setting.N_CLASSES_PER_TASK + setting.N_CLASSES_PER_TASK)
150 |
151 | train_dataset.data = train_dataset.data[train_mask]
152 | train_dataset.targets = np.array(train_dataset.targets)[train_mask]
153 |
154 | return DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
155 |
--------------------------------------------------------------------------------
/utils/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy # needed (don't change it)
7 | import importlib
8 | import os
9 | import sys
10 | import socket
11 | mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12 | sys.path.append(mammoth_path)
13 | sys.path.append(mammoth_path + '/datasets')
14 | sys.path.append(mammoth_path + '/backbone')
15 | sys.path.append(mammoth_path + '/models')
16 |
17 | from datasets import NAMES as DATASET_NAMES
18 | from models import get_all_models
19 | from argparse import ArgumentParser
20 | from utils.args import add_management_args
21 | from datasets import ContinualDataset
22 | from utils.continual_training import train as ctrain
23 | from datasets import get_dataset
24 | from models import get_model
25 | from utils.training import train, test_visual
26 | from utils.best_args import best_args
27 | from utils.conf import set_random_seed
28 | import setproctitle
29 | import torch
30 | import uuid
31 | import datetime
32 |
33 | import numpy as np
34 | import random
35 | import torch
36 | import torch.nn as nn
37 |
38 |
39 | def lecun_fix():
40 | # Yann moved his website to CloudFlare. You need this now
41 | from six.moves import urllib
42 | opener = urllib.request.build_opener()
43 | opener.addheaders = [('User-agent', 'Mozilla/5.0')]
44 | urllib.request.install_opener(opener)
45 |
46 | def parse_args():
47 | parser = ArgumentParser(description='mammoth', allow_abbrev=False)
48 | parser.add_argument('--model', type=str, required=True,
49 | help='Model name.', choices=get_all_models())
50 | parser.add_argument('--load_best_args', action='store_true',
51 | help='Loads the best arguments for each method, '
52 | 'dataset and memory buffer.')
53 | torch.set_num_threads(4)
54 | add_management_args(parser)
55 | args = parser.parse_known_args()[0]
56 | mod = importlib.import_module('models.' + args.model)
57 |
58 | if args.load_best_args:
59 | parser.add_argument('--dataset', type=str, required=True,
60 | choices=DATASET_NAMES,
61 | help='Which dataset to perform experiments on.')
62 | if hasattr(mod, 'Buffer'):
63 | parser.add_argument('--buffer_size', type=int, required=True,
64 | help='The size of the memory buffer.')
65 | args = parser.parse_args()
66 | if args.model == 'joint':
67 | best = best_args[args.dataset]['sgd']
68 | else:
69 | best = best_args[args.dataset][args.model]
70 | if hasattr(mod, 'Buffer'):
71 | best = best[args.buffer_size]
72 | else:
73 | best = best[-1]
74 | get_parser = getattr(mod, 'get_parser')
75 | parser = get_parser()
76 | to_parse = sys.argv[1:] + ['--' + k + '=' + str(v) for k, v in best.items()]
77 | to_parse.remove('--load_best_args')
78 | args = parser.parse_args(to_parse)
79 | if args.model == 'joint' and args.dataset == 'mnist-360':
80 | args.model = 'joint_gcl'
81 | else:
82 | get_parser = getattr(mod, 'get_parser')
83 | parser = get_parser()
84 | args = parser.parse_args()
85 |
86 | if args.seed is not None:
87 | set_random_seed(args.seed)
88 |
89 | return args
90 |
91 | def main(args=None):
92 | lecun_fix()
93 | if args is None:
94 | args = parse_args()
95 |
96 | random.seed(args.seed)
97 | os.environ['PYTHONHASHSEED'] = str(args.seed) # 为了禁止hash随机化,使得实验可复现
98 | np.random.seed(args.seed)
99 | torch.manual_seed(args.seed)
100 | torch.cuda.manual_seed(args.seed)
101 | torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
102 | torch.backends.cudnn.benchmark = False
103 | torch.backends.cudnn.deterministic = True
104 |
105 | # Add uuid, timestamp and hostname for logging
106 | args.conf_jobnum = str(uuid.uuid4())
107 | args.conf_timestamp = str(datetime.datetime.now())
108 | args.conf_host = socket.gethostname()
109 | dataset = get_dataset(args)
110 |
111 | if args.n_epochs is None and isinstance(dataset, ContinualDataset):
112 | args.n_epochs = dataset.get_epochs()
113 | if args.batch_size is None:
114 | args.batch_size = dataset.get_batch_size()
115 | if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and args.minibatch_size is None:
116 | args.minibatch_size = dataset.get_minibatch_size()
117 |
118 | backbone = dataset.get_backbone()
119 | loss = dataset.get_loss()
120 | model = get_model(args, backbone, loss, dataset.get_transform())
121 |
122 | # model.net.linear.scale = args.cos_temp
123 |
124 | # if args.model == 'lucir':
125 | # if args.dataset == 'seq-cifar10':
126 | # model.net.linear = nn.Linear(512, 10)
127 | # elif args.dataset == 'seq-cifar100':
128 | # model.net.linear = nn.Linear(512, 100)
129 | # else:
130 | # model.net.linear = nn.Linear(512, 200)
131 |
132 | # set job name
133 | setproctitle.setproctitle('{}_{}_{}'.format(args.model, args.buffer_size if 'buffer_size' in args else 0, args.dataset))
134 |
135 | # test_visual(model, dataset, args)
136 | if isinstance(dataset, ContinualDataset):
137 | train(model, dataset, args)
138 | else:
139 | assert not hasattr(model, 'end_task') or model.NAME == 'joint_gcl'
140 | ctrain(args)
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/datasets/seq_cifar100.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torchvision.datasets import CIFAR100
7 | import torchvision.transforms as transforms
8 | from backbone.ResNet18 import resnet18
9 | import torch.nn.functional as F
10 | from utils.conf import base_path
11 | from PIL import Image
12 | from datasets.utils.validation import get_train_val
13 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
14 | from typing import Tuple
15 | from datasets.transforms.denormalization import DeNormalize
16 | import torch.optim
17 |
18 | class TCIFAR100(CIFAR100):
19 | def __init__(self, root, train=True, transform=None,
20 | target_transform=None, download=False) -> None:
21 | self.root = root
22 | super(TCIFAR100, self).__init__(root, train, transform, target_transform, download=not self._check_integrity())
23 |
24 | class MyCIFAR100(CIFAR100):
25 | """
26 | Overrides the CIFAR100 dataset to change the getitem function.
27 | """
28 | def __init__(self, root, train=True, transform=None,
29 | target_transform=None, download=False) -> None:
30 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
31 | self.root = root
32 | super(MyCIFAR100, self).__init__(root, train, transform, target_transform, not self._check_integrity())
33 |
34 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
35 | """
36 | Gets the requested element from the dataset.
37 | :param index: index of the element to be returned
38 | :returns: tuple: (image, target) where target is index of the target class.
39 | """
40 | img, target = self.data[index], self.targets[index]
41 |
42 | # to return a PIL Image
43 | img = Image.fromarray(img, mode='RGB')
44 | original_img = img.copy()
45 |
46 | not_aug_img = self.not_aug_transform(original_img)
47 |
48 | if self.transform is not None:
49 | img = self.transform(img)
50 |
51 | if self.target_transform is not None:
52 | target = self.target_transform(target)
53 |
54 | if hasattr(self, 'logits'):
55 | return img, target, not_aug_img, self.logits[index]
56 |
57 | return img, target, not_aug_img
58 |
59 |
60 | class SequentialCIFAR100(ContinualDataset):
61 |
62 | NAME = 'seq-cifar100'
63 | SETTING = 'class-il'
64 | # N_CLASSES_PER_TASK = 10
65 | # N_TASKS = 10
66 | N_CLASSES_PER_TASK = 20
67 | N_TASKS = 5
68 | # N_CLASSES_PER_TASK = 50
69 | # N_TASKS = 2
70 | TRANSFORM = transforms.Compose(
71 | [transforms.RandomCrop(32, padding=4),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | transforms.Normalize((0.5071, 0.4867, 0.4408),
75 | (0.2675, 0.2565, 0.2761))])
76 |
77 | def get_examples_number(self):
78 | train_dataset = MyCIFAR100(base_path() + 'CIFAR10', train=True,
79 | download=True)
80 | return len(train_dataset.data)
81 |
82 | def get_data_loaders(self):
83 | transform = self.TRANSFORM
84 |
85 | test_transform = transforms.Compose(
86 | [transforms.ToTensor(), self.get_normalization_transform()])
87 |
88 | train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True,
89 | download=True, transform=transform)
90 | if self.args.validation:
91 | train_dataset, test_dataset = get_train_val(train_dataset,
92 | test_transform, self.NAME)
93 | else:
94 | test_dataset = TCIFAR100(base_path() + 'CIFAR100',train=False,
95 | download=True, transform=test_transform)
96 |
97 | train, test = store_masked_loaders(train_dataset, test_dataset, self)
98 |
99 | return train, test
100 |
101 | @staticmethod
102 | def get_transform():
103 | transform = transforms.Compose(
104 | [transforms.ToPILImage(), SequentialCIFAR100.TRANSFORM])
105 | return transform
106 |
107 | @staticmethod
108 | def get_backbone():
109 | return resnet18(SequentialCIFAR100.N_CLASSES_PER_TASK
110 | * SequentialCIFAR100.N_TASKS)
111 |
112 | @staticmethod
113 | def get_loss():
114 | return F.cross_entropy
115 |
116 | @staticmethod
117 | def get_normalization_transform():
118 | transform = transforms.Normalize((0.5071, 0.4867, 0.4408),
119 | (0.2675, 0.2565, 0.2761))
120 | return transform
121 |
122 | @staticmethod
123 | def get_denormalization_transform():
124 | transform = DeNormalize((0.5071, 0.4867, 0.4408),
125 | (0.2675, 0.2565, 0.2761))
126 | return transform
127 |
128 | @staticmethod
129 | def get_epochs():
130 | return 50
131 |
132 | @staticmethod
133 | def get_batch_size():
134 | return 32
135 |
136 | @staticmethod
137 | def get_minibatch_size():
138 | return SequentialCIFAR100.get_batch_size()
139 |
140 | @staticmethod
141 | def get_scheduler(model, args) -> torch.optim.lr_scheduler:
142 | model.opt = torch.optim.SGD(model.net.parameters(), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom)
143 | scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False)
144 | return scheduler
145 |
146 |
--------------------------------------------------------------------------------
/utils/loggers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import csv
7 | import os
8 | import sys
9 | from typing import Dict, Any
10 | from utils.metrics import *
11 |
12 | from utils import create_if_not_exists
13 | from utils.conf import base_path
14 | import numpy as np
15 |
16 | useless_args = ['dataset', 'tensorboard', 'validation', 'model',
17 | 'csv_log', 'notes', 'load_best_args']
18 |
19 |
20 | def print_mean_accuracy(mean_acc: np.ndarray, task_number: int,
21 | setting: str) -> None:
22 | """
23 | Prints the mean accuracy on stderr.
24 | :param mean_acc: mean accuracy value
25 | :param task_number: task index
26 | :param setting: the setting of the benchmark
27 | """
28 | if setting == 'domain-il':
29 | mean_acc, _ = mean_acc
30 | print('\nAccuracy for {} task(s): {} %'.format(
31 | task_number, round(mean_acc, 2)), file=sys.stderr)
32 | else:
33 | mean_acc_class_il, mean_acc_task_il = mean_acc
34 | print('\nAccuracy for {} task(s): \t [Class-IL]: {} %'
35 | ' \t [Task-IL]: {} %\n'.format(task_number, round(
36 | mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr)
37 |
38 |
39 | class CsvLogger:
40 | def __init__(self, setting_str: str, dataset_str: str,
41 | model_str: str) -> None:
42 | self.accs = []
43 | if setting_str == 'class-il':
44 | self.accs_mask_classes = []
45 | self.setting = setting_str
46 | self.dataset = dataset_str
47 | self.model = model_str
48 | self.fwt = None
49 | self.fwt_mask_classes = None
50 | self.bwt = None
51 | self.bwt_mask_classes = None
52 | self.forgetting = None
53 | self.forgetting_mask_classes = None
54 |
55 | def add_fwt(self, results, accs, results_mask_classes, accs_mask_classes):
56 | self.fwt = forward_transfer(results, accs)
57 | if self.setting == 'class-il':
58 | self.fwt_mask_classes = forward_transfer(results_mask_classes, accs_mask_classes)
59 |
60 | def add_bwt(self, results, results_mask_classes):
61 | self.bwt = backward_transfer(results)
62 | self.bwt_mask_classes = backward_transfer(results_mask_classes)
63 |
64 | def add_forgetting(self, results, results_mask_classes):
65 | self.forgetting = forgetting(results)
66 | self.forgetting_mask_classes = forgetting(results_mask_classes)
67 |
68 | def log(self, mean_acc: np.ndarray) -> None:
69 | """
70 | Logs a mean accuracy value.
71 | :param mean_acc: mean accuracy value
72 | """
73 | if self.setting == 'general-continual':
74 | self.accs.append(mean_acc)
75 | elif self.setting == 'domain-il':
76 | mean_acc, _ = mean_acc
77 | self.accs.append(mean_acc)
78 | else:
79 | mean_acc_class_il, mean_acc_task_il = mean_acc
80 | self.accs.append(mean_acc_class_il)
81 | self.accs_mask_classes.append(mean_acc_task_il)
82 |
83 | def write(self, args: Dict[str, Any]) -> None:
84 | """
85 | writes out the logged value along with its arguments.
86 | :param args: the namespace of the current experiment
87 | """
88 | for cc in useless_args:
89 | if cc in args:
90 | del args[cc]
91 |
92 | columns = list(args.keys())
93 |
94 | new_cols = []
95 | for i, acc in enumerate(self.accs):
96 | args['task' + str(i + 1)] = acc
97 | new_cols.append('task' + str(i + 1))
98 |
99 | args['forward_transfer'] = self.fwt
100 | new_cols.append('forward_transfer')
101 |
102 | args['backward_transfer'] = self.bwt
103 | new_cols.append('backward_transfer')
104 |
105 | args['forgetting'] = self.forgetting
106 | new_cols.append('forgetting')
107 |
108 | columns = new_cols + columns
109 |
110 | create_if_not_exists(base_path() + "results/" + self.setting)
111 | create_if_not_exists(base_path() + "results/" + self.setting +
112 | "/" + self.dataset)
113 | create_if_not_exists(base_path() + "results/" + self.setting +
114 | "/" + self.dataset + "/" + self.model)
115 |
116 | write_headers = False
117 | path = base_path() + "results/" + self.setting + "/" + self.dataset\
118 | + "/" + self.model + "/mean_accs.csv"
119 | if not os.path.exists(path):
120 | write_headers = True
121 | with open(path, 'a') as tmp:
122 | writer = csv.DictWriter(tmp, fieldnames=columns)
123 | if write_headers:
124 | writer.writeheader()
125 | writer.writerow(args)
126 |
127 | if self.setting == 'class-il':
128 | create_if_not_exists(base_path() + "results/task-il/"
129 | + self.dataset)
130 | create_if_not_exists(base_path() + "results/task-il/"
131 | + self.dataset + "/" + self.model)
132 |
133 | for i, acc in enumerate(self.accs_mask_classes):
134 | args['task' + str(i + 1)] = acc
135 |
136 | args['forward_transfer'] = self.fwt_mask_classes
137 | args['backward_transfer'] = self.bwt_mask_classes
138 | args['forgetting'] = self.forgetting_mask_classes
139 |
140 | write_headers = False
141 | path = base_path() + "results/task-il" + "/" + self.dataset + "/"\
142 | + self.model + "/mean_accs.csv"
143 | if not os.path.exists(path):
144 | write_headers = True
145 | with open(path, 'a') as tmp:
146 | writer = csv.DictWriter(tmp, fieldnames=columns)
147 | if write_headers:
148 | writer.writeheader()
149 | writer.writerow(args)
150 |
--------------------------------------------------------------------------------
/models/gem.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the gem_license file in the root of this source tree.
6 |
7 | import quadprog
8 |
9 | import numpy as np
10 | import torch
11 | from models.utils.continual_model import ContinualModel
12 |
13 | from utils.buffer import Buffer
14 | from utils.args import *
15 |
16 |
17 | def get_parser() -> ArgumentParser:
18 | parser = ArgumentParser(description='Continual learning via'
19 | ' Gradient Episodic Memory.')
20 | add_management_args(parser)
21 | add_experiment_args(parser)
22 | add_rehearsal_args(parser)
23 | # remove minibatch_size from parser
24 | for i in range(len(parser._actions)):
25 | if parser._actions[i].dest == 'minibatch_size':
26 | del parser._actions[i]
27 | break
28 |
29 | parser.add_argument('--gamma', type=float, default=None,
30 | help='Margin parameter for GEM.')
31 | return parser
32 |
33 |
34 | def store_grad(params, grads, grad_dims):
35 | """
36 | This stores parameter gradients of past tasks.
37 | pp: parameters
38 | grads: gradients
39 | grad_dims: list with number of parameters per layers
40 | """
41 | # store the gradients
42 | grads.fill_(0.0)
43 | count = 0
44 | for param in params():
45 | if param.grad is not None:
46 | begin = 0 if count == 0 else sum(grad_dims[:count])
47 | end = np.sum(grad_dims[:count + 1])
48 | grads[begin: end].copy_(param.grad.data.view(-1))
49 | count += 1
50 |
51 |
52 | def overwrite_grad(params, newgrad, grad_dims):
53 | """
54 | This is used to overwrite the gradients with a new gradient
55 | vector, whenever violations occur.
56 | pp: parameters
57 | newgrad: corrected gradient
58 | grad_dims: list storing number of parameters at each layer
59 | """
60 | count = 0
61 | for param in params():
62 | if param.grad is not None:
63 | begin = 0 if count == 0 else sum(grad_dims[:count])
64 | end = sum(grad_dims[:count + 1])
65 | this_grad = newgrad[begin: end].contiguous().view(
66 | param.grad.data.size())
67 | param.grad.data.copy_(this_grad)
68 | count += 1
69 |
70 |
71 | def project2cone2(gradient, memories, margin=0.5, eps=1e-3):
72 | """
73 | Solves the GEM dual QP described in the paper given a proposed
74 | gradient "gradient", and a memory of task gradients "memories".
75 | Overwrites "gradient" with the final projected update.
76 |
77 | input: gradient, p-vector
78 | input: memories, (t * p)-vector
79 | output: x, p-vector
80 | """
81 | memories_np = memories.cpu().t().double().numpy()
82 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
83 | n_rows = memories_np.shape[0]
84 | self_prod = np.dot(memories_np, memories_np.transpose())
85 | self_prod = 0.5 * (self_prod + self_prod.transpose()) + np.eye(n_rows) * eps
86 | grad_prod = np.dot(memories_np, gradient_np) * -1
87 | G = np.eye(n_rows)
88 | h = np.zeros(n_rows) + margin
89 | v = quadprog.solve_qp(self_prod, grad_prod, G, h)[0]
90 | x = np.dot(v, memories_np) + gradient_np
91 | gradient.copy_(torch.from_numpy(x).view(-1, 1))
92 |
93 |
94 | class Gem(ContinualModel):
95 | NAME = 'gem'
96 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
97 |
98 | def __init__(self, backbone, loss, args, transform):
99 | super(Gem, self).__init__(backbone, loss, args, transform)
100 | self.current_task = 0
101 | self.buffer = Buffer(self.args.buffer_size, self.device)
102 |
103 | # Allocate temporary synaptic memory
104 | self.grad_dims = []
105 | for pp in self.parameters():
106 | self.grad_dims.append(pp.data.numel())
107 |
108 | self.grads_cs = []
109 | self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device)
110 |
111 | def end_task(self, dataset):
112 | self.current_task += 1
113 | self.grads_cs.append(torch.zeros(
114 | np.sum(self.grad_dims)).to(self.device))
115 |
116 | # add data to the buffer
117 | samples_per_task = self.args.buffer_size // dataset.N_TASKS
118 |
119 | loader = dataset.train_loader
120 | cur_y, cur_x = next(iter(loader))[1:]
121 | self.buffer.add_data(
122 | examples=cur_x.to(self.device),
123 | labels=cur_y.to(self.device),
124 | task_labels=torch.ones(samples_per_task,
125 | dtype=torch.long).to(self.device) * (self.current_task - 1)
126 | )
127 |
128 |
129 | def observe(self, inputs, labels, not_aug_inputs):
130 |
131 | if not self.buffer.is_empty():
132 | buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data(
133 | self.args.buffer_size, transform=self.transform)
134 |
135 | for tt in buf_task_labels.unique():
136 | # compute gradient on the memory buffer
137 | self.opt.zero_grad()
138 | cur_task_inputs = buf_inputs[buf_task_labels == tt]
139 | cur_task_labels = buf_labels[buf_task_labels == tt]
140 | cur_task_outputs = self.forward(cur_task_inputs)
141 | penalty = self.loss(cur_task_outputs, cur_task_labels)
142 | penalty.backward()
143 | store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)
144 |
145 | # now compute the grad on the current data
146 | self.opt.zero_grad()
147 | outputs = self.forward(inputs)
148 | loss = self.loss(outputs, labels)
149 | loss.backward()
150 |
151 | # check if gradient violates buffer constraints
152 | if not self.buffer.is_empty():
153 | # copy gradient
154 | store_grad(self.parameters, self.grads_da, self.grad_dims)
155 |
156 | dot_prod = torch.mm(self.grads_da.unsqueeze(0),
157 | torch.stack(self.grads_cs).T)
158 | if (dot_prod < 0).sum() != 0:
159 | project2cone2(self.grads_da.unsqueeze(1),
160 | torch.stack(self.grads_cs).T, margin=self.args.gamma)
161 | # copy gradients back
162 | overwrite_grad(self.parameters, self.grads_da,
163 | self.grad_dims)
164 |
165 | self.opt.step()
166 |
167 | return loss.item()
168 |
--------------------------------------------------------------------------------
/datasets/seq_tinyimagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 | from torch.utils.data import Dataset
9 | from backbone.ResNet18 import resnet18
10 | import torch.nn.functional as F
11 | from utils.conf import base_path
12 | from PIL import Image
13 | import os
14 | from datasets.utils.validation import get_train_val
15 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
16 | from datasets.utils.continual_dataset import get_previous_train_loader
17 | from datasets.transforms.denormalization import DeNormalize
18 |
19 |
20 | class TinyImagenet(Dataset):
21 | """
22 | Defines Tiny Imagenet as for the others pytorch datasets.
23 | """
24 | def __init__(self, root: str, train: bool=True, transform: transforms=None,
25 | target_transform: transforms=None, download: bool=False) -> None:
26 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
27 | self.root = root
28 | self.train = train
29 | self.transform = transform
30 | self.target_transform = target_transform
31 | self.download = download
32 |
33 | if download:
34 | if os.path.isdir(root) and len(os.listdir(root)) > 0:
35 | print('Download not needed, files already on disk.')
36 | else:
37 | from google_drive_downloader import GoogleDriveDownloader as gdd
38 |
39 | # https://drive.google.com/file/d/1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj/view
40 | print('Downloading dataset')
41 | gdd.download_file_from_google_drive(
42 | file_id='1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj',
43 |
44 | dest_path=os.path.join(root, 'tiny-imagenet-processed.zip'),
45 | unzip=True)
46 |
47 | self.data = []
48 | for num in range(20):
49 | self.data.append(np.load(os.path.join(
50 | root, 'processed/x_%s_%02d.npy' %
51 | ('train' if self.train else 'val', num+1))))
52 | self.data = np.concatenate(np.array(self.data))
53 |
54 | self.targets = []
55 | for num in range(20):
56 | self.targets.append(np.load(os.path.join(
57 | root, 'processed/y_%s_%02d.npy' %
58 | ('train' if self.train else 'val', num+1))))
59 | self.targets = np.concatenate(np.array(self.targets))
60 |
61 | def __len__(self):
62 | return len(self.data)
63 |
64 | def __getitem__(self, index):
65 | img, target = self.data[index], self.targets[index]
66 |
67 | # doing this so that it is consistent with all other datasets
68 | # to return a PIL Image
69 | img = Image.fromarray(np.uint8(255 * img))
70 | original_img = img.copy()
71 |
72 | if self.transform is not None:
73 | img = self.transform(img)
74 |
75 | if self.target_transform is not None:
76 | target = self.target_transform(target)
77 |
78 | if hasattr(self, 'logits'):
79 | return img, target, original_img, self.logits[index]
80 |
81 | return img, target
82 |
83 |
84 | class MyTinyImagenet(TinyImagenet):
85 | """
86 | Defines Tiny Imagenet as for the others pytorch datasets.
87 | """
88 | def __init__(self, root: str, train: bool=True, transform: transforms=None,
89 | target_transform: transforms=None, download: bool=False) -> None:
90 | super(MyTinyImagenet, self).__init__(
91 | root, train, transform, target_transform, download)
92 |
93 | def __getitem__(self, index):
94 | img, target = self.data[index], self.targets[index]
95 |
96 | # doing this so that it is consistent with all other datasets
97 | # to return a PIL Image
98 | img = Image.fromarray(np.uint8(255 * img))
99 | original_img = img.copy()
100 |
101 | not_aug_img = self.not_aug_transform(original_img)
102 |
103 | if self.transform is not None:
104 | img = self.transform(img)
105 |
106 | if self.target_transform is not None:
107 | target = self.target_transform(target)
108 |
109 | if hasattr(self, 'logits'):
110 | return img, target, not_aug_img, self.logits[index]
111 |
112 | return img, target, not_aug_img
113 |
114 |
115 | class SequentialTinyImagenet(ContinualDataset):
116 |
117 | NAME = 'seq-tinyimg'
118 | SETTING = 'class-il'
119 | N_CLASSES_PER_TASK = 20
120 | N_TASKS = 10
121 | TRANSFORM = transforms.Compose(
122 | [transforms.RandomCrop(64, padding=4),
123 | transforms.RandomHorizontalFlip(),
124 | transforms.ToTensor(),
125 | transforms.Normalize((0.4802, 0.4480, 0.3975),
126 | (0.2770, 0.2691, 0.2821))])
127 |
128 | def get_data_loaders(self):
129 | transform = self.TRANSFORM
130 |
131 | test_transform = transforms.Compose(
132 | [transforms.ToTensor(), self.get_normalization_transform()])
133 |
134 | train_dataset = MyTinyImagenet(base_path() + 'TINYIMG',
135 | train=True, download=True, transform=transform)
136 | if self.args.validation:
137 | train_dataset, test_dataset = get_train_val(train_dataset,
138 | test_transform, self.NAME)
139 | else:
140 | test_dataset = TinyImagenet(base_path() + 'TINYIMG',
141 | train=False, download=True, transform=test_transform)
142 |
143 | train, test = store_masked_loaders(train_dataset, test_dataset, self)
144 | return train, test
145 |
146 | @staticmethod
147 | def get_backbone():
148 | return resnet18(SequentialTinyImagenet.N_CLASSES_PER_TASK
149 | * SequentialTinyImagenet.N_TASKS)
150 |
151 | @staticmethod
152 | def get_loss():
153 | return F.cross_entropy
154 |
155 | def get_transform(self):
156 | transform = transforms.Compose(
157 | [transforms.ToPILImage(), self.TRANSFORM])
158 | return transform
159 |
160 | @staticmethod
161 | def get_normalization_transform():
162 | transform = transforms.Normalize((0.4802, 0.4480, 0.3975),
163 | (0.2770, 0.2691, 0.2821))
164 | return transform
165 |
166 | @staticmethod
167 | def get_denormalization_transform():
168 | transform = DeNormalize((0.4802, 0.4480, 0.3975),
169 | (0.2770, 0.2691, 0.2821))
170 | return transform
171 |
172 | @staticmethod
173 | def get_scheduler(model, args):
174 | return None
--------------------------------------------------------------------------------
/models/hal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from utils.buffer import Buffer
8 | from utils.args import *
9 | from models.utils.continual_model import ContinualModel
10 | from datasets import get_dataset
11 | import numpy as np
12 | from torch.optim import SGD
13 | import sys
14 |
15 | import numpy as np
16 | import copy
17 | from torch.nn import functional as F
18 | import math
19 | from torch.optim import SGD
20 | from collections import OrderedDict
21 | EPS = 1E-20
22 |
23 | def get_parser() -> ArgumentParser:
24 | parser = ArgumentParser(description='Continual learning via'
25 | ' Experience Replay.')
26 | add_management_args(parser)
27 | add_experiment_args(parser)
28 | add_rehearsal_args(parser)
29 |
30 | parser.add_argument('--hal_lambda', type=float, default=0.1)
31 | parser.add_argument('--beta', type=float, default=0.5)
32 | parser.add_argument('--gamma', type=float, default=0.1)
33 |
34 | return parser
35 |
36 |
37 | class HAL(ContinualModel):
38 | NAME = 'hal'
39 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
40 |
41 | def __init__(self, backbone, loss, args, transform):
42 | super(HAL, self).__init__(backbone, loss, args, transform)
43 | self.task_number = 0
44 | self.buffer = Buffer(self.args.buffer_size, self.device, get_dataset(args).N_TASKS, mode='ring')
45 | self.hal_lambda = args.hal_lambda
46 | self.beta = args.beta
47 | self.gamma = args.gamma
48 | self.anchor_optimization_steps = 100
49 | self.finetuning_epochs = 1
50 | self.dataset = get_dataset(args)
51 | self.spare_model = self.dataset.get_backbone()
52 | self.spare_model.to(self.device)
53 | self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr)
54 |
55 | def end_task(self, dataset):
56 | self.task_number += 1
57 | # ring buffer mgmt (if we are not loading
58 | if self.task_number > self.buffer.task_number:
59 | self.buffer.num_seen_examples = 0
60 | self.buffer.task_number = self.task_number
61 | # get anchors (provided that we are not loading the model
62 | if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK:
63 | self.get_anchors(dataset)
64 | del self.phi
65 |
66 | def get_anchors(self, dataset):
67 | theta_t = self.net.get_params().detach().clone()
68 | self.spare_model.set_params(theta_t)
69 |
70 | # fine tune on memory buffer
71 | for _ in range(self.finetuning_epochs):
72 | inputs, labels = self.buffer.get_data(self.args.batch_size, transform=self.transform)
73 | self.spare_opt.zero_grad()
74 | out = self.spare_model(inputs)
75 | loss = self.loss(out, labels)
76 | loss.backward()
77 | self.spare_opt.step()
78 |
79 | theta_m = self.spare_model.get_params().detach().clone()
80 |
81 | classes_for_this_task = np.unique(dataset.train_loader.dataset.targets)
82 |
83 | for a_class in classes_for_this_task:
84 | e_t = torch.rand(self.input_shape, requires_grad=True, device=self.device)
85 | e_t_opt = SGD([e_t], lr=self.args.lr)
86 | print(file=sys.stderr)
87 | for i in range(self.anchor_optimization_steps):
88 | e_t_opt.zero_grad()
89 | cum_loss = 0
90 |
91 | self.spare_opt.zero_grad()
92 | self.spare_model.set_params(theta_m.detach().clone())
93 | loss = -torch.sum(self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device)))
94 | loss.backward()
95 | cum_loss += loss.item()
96 |
97 | self.spare_opt.zero_grad()
98 | self.spare_model.set_params(theta_t.detach().clone())
99 | loss = torch.sum(self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device)))
100 | loss.backward()
101 | cum_loss += loss.item()
102 |
103 | self.spare_opt.zero_grad()
104 | loss = torch.sum(self.gamma * (self.spare_model(e_t.unsqueeze(0), returnt='features') - self.phi) ** 2)
105 | assert not self.phi.requires_grad
106 | loss.backward()
107 | cum_loss += loss.item()
108 |
109 | e_t_opt.step()
110 |
111 | e_t = e_t.detach()
112 | e_t.requires_grad = False
113 | self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0)))
114 | del e_t
115 | print('Total anchors:', len(self.anchors), file=sys.stderr)
116 |
117 | self.spare_model.zero_grad()
118 |
119 | def observe(self, inputs, labels, not_aug_inputs):
120 | real_batch_size = inputs.shape[0]
121 | if not hasattr(self, 'input_shape'):
122 | self.input_shape = inputs.shape[1:]
123 | if not hasattr(self, 'anchors'):
124 | self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to(self.device)
125 | if not hasattr(self, 'phi'):
126 | print('Building phi', file=sys.stderr)
127 | with torch.no_grad():
128 | self.phi = torch.zeros_like(self.net(inputs[0].unsqueeze(0), returnt='features'), requires_grad=False)
129 | assert not self.phi.requires_grad
130 |
131 | if not self.buffer.is_empty():
132 | buf_inputs, buf_labels = self.buffer.get_data(
133 | self.args.minibatch_size, transform=self.transform)
134 | inputs = torch.cat((inputs, buf_inputs))
135 | labels = torch.cat((labels, buf_labels))
136 |
137 | old_weights = self.net.get_params().detach().clone()
138 |
139 | self.opt.zero_grad()
140 | outputs = self.net(inputs)
141 |
142 | k = self.task_number
143 |
144 | loss = self.loss(outputs, labels)
145 | loss.backward()
146 | self.opt.step()
147 |
148 | first_loss = 0
149 |
150 | assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k
151 |
152 | if len(self.anchors) > 0:
153 | first_loss = loss.item()
154 | with torch.no_grad():
155 | pred_anchors = self.net(self.anchors)
156 |
157 | self.net.set_params(old_weights)
158 | pred_anchors -= self.net(self.anchors)
159 | loss = self.hal_lambda * (pred_anchors ** 2).mean()
160 | loss.backward()
161 | self.opt.step()
162 |
163 | with torch.no_grad():
164 | self.phi = self.beta * self.phi + (1 - self.beta) * self.net(inputs[:real_batch_size], returnt='features').mean(0)
165 |
166 | self.buffer.add_data(examples=not_aug_inputs,
167 | labels=labels[:real_batch_size])
168 |
169 | return first_loss + loss.item()
170 |
--------------------------------------------------------------------------------
/utils/gss_buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from typing import Tuple
10 | from torchvision import transforms
11 |
12 | class Buffer:
13 | """
14 | The memory buffer of rehearsal method.
15 | """
16 | def __init__(self, buffer_size, device, minibatch_size, model=None):
17 | self.buffer_size = buffer_size
18 | self.device = device
19 | self.num_seen_examples = 0
20 | self.attributes = ['examples', 'labels']
21 | self.model = model
22 | self.minibatch_size = minibatch_size
23 | self.cache = {}
24 | self.fathom = 0
25 | self.fathom_mask = None
26 | self.reset_fathom()
27 |
28 | self.conterone = 0
29 |
30 | def reset_fathom(self):
31 | self.fathom = 0
32 | self.fathom_mask = torch.randperm(min(self.num_seen_examples, self.examples.shape[0] if hasattr(self, 'examples') else self.num_seen_examples))
33 |
34 | def get_grad_score(self, x, y, X, Y, indices):
35 | g = self.model.get_grads(x, y)
36 | G = []
37 | for x, y, idx in zip(X, Y, indices):
38 | if idx in self.cache:
39 | grd = self.cache[idx]
40 | else:
41 | grd = self.model.get_grads(x.unsqueeze(0), y.unsqueeze(0))
42 | self.cache[idx] = grd
43 | G.append(grd)
44 | G = torch.cat(G).to(g.device)
45 | c_score = 0
46 | grads_at_a_time = 5
47 | # let's split this so your gpu does not melt. You're welcome.
48 | for it in range(int(np.ceil(G.shape[0] / grads_at_a_time))):
49 | tmp = F.cosine_similarity(g, G[it*grads_at_a_time: (it+1)*grads_at_a_time], dim=1).max().item() + 1
50 | c_score = max(c_score, tmp)
51 | return c_score
52 |
53 | def functional_reservoir(self, x, y, batch_c, bigX=None, bigY=None, indices=None):
54 | if self.num_seen_examples < self.buffer_size:
55 | return self.num_seen_examples, batch_c
56 |
57 | elif batch_c < 1:
58 | single_c = self.get_grad_score(x.unsqueeze(0), y.unsqueeze(0), bigX, bigY, indices)
59 | s = self.scores.cpu().numpy()
60 | i = np.random.choice(np.arange(0, self.buffer_size), size=1, p=s / s.sum())[0]
61 | rand = np.random.rand(1)[0]
62 | # print(rand, s[i] / (s[i] + c))
63 | if rand < s[i] / (s[i] + single_c):
64 | return i, single_c
65 |
66 | return -1, 0
67 |
68 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor) -> None:
69 | """
70 | Initializes just the required tensors.
71 | :param examples: tensor containing the images
72 | :param labels: tensor containing the labels
73 | :param logits: tensor containing the outputs of the network
74 | :param task_labels: tensor containing the task labels
75 | """
76 | for attr_str in self.attributes:
77 | attr = eval(attr_str)
78 | if attr is not None and not hasattr(self, attr_str):
79 | typ = torch.int64 if attr_str.endswith('els') else torch.float32
80 | setattr(self, attr_str, torch.zeros((self.buffer_size,
81 | *attr.shape[1:]), dtype=typ, device=self.device))
82 | self.scores = torch.zeros((self.buffer_size,*attr.shape[1:]),
83 | dtype=torch.float32, device=self.device)
84 |
85 | def add_data(self, examples, labels=None):
86 | """
87 | Adds the data to the memory buffer according to the reservoir strategy.
88 | :param examples: tensor containing the images
89 | :param labels: tensor containing the labels
90 | :param logits: tensor containing the outputs of the network
91 | :param task_labels: tensor containing the task labels
92 | :return:
93 | """
94 | if not hasattr(self, 'examples'):
95 | self.init_tensors(examples, labels)
96 |
97 | # compute buffer score
98 | if self.num_seen_examples > 0:
99 | bigX, bigY, indices = self.get_data(min(self.minibatch_size, self.num_seen_examples), give_index=True,
100 | random=True)
101 | c = self.get_grad_score(examples, labels, bigX, bigY, indices)
102 | else:
103 | bigX, bigY, indices = None, None, None
104 | c = 0.1
105 |
106 | for i in range(examples.shape[0]):
107 | index, score = self.functional_reservoir(examples[i], labels[i], c, bigX, bigY, indices)
108 | self.num_seen_examples += 1
109 | if index >= 0:
110 | self.examples[index] = examples[i].to(self.device)
111 | if labels is not None:
112 | self.labels[index] = labels[i].to(self.device)
113 | self.scores[index] = score
114 | if index in self.cache:
115 | del self.cache[index]
116 |
117 | def drop_cache(self):
118 | self.cache = {}
119 |
120 | def get_data(self, size: int, transform: transforms=None, give_index=False, random=False) -> Tuple:
121 | """
122 | Random samples a batch of size items.
123 | :param size: the number of requested items
124 | :param transform: the transformation to be applied (data augmentation)
125 | :return:
126 | """
127 |
128 | if size > self.examples.shape[0]:
129 | size = self.examples.shape[0]
130 |
131 | if random:
132 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
133 | size=min(size, self.num_seen_examples),
134 | replace=False)
135 | else:
136 | choice = np.arange(self.fathom, min(self.fathom + size, self.examples.shape[0], self.num_seen_examples))
137 | choice = self.fathom_mask[choice]
138 | self.fathom += len(choice)
139 | if self.fathom >= self.examples.shape[0] or self.fathom >= self.num_seen_examples:
140 | self.fathom = 0
141 | if transform is None: transform = lambda x: x
142 | ret_tuple = (torch.stack([transform(ee.cpu())
143 | for ee in self.examples[choice]]).to(self.device),)
144 | for attr_str in self.attributes[1:]:
145 | if hasattr(self, attr_str):
146 | attr = getattr(self, attr_str)
147 | ret_tuple += (attr[choice],)
148 | if give_index:
149 | ret_tuple += (choice,)
150 |
151 | return ret_tuple
152 |
153 | def is_empty(self) -> bool:
154 | """
155 | Returns true if the buffer is empty, false otherwise.
156 | """
157 | if self.num_seen_examples == 0:
158 | return True
159 | else:
160 | return False
161 |
162 | def get_all_data(self, transform: transforms=None) -> Tuple:
163 | """
164 | Return all the items in the memory buffer.
165 | :param transform: the transformation to be applied (data augmentation)
166 | :return: a tuple with all the items in the memory buffer
167 | """
168 | if transform is None: transform = lambda x: x
169 | ret_tuple = (torch.stack([transform(ee.cpu())
170 | for ee in self.examples]).to(self.device),)
171 | for attr_str in self.attributes[1:]:
172 | if hasattr(self, attr_str):
173 | attr = getattr(self, attr_str)
174 | ret_tuple += (attr,)
175 | return ret_tuple
176 |
177 | def empty(self) -> None:
178 | """
179 | Set all the tensors to None.
180 | """
181 | for attr_str in self.attributes:
182 | if hasattr(self, attr_str):
183 | delattr(self, attr_str)
184 | self.num_seen_examples = 0
185 |
--------------------------------------------------------------------------------
/utils/vmf_sampling.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | class VonMisesFisher(torch.distributions.Distribution):
4 |
5 | arg_constraints = {
6 | "loc": torch.distributions.constraints.real,
7 | "scale": torch.distributions.constraints.positive,
8 | }
9 | support = torch.distributions.constraints.real
10 | has_rsample = True
11 | _mean_carrier_measure = 0
12 |
13 | def __init__(self, loc, scale, validate_args=None, k=1):
14 | self.dtype = loc.dtype
15 | self.loc = loc
16 | self.scale = scale
17 | self.device = loc.device
18 | self.__m = loc.shape[-1]
19 | self.__e1 = (torch.Tensor([1.0] + [0] * (loc.shape[-1] - 1))).to(self.device)
20 | self.k = k
21 |
22 | super().__init__(self.loc.size(), validate_args=validate_args)
23 |
24 | def sample(self, shape=torch.Size()):
25 | with torch.no_grad():
26 | return self.rsample(shape)
27 |
28 | def rsample(self, shape=torch.Size()):
29 | shape = shape if isinstance(shape, torch.Size) else torch.Size([shape])
30 |
31 | w = (
32 | self.__sample_w3(shape=shape)
33 | if self.__m == 3
34 | else self.__sample_w_rej(shape=shape)
35 | )
36 |
37 | v = (
38 | torch.distributions.Normal(0, 1)
39 | .sample(shape + torch.Size(self.loc.shape))
40 | .to(self.device)
41 | .transpose(0, -1)[1:]
42 | ).transpose(0, -1)
43 | v = v / v.norm(dim=-1, keepdim=True)
44 |
45 | w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10))
46 | x = torch.cat((w, w_ * v), -1)
47 | z = self.__householder_rotation(x)
48 |
49 | return z.type(self.dtype)
50 |
51 | def __sample_w3(self, shape):
52 | shape = shape + torch.Size(self.scale.shape)
53 | u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device)
54 | self.__w = (
55 | 1
56 | + torch.stack(
57 | [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0
58 | ).logsumexp(0)
59 | / self.scale
60 | )
61 | return self.__w
62 |
63 | def __sample_w_rej(self, shape):
64 | c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2)
65 | b_true = (-2 * self.scale + c) / (self.__m - 1)
66 |
67 | # using Taylor approximation with a smooth swift from 10 < scale < 11
68 | # to avoid numerical errors for large scale
69 | b_app = (self.__m - 1) / (4 * self.scale)
70 | s = torch.min(
71 | torch.max(
72 | torch.tensor([0.0], dtype=self.dtype, device=self.device),
73 | self.scale - 10,
74 | ),
75 | torch.tensor([1.0], dtype=self.dtype, device=self.device),
76 | )
77 | b = b_app * s + b_true * (1 - s)
78 |
79 | a = (self.__m - 1 + 2 * self.scale + c) / 4
80 | d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1)
81 |
82 | self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape, k=self.k)
83 | return self.__w
84 |
85 | @staticmethod
86 | def first_nonzero(x, dim, invalid_val=-1):
87 | mask = x > 0
88 | idx = torch.where(
89 | mask.any(dim=dim),
90 | mask.float().argmax(dim=1).squeeze(),
91 | torch.tensor(invalid_val, device=x.device),
92 | )
93 | return idx
94 |
95 | def __while_loop(self, b, a, d, shape, k=20, eps=1e-20):
96 | # matrix while loop: samples a matrix of [A, k] samples, to avoid looping all together
97 | b, a, d = [
98 | e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1)
99 | for e in (b, a, d)
100 | ]
101 | w, e, bool_mask = (
102 | torch.zeros_like(b).to(self.device),
103 | torch.zeros_like(b).to(self.device),
104 | (torch.ones_like(b) == 1).to(self.device),
105 | )
106 |
107 | sample_shape = torch.Size([b.shape[0], k])
108 | shape = shape + torch.Size(self.scale.shape)
109 |
110 | while bool_mask.sum() != 0:
111 | con1 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64)
112 | con2 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64)
113 | e_ = (
114 | torch.distributions.Beta(con1, con2)
115 | .sample(sample_shape)
116 | .to(self.device)
117 | .type(self.dtype)
118 | )
119 |
120 | u = (
121 | torch.distributions.Uniform(0 + eps, 1 - eps)
122 | .sample(sample_shape)
123 | .to(self.device)
124 | .type(self.dtype)
125 | )
126 |
127 | w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_)
128 | t = (2 * a * b) / (1 - (1 - b) * e_)
129 |
130 | accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u)
131 | accept_idx = self.first_nonzero(accept, dim=-1, invalid_val=-1).unsqueeze(1)
132 | accept_idx_clamped = accept_idx.clamp(0)
133 | # we use .abs(), in order to not get -1 index issues, the -1 is still used afterwards
134 | w_ = w_.gather(1, accept_idx_clamped.view(-1, 1))
135 | e_ = e_.gather(1, accept_idx_clamped.view(-1, 1))
136 |
137 | reject = accept_idx < 0
138 | accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject
139 |
140 | w[bool_mask * accept] = w_[bool_mask * accept]
141 | e[bool_mask * accept] = e_[bool_mask * accept]
142 |
143 | bool_mask[bool_mask * accept] = reject[bool_mask * accept]
144 |
145 | return e.reshape(shape), w.reshape(shape)
146 |
147 | def __householder_rotation(self, x):
148 | u = self.__e1 - self.loc
149 | u = u / (u.norm(dim=-1, keepdim=True) + 1e-5)
150 | z = x - 2 * (x * u).sum(-1, keepdim=True) * u
151 | return z
152 |
153 |
154 | # polar_noise = torch.full([32, 512], 100.0) #torch.normal(0.0, 1, size=(32,512))
155 | # norm_x = torch.norm(polar_noise.clone(), 2, 1, keepdim=True)
156 | # polar_noise = polar_noise / norm_x
157 | # print(polar_noise)
158 | # z_var = torch.full([32, 1], 1.0)
159 | # q_z = VonMisesFisher(polar_noise, z_var)
160 |
161 | # z = q_z.rsample()
162 | # print(z.shape)
163 | # print(z)
164 |
165 |
166 |
167 | # if 'vmf' in self.args.method2:
168 | # vmf_buf_features = copy.deepcopy(buf_features.detach())
169 | # vmf_buf_features = euclid2polar(vmf_buf_features)
170 | # kappa = self.args.gamma_loss
171 | # polar_noise = torch.from_numpy(np.random.vonmises(0, kappa, (vmf_buf_features.shape[0], vmf_buf_features.shape[1])))
172 | # if 'vmf_g' == self.args.method2:
173 | # polar_noise = torch.normal(0.0, 1, size=vmf_buf_features.shape).to(vmf_buf_features.device)
174 | # polar_mask = polar_noise < 0
175 | # polar_noise[polar_mask] = -polar_noise[polar_mask]
176 | # polar_noise = polar_noise.to(vmf_buf_features).cuda()
177 |
178 | # # vmf_buf_features = vmf_buf_features.cuda() + polar_noise.cuda()
179 | # # vmf_buf_features = polar2euclid(vmf_buf_features)
180 | # # buf_features_noise = vmf_buf_features - buf_features
181 | # buf_features_noise = polar2euclid(polar_noise)
182 | # buf_features_noise = buf_features_noise.detach().cuda()
183 | # buf_features_noise = buf_features_noise.to(torch.float32)
184 | # if 'vmf_pa' == self.args.method2:
185 | # vmf_buf_features = vmf_buf_features.cuda() + polar_noise.cuda()
186 | # buf_features_noise = polar2euclid(vmf_buf_features)
187 | # buf_features_noise = buf_features_noise.detach().cuda()
188 | # buf_features_noise = buf_features_noise.to(torch.float32)
189 | # buf_features_noise = buf_features_noise - buf_features.detach()
--------------------------------------------------------------------------------