├── agents
├── __init__.py
├── cndpm.py
├── lwf.py
├── icarl.py
├── scr.py
├── gdumb.py
├── summarize.py
├── agem.py
├── ewc_pp.py
├── exp_replay.py
└── base.py
├── utils
├── __init__.py
├── buffer
│ ├── __init__.py
│ ├── sc_retrieve.py
│ ├── random_retrieve.py
│ ├── mem_match.py
│ ├── reservoir_update.py
│ ├── mir_retrieve.py
│ ├── aser_retrieve.py
│ ├── buffer.py
│ ├── aser_update.py
│ ├── gss_greedy_update.py
│ ├── aser_utils.py
│ └── buffer_utils.py
├── kd_manager.py
├── logging.py
├── global_vars.py
├── io.py
├── name_match.py
├── setup_elements.py
├── loss.py
└── utils.py
├── continuum
├── __init__.py
├── dataset_scripts
│ ├── __init__.py
│ ├── dataset_base.py
│ ├── cifar100.py
│ ├── cifar10.py
│ ├── tiny_imagenet.py
│ ├── mini_imagenet.py
│ ├── openloris.py
│ └── core50.py
├── continuum.py
├── data_utils.py
└── non_stationary.py
├── experiment
├── __init__.py
├── tune_hyperparam.py
└── metrics.py
├── models
├── ndpm
│ ├── __init__.py
│ ├── utils.py
│ ├── loss.py
│ ├── priors.py
│ ├── expert.py
│ ├── component.py
│ ├── ndpm.py
│ ├── classifier.py
│ └── vae.py
├── __init__.py
├── pretrained.py
├── convnet.py
└── resnet.py
├── figs
└── pipeline-ssd.png
├── requirements.txt
├── process_tiny_imagenet.py
├── main_config.py
├── README.md
├── main_tune.py
└── .gitignore
/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/continuum/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiment/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/ndpm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/buffer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import resnet
--------------------------------------------------------------------------------
/figs/pipeline-ssd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vimar-gu/SSD/HEAD/figs/pipeline-ssd.png
--------------------------------------------------------------------------------
/models/pretrained.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | import torch
3 |
4 | def ResNet18_pretrained(n_classes):
5 | classifier = models.resnet18(pretrained=True)
6 | classifier.fc = torch.nn.Linear(512, n_classes)
7 | return classifier
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -f https://download.pytorch.org/whl/torch_stable.html
2 | torch==1.7.1
3 | torchvision==0.8.2
4 | matplotlib==3.2.1
5 | scipy==1.4.1
6 | scikit-image==0.14.2
7 | scikit-learn==0.23.0
8 | pandas==1.0.5
9 | PyYAML==5.3.1
10 | psutil==5.7.0
11 | kornia==0.4.1
12 |
--------------------------------------------------------------------------------
/models/ndpm/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Lambda(nn.Module):
5 | def __init__(self, f=None):
6 | super().__init__()
7 | self.f = f if f is not None else (lambda x: x)
8 |
9 | def forward(self, *args, **kwargs):
10 | return self.f(*args, **kwargs)
11 |
--------------------------------------------------------------------------------
/utils/buffer/sc_retrieve.py:
--------------------------------------------------------------------------------
1 | from utils.buffer.buffer_utils import match_retrieve
2 | import torch
3 |
4 | class Match_retrieve(object):
5 | def __init__(self, params):
6 | super().__init__()
7 | self.num_retrieve = params.eps_mem_batch
8 | self.warmup = params.warmup
9 |
10 | def retrieve(self, buffer, **kwargs):
11 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup:
12 | cur_x, cur_y = kwargs['x'], kwargs['y']
13 | return match_retrieve(buffer, cur_y)
14 | else:
15 | return torch.tensor([]), torch.tensor([])
--------------------------------------------------------------------------------
/utils/buffer/random_retrieve.py:
--------------------------------------------------------------------------------
1 | from utils.buffer.buffer_utils import random_retrieve, balanced_retrieve
2 |
3 |
4 | class Random_retrieve(object):
5 | def __init__(self, params):
6 | super().__init__()
7 | self.num_retrieve = params.eps_mem_batch
8 |
9 | def retrieve(self, buffer, **kwargs):
10 | return random_retrieve(buffer, self.num_retrieve)
11 |
12 |
13 | class BalancedRetrieve(object):
14 | def __init__(self, params):
15 | super().__init__()
16 | self.num_retrieve = params.eps_mem_batch
17 |
18 | def retrieve(self, buffer, **kwargs):
19 | return balanced_retrieve(buffer, self.num_retrieve)
20 |
21 |
--------------------------------------------------------------------------------
/utils/kd_manager.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def loss_fn_kd(scores, target_scores, T=2.):
7 | log_scores_norm = F.log_softmax(scores / T, dim=1)
8 | targets_norm = F.softmax(target_scores / T, dim=1)
9 | # Calculate distillation loss (see e.g., Li and Hoiem, 2017)
10 | kd_loss = (-1 * targets_norm * log_scores_norm).sum(dim=1).mean() * T ** 2
11 | return kd_loss
12 |
13 |
14 | class KdManager:
15 | def __init__(self):
16 | self.teacher_model = None
17 |
18 | def update_teacher(self, model):
19 | self.teacher_model = copy.deepcopy(model)
20 |
21 | def get_kd_loss(self, cur_model_logits, x):
22 | if self.teacher_model is not None:
23 | with torch.no_grad():
24 | prev_model_logits = self.teacher_model.forward(x)
25 | dist_loss = loss_fn_kd(cur_model_logits, prev_model_logits)
26 | else:
27 | dist_loss = 0
28 | return dist_loss
29 |
--------------------------------------------------------------------------------
/utils/buffer/mem_match.py:
--------------------------------------------------------------------------------
1 | from utils.buffer.buffer_utils import match_retrieve
2 | from utils.buffer.buffer_utils import random_retrieve
3 | import torch
4 |
5 | class MemMatch_retrieve(object):
6 | def __init__(self, params):
7 | super().__init__()
8 | self.num_retrieve = params.eps_mem_batch
9 | self.warmup = params.warmup
10 |
11 |
12 | def retrieve(self, buffer, **kwargs):
13 | match_x, match_y = torch.tensor([]), torch.tensor([])
14 | candidate_x, candidate_y = torch.tensor([]), torch.tensor([])
15 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup:
16 | while match_x.size(0) == 0:
17 | candidate_x, candidate_y, indices = random_retrieve(buffer, self.num_retrieve,return_indices=True)
18 | if candidate_x.size(0) == 0:
19 | return candidate_x, candidate_y, match_x, match_y
20 | match_x, match_y = match_retrieve(buffer, candidate_y, indices)
21 | return candidate_x, candidate_y, match_x, match_y
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 | import errno
5 |
6 |
7 | def mkdir_if_missing(dir_path):
8 | try:
9 | os.makedirs(dir_path)
10 | except OSError as e:
11 | if e.errno != errno.EEXIST:
12 | raise
13 |
14 | class Logger(object):
15 | def __init__(self, fpath=None):
16 | self.console = sys.stdout
17 | self.file = None
18 | if fpath is not None:
19 | mkdir_if_missing(os.path.dirname(fpath))
20 | self.file = open(fpath, 'w')
21 |
22 | def __del__(self):
23 | self.close()
24 |
25 | def __enter__(self):
26 | pass
27 |
28 | def __exit__(self, *args):
29 | self.close()
30 |
31 | def write(self, msg):
32 | self.console.write(msg)
33 | if self.file is not None:
34 | self.file.write(msg)
35 |
36 | def flush(self):
37 | self.console.flush()
38 | if self.file is not None:
39 | self.file.flush()
40 | os.fsync(self.file.fileno())
41 |
42 | def close(self):
43 | self.console.close()
44 | if self.file is not None:
45 | self.file.close()
46 |
47 |
--------------------------------------------------------------------------------
/continuum/continuum.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # other imports
4 | from utils.name_match import data_objects
5 |
6 | class continuum(object):
7 | def __init__(self, dataset, scenario, params):
8 | """" Initialize Object """
9 | self.data_object = data_objects[dataset](scenario, params)
10 | self.run = params.num_runs
11 | self.task_nums = self.data_object.task_nums
12 | self.cur_task = 0
13 | self.cur_run = -1
14 |
15 | def __iter__(self):
16 | return self
17 |
18 | def __next__(self):
19 | if self.cur_task == self.data_object.task_nums:
20 | raise StopIteration
21 | x_train, y_train, labels = self.data_object.new_task(self.cur_task, cur_run=self.cur_run)
22 | self.cur_task += 1
23 | return x_train, y_train, labels
24 |
25 | def test_data(self):
26 | return self.data_object.get_test_set()
27 |
28 | def clean_mem_test_set(self):
29 | self.data_object.clean_mem_test_set()
30 |
31 | def reset_run(self):
32 | self.cur_task = 0
33 |
34 | def new_run(self):
35 | self.cur_task = 0
36 | self.cur_run += 1
37 | self.data_object.new_run(cur_run=self.cur_run)
38 |
39 |
40 |
--------------------------------------------------------------------------------
/utils/global_vars.py:
--------------------------------------------------------------------------------
1 | MODLES_NDPM_VAE_NF_BASE = 32
2 | MODELS_NDPM_VAE_NF_EXT = 4
3 | MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER = False
4 | MODELS_NDPM_VAE_Z_DIM = 64
5 | MODELS_NDPM_VAE_RECON_LOSS = 'gaussian'
6 | MODELS_NDPM_VAE_LEARN_X_LOG_VAR = False
7 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM = 0
8 | MODELS_NDPM_VAE_Z_SAMPLES = 16
9 | MODELS_NDPM_CLASSIFIER_NUM_BLOCKS = [1, 1, 1, 1]
10 | MODELS_NDPM_CLASSIFIER_NORM_LAYER = 'InstanceNorm2d'
11 | MODELS_NDPM_CLASSIFIER_CLS_NF_BASE = 20
12 | MODELS_NDPM_CLASSIFIER_CLS_NF_EXT = 4
13 | MODELS_NDPM_NDPM_DISABLE_D = False
14 | MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS = False
15 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE = 50
16 | MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS = 0
17 | MODELS_NDPM_NDPM_SLEEP_STEP_G = 4000
18 | MODELS_NDPM_NDPM_SLEEP_STEP_D = 1000
19 | MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE = 0
20 | MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP = 500
21 | MODELS_NDPM_NDPM_WEIGHT_DECAY = 0.00001
22 | MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY = False
23 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}}
24 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}}
25 | MODELS_NDPM_COMPONENT_CLIP_GRAD = {'type': 'value', 'options': {'clip_value': 0.5}}
26 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/dataset_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import os
3 |
4 | class DatasetBase(ABC):
5 | def __init__(self, dataset, scenario, task_nums, run, params):
6 | super(DatasetBase, self).__init__()
7 | self.params = params
8 | self.scenario = scenario
9 | self.dataset = dataset
10 | self.task_nums = task_nums
11 | self.run = run
12 | self.root = './datasets'
13 | self.test_set = []
14 | self.val_set = []
15 | self._is_properly_setup()
16 | self.download_load()
17 |
18 |
19 | @abstractmethod
20 | def download_load(self):
21 | pass
22 |
23 | @abstractmethod
24 | def setup(self, **kwargs):
25 | pass
26 |
27 | @abstractmethod
28 | def new_task(self, cur_task, **kwargs):
29 | pass
30 |
31 | def _is_properly_setup(self):
32 | pass
33 |
34 | @abstractmethod
35 | def new_run(self, **kwargs):
36 | pass
37 |
38 | @property
39 | def dataset_info(self):
40 | return self.dataset
41 |
42 | def get_test_set(self):
43 | return self.test_set
44 |
45 | def clean_mem_test_set(self):
46 | self.test_set = None
47 | self.test_data = None
48 | self.test_label = None
49 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import pandas as pd
3 | import os
4 | import psutil
5 | import torch
6 |
7 | def load_yaml(path, key='parameters'):
8 | with open(path, 'r') as stream:
9 | try:
10 | return yaml.load(stream, Loader=yaml.FullLoader)[key]
11 | except yaml.YAMLError as exc:
12 | print(exc)
13 |
14 | def save_dataframe_csv(df, path, name):
15 | df.to_csv(path + '/' + name, index=False)
16 |
17 |
18 | def load_dataframe_csv(path, name=None, delimiter=None, names=None):
19 | if not name:
20 | return pd.read_csv(path, delimiter=delimiter, names=names)
21 | else:
22 | return pd.read_csv(path+name, delimiter=delimiter, names=names)
23 |
24 | def check_ram_usage():
25 | """
26 | Compute the RAM usage of the current process.
27 | Returns:
28 | mem (float): Memory occupation in Megabytes
29 | """
30 |
31 | process = psutil.Process(os.getpid())
32 | mem = process.memory_info().rss / (1024 * 1024)
33 |
34 | return mem
35 |
36 | def save_model(model, optimizer, opt, epoch, save_file):
37 | print('==> Saving...')
38 | state = {
39 | 'opt': opt,
40 | 'model': model.state_dict(),
41 | 'optimizer': optimizer.state_dict(),
42 | 'epoch': epoch,
43 | }
44 | torch.save(state, save_file)
45 | del state
46 |
--------------------------------------------------------------------------------
/process_tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | with open('wnids.txt', 'r') as fp:
8 | wnids = fp.readlines()
9 | wnids = [wnid.strip() for wnid in wnids]
10 | wnid2index = {wnid: index for index, wnid in enumerate(wnids)}
11 |
12 | train_data = {'data': [], 'target': []}
13 | for wnid in wnids:
14 | image_list = os.listdir(f'train/{wnid}/images')
15 | for image_name in image_list:
16 | image = np.asarray(Image.open(f'train/{wnid}/images/{image_name}').convert('RGB')).transpose(2, 0, 1)
17 | train_data['data'].append(image)
18 | train_data['target'].append(wnid2index[wnid])
19 | train_data['data'] = np.stack(train_data['data'])
20 | train_data['target'] = np.array(train_data['target'])
21 | pickle.dump(train_data, open('train.pkl', 'wb'))
22 |
23 | val_data = {'data': [], 'target': []}
24 | with open('val/val_annotations.txt', 'r') as fp:
25 | val_annos = fp.readlines()
26 | for val_anno in val_annos:
27 | image_name, wnid = val_anno.split('\t')[:2]
28 | image = np.asarray(Image.open(f'val/images/{image_name}').convert('RGB')).transpose(2, 0, 1)
29 | val_data['data'].append(image)
30 | val_data['target'].append(wnid2index[wnid])
31 | val_data['data'] = np.stack(val_data['data'])
32 | val_data['target'] = np.array(val_data['target'])
33 | pickle.dump(val_data, open('val.pkl', 'wb'))
34 |
--------------------------------------------------------------------------------
/models/ndpm/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import Tensor
4 | from torch.nn.functional import binary_cross_entropy
5 |
6 |
7 | def gaussian_nll(x, mean, log_var, min_noise=0.001):
8 | return (
9 | ((x - mean) ** 2 + min_noise) / (2 * log_var.exp() + 1e-8)
10 | + 0.5 * log_var + 0.5 * np.log(2 * np.pi)
11 | )
12 |
13 |
14 | def laplace_nll(x, median, log_scale, min_noise=0.01):
15 | return (
16 | ((x - median).abs() + min_noise) / (log_scale.exp() + 1e-8)
17 | + log_scale + np.log(2)
18 | )
19 |
20 |
21 | def bernoulli_nll(x, p):
22 | # Broadcast
23 | x_exp, p_exp = [], []
24 | for x_size, p_size in zip(x.size(), p.size()):
25 | if x_size > p_size:
26 | x_exp.append(-1)
27 | p_exp.append(x_size)
28 | elif x_size < p_size:
29 | x_exp.append(p_size)
30 | p_exp.append(-1)
31 | else:
32 | x_exp.append(-1)
33 | p_exp.append(-1)
34 | x = x.expand(*x_exp)
35 | p = p.expand(*p_exp)
36 |
37 | return binary_cross_entropy(p, x, reduction='none')
38 |
39 |
40 | def logistic_nll(x, mean, log_scale):
41 | bin_size = 1 / 256
42 | scale = log_scale.exp()
43 | x_centered = x - mean
44 | cdf1 = x_centered / scale
45 | cdf2 = (x_centered + bin_size) / scale
46 | p = torch.sigmoid(cdf2) - torch.sigmoid(cdf1) + 1e-12
47 | return -p.log()
48 |
--------------------------------------------------------------------------------
/main_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils.io import load_yaml
3 | from types import SimpleNamespace
4 | from utils.utils import boolean_string
5 | import time
6 | import torch
7 | import random
8 | import numpy as np
9 | from experiment.run import multiple_run
10 |
11 | def main(args):
12 | genereal_params = load_yaml(args.general)
13 | data_params = load_yaml(args.data)
14 | agent_params = load_yaml(args.agent)
15 | genereal_params['verbose'] = args.verbose
16 | genereal_params['cuda'] = torch.cuda.is_available()
17 | final_params = SimpleNamespace(**genereal_params, **data_params, **agent_params)
18 | time_start = time.time()
19 | print(final_params)
20 |
21 | #reproduce
22 | np.random.seed(final_params.seed)
23 | random.seed(final_params.seed)
24 | torch.manual_seed(final_params.seed)
25 | if final_params.cuda:
26 | torch.cuda.manual_seed(final_params.seed)
27 | torch.backends.cudnn.deterministic = True
28 | torch.backends.cudnn.benchmark = False
29 |
30 | #run
31 | multiple_run(final_params)
32 |
33 |
34 |
35 | if __name__ == "__main__":
36 | # Commandline arguments
37 | parser = argparse.ArgumentParser('CVPR Continual Learning Challenge')
38 | parser.add_argument('--general', dest='general', default='config/general.yml')
39 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml')
40 | parser.add_argument('--agent', dest='agent', default='config/agent/er.yml')
41 |
42 | parser.add_argument('--verbose', type=boolean_string, default=True,
43 | help='print information or not')
44 | args = parser.parse_args()
45 | main(args)
--------------------------------------------------------------------------------
/agents/cndpm.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from models.ndpm.ndpm import Ndpm
4 | from utils.setup_elements import transforms_match
5 | from torch.utils import data
6 | from utils.utils import maybe_cuda, AverageMeter
7 | import torch
8 |
9 |
10 | class Cndpm(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(Cndpm, self).__init__(model, opt, params)
13 | self.model = model
14 |
15 |
16 | def train_learner(self, x_train, y_train):
17 | # set up loader
18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
20 | drop_last=True)
21 | # setup tracker
22 | losses_batch = AverageMeter()
23 | acc_batch = AverageMeter()
24 |
25 | self.model.train()
26 |
27 | for ep in range(self.epoch):
28 | for i, batch_data in enumerate(train_loader):
29 | # batch update
30 | batch_x, batch_y = batch_data
31 | batch_x = maybe_cuda(batch_x, self.cuda)
32 | batch_y = maybe_cuda(batch_y, self.cuda)
33 | self.model.learn(batch_x, batch_y)
34 | if self.params.verbose:
35 | print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format(
36 | i,
37 | len(self.model.stm_x), self.params.stm_capacity,
38 | len(self.model.experts) - 1
39 | ), end='')
40 | print()
41 |
--------------------------------------------------------------------------------
/models/ndpm/priors.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 |
4 | from utils.utils import maybe_cuda
5 |
6 |
7 | class Prior(ABC):
8 | def __init__(self, params):
9 | self.params = params
10 |
11 | @abstractmethod
12 | def add_expert(self):
13 | pass
14 |
15 | @abstractmethod
16 | def record_usage(self, usage, index=None):
17 | pass
18 |
19 | @abstractmethod
20 | def nl_prior(self, normalize=False):
21 | pass
22 |
23 |
24 | class CumulativePrior(Prior):
25 | def __init__(self, params):
26 | super().__init__(params)
27 | self.log_counts = maybe_cuda(torch.tensor(
28 | params.log_alpha
29 | )).float().unsqueeze(0)
30 |
31 | def add_expert(self):
32 | self.log_counts = torch.cat(
33 | [self.log_counts, maybe_cuda(torch.zeros(1))],
34 | dim=0
35 | )
36 |
37 | def record_usage(self, usage, index=None):
38 | """Record expert usage
39 |
40 | Args:
41 | usage: Tensor of shape [K+1] if index is None else scalar
42 | index: expert index
43 | """
44 | if index is None:
45 | self.log_counts = torch.logsumexp(torch.stack([
46 | self.log_counts,
47 | usage.log()
48 | ], dim=1), dim=1)
49 | else:
50 | self.log_counts[index] = torch.logsumexp(torch.stack([
51 | self.log_counts[index],
52 | maybe_cuda(torch.tensor(usage)).float().log()
53 | ], dim=0), dim=0)
54 |
55 | def nl_prior(self, normalize=False):
56 | nl_prior = -self.log_counts
57 | if normalize:
58 | nl_prior += torch.logsumexp(self.log_counts, dim=0)
59 | return nl_prior
60 |
61 | @property
62 | def counts(self):
63 | return self.log_counts.exp()
64 |
--------------------------------------------------------------------------------
/experiment/tune_hyperparam.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | from sklearn.model_selection import ParameterGrid
3 | from utils.setup_elements import setup_opt, setup_architecture
4 | from utils.utils import maybe_cuda
5 | from utils.name_match import agents
6 | import numpy as np
7 | from experiment.metrics import compute_performance
8 |
9 |
10 | def tune_hyper(tune_data, tune_test_loaders, default_params, tune_params):
11 | param_grid_list = list(ParameterGrid(tune_params))
12 | print(len(param_grid_list))
13 | tune_accs = []
14 | tune_fgt = []
15 | for param_set in param_grid_list:
16 | final_params = vars(default_params)
17 | print(param_set)
18 | final_params.update(param_set)
19 | final_params = SimpleNamespace(**final_params)
20 | accuracy_list = []
21 | for run in range(final_params.num_runs_val):
22 | tmp_acc = []
23 | model = setup_architecture(final_params)
24 | model = maybe_cuda(model, final_params.cuda)
25 | opt = setup_opt(final_params.optimizer, model, final_params.learning_rate, final_params.weight_decay)
26 | agent = agents[final_params.agent](model, opt, final_params)
27 | for i, (x_train, y_train, labels) in enumerate(tune_data):
28 | print("-----------tune run {} task {}-------------".format(run, i))
29 | print('size: {}, {}'.format(x_train.shape, y_train.shape))
30 | agent.train_learner(x_train, y_train)
31 | acc_array = agent.evaluate(tune_test_loaders)
32 | tmp_acc.append(acc_array)
33 | print(
34 | "-----------tune run {}-----------avg_end_acc {}-----------".format(run, np.mean(tmp_acc[-1])))
35 | accuracy_list.append(np.array(tmp_acc))
36 | accuracy_list = np.array(accuracy_list)
37 | avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt = compute_performance(accuracy_list)
38 | tune_accs.append(avg_end_acc[0])
39 | tune_fgt.append(avg_end_fgt[0])
40 | best_tune = param_grid_list[tune_accs.index(max(tune_accs))]
41 | return best_tune
--------------------------------------------------------------------------------
/utils/name_match.py:
--------------------------------------------------------------------------------
1 | from agents.gdumb import Gdumb
2 | from continuum.dataset_scripts.cifar100 import CIFAR100
3 | from continuum.dataset_scripts.cifar10 import CIFAR10
4 | from continuum.dataset_scripts.core50 import CORE50
5 | from continuum.dataset_scripts.mini_imagenet import Mini_ImageNet
6 | from continuum.dataset_scripts.tiny_imagenet import TinyImageNet
7 | from continuum.dataset_scripts.openloris import OpenLORIS
8 | from agents.exp_replay import ExperienceReplay
9 | from agents.agem import AGEM
10 | from agents.ewc_pp import EWC_pp
11 | from agents.cndpm import Cndpm
12 | from agents.lwf import Lwf
13 | from agents.icarl import Icarl
14 | from agents.scr import SupContrastReplay
15 | from agents.summarize import SummarizeContrastReplay
16 | from utils.buffer.random_retrieve import Random_retrieve, BalancedRetrieve
17 | from utils.buffer.reservoir_update import Reservoir_update
18 | from utils.buffer.summarize_update import SummarizeUpdate
19 | from utils.buffer.mir_retrieve import MIR_retrieve
20 | from utils.buffer.gss_greedy_update import GSSGreedyUpdate
21 | from utils.buffer.aser_retrieve import ASER_retrieve
22 | from utils.buffer.aser_update import ASER_update
23 | from utils.buffer.sc_retrieve import Match_retrieve
24 | from utils.buffer.mem_match import MemMatch_retrieve
25 |
26 | data_objects = {
27 | 'cifar100': CIFAR100,
28 | 'cifar10': CIFAR10,
29 | 'core50': CORE50,
30 | 'mini_imagenet': Mini_ImageNet,
31 | 'tiny_imagenet': TinyImageNet,
32 | 'openloris': OpenLORIS
33 | }
34 |
35 | agents = {
36 | 'ER': ExperienceReplay,
37 | 'EWC': EWC_pp,
38 | 'AGEM': AGEM,
39 | 'CNDPM': Cndpm,
40 | 'LWF': Lwf,
41 | 'ICARL': Icarl,
42 | 'GDUMB': Gdumb,
43 | 'SCR': SupContrastReplay,
44 | 'SSCR': SummarizeContrastReplay,
45 | }
46 |
47 | retrieve_methods = {
48 | 'MIR': MIR_retrieve,
49 | 'random': Random_retrieve,
50 | 'ASER': ASER_retrieve,
51 | 'match': Match_retrieve,
52 | 'mem_match': MemMatch_retrieve,
53 | }
54 |
55 | update_methods = {
56 | 'random': Reservoir_update,
57 | 'GSS': GSSGreedyUpdate,
58 | 'ASER': ASER_update,
59 | 'summarize': SummarizeUpdate,
60 | }
61 |
62 |
--------------------------------------------------------------------------------
/models/ndpm/expert.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from models.ndpm.classifier import ResNetSharingClassifier
4 | from models.ndpm.vae import CnnSharingVae
5 | from utils.utils import maybe_cuda
6 |
7 | from utils.global_vars import *
8 |
9 |
10 | class Expert(nn.Module):
11 | def __init__(self, params, experts=()):
12 | super().__init__()
13 | self.id = len(experts)
14 | self.experts = experts
15 |
16 | self.g = maybe_cuda(CnnSharingVae(params, experts))
17 | self.d = maybe_cuda(ResNetSharingClassifier(params, experts)) if not MODELS_NDPM_NDPM_DISABLE_D else None
18 |
19 |
20 | # use random initialized g if it's a placeholder
21 | if self.id == 0:
22 | self.eval()
23 | for p in self.g.parameters():
24 | p.requires_grad = False
25 |
26 | # use random initialized d if it's a placeholder
27 | if self.id == 0 and self.d is not None:
28 | for p in self.d.parameters():
29 | p.requires_grad = False
30 |
31 | def forward(self, x):
32 | return self.d(x)
33 |
34 | def nll(self, x, y, step=None):
35 | """Negative log likelihood"""
36 | nll = self.g.nll(x, step)
37 | if self.d is not None:
38 | d_nll = self.d.nll(x, y, step)
39 | nll = nll + d_nll
40 | return nll
41 |
42 | def collect_nll(self, x, y, step=None):
43 | if self.id == 0:
44 | nll = self.nll(x, y, step)
45 | return nll.unsqueeze(1)
46 |
47 | nll = self.g.collect_nll(x, step)
48 | if self.d is not None:
49 | d_nll = self.d.collect_nll(x, y, step)
50 | nll = nll + d_nll
51 |
52 | return nll
53 |
54 | def lr_scheduler_step(self):
55 | if self.g.lr_scheduler is not NotImplemented:
56 | self.g.lr_scheduler.step()
57 | if self.d is not None and self.d.lr_scheduler is not NotImplemented:
58 | self.d.lr_scheduler.step()
59 |
60 | def clip_grad(self):
61 | self.g.clip_grad()
62 | if self.d is not None:
63 | self.d.clip_grad()
64 |
65 | def optimizer_step(self):
66 | self.g.optimizer.step()
67 | if self.d is not None:
68 | self.d.optimizer.step()
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Summarizing Stream Data for Memory-Restricted Online Continual Learning
2 |
3 | Official implementation of "[Summarizing Stream Data for Memory-Restricted Online Continual Learning](https://arxiv.org/abs/2305.16645)"
4 |
5 |
6 |

7 |
8 | ## Highlights :sparkles:
9 | - SSD is accepted by AAAI 2024!
10 | - SSD summarizes the knowledge in the stream data into informative images for the replay memory.
11 | - Through maintaining the consistency of training gradients and relationship to the past tasks, the summarized samples are more representative for the stream data compared with original images.
12 | - SSD significantly enhances the replay effects for online continual learning methods with limited extra computational overhead.
13 |
14 | ## Datasets
15 |
16 | ### Online Class Incremental
17 | - Split CIFAR100
18 | - Split Mini-ImageNet
19 | - Split Tiny-ImageNet
20 |
21 | ### Data preparation
22 | - CIFAR10 & CIFAR100 will be downloaded during the first run
23 | - Mini-ImageNet: Download from https://www.kaggle.com/whitemoon/miniimagenet/download, and place it in datasets/mini_imagenet/
24 | - Tiny-ImageNet: Download from http://cs231n.stanford.edu/tiny-imagenet-200.zip, place it in datasets/tiny-imagenet-200/. Copy `process_tiny_imagenet.py` to the directory and run it to suppress the dataset into pickles
25 |
26 | ## Run commands
27 | Detailed descriptions of options can be found in the `SSD` section in [general_main.py](general_main.py)
28 |
29 | ### Sample commands to run algorithms on Split-CIFAR100
30 | ```shell
31 | python general_main.py --data cifar100 --cl_type nc --agent SSCR --retrieve random --update summarize --mem_size 1000 --images_per_class 10 --head mlp --temp 0.07 --eps_mem_batch 100 --lr_img 4e-3 --summarize_interval 6 --queue_size 64 --mem_weight 1 --num_runs 10
32 | ```
33 |
34 | ## Acknowledgement
35 |
36 | This project is mainly based on [online-continual-learning](https://github.com/RaptorMai/online-continual-learning)
37 |
38 | ## Citation
39 |
40 | If you find this work helpful, please cite:
41 | ```
42 | @article{gu2023summarizing,
43 | title={Summarizing Stream Data for Memory-Restricted Online Continual Learning},
44 | author={Gu, Jianyang and Wang, Kai and Jiang, Wei and You, Yang},
45 | journal={arXiv preprint arXiv:2305.16645},
46 | year={2023}
47 | }
48 | ```
49 |
--------------------------------------------------------------------------------
/agents/lwf.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils.setup_elements import transforms_match
4 | from torch.utils import data
5 | from utils.utils import maybe_cuda, AverageMeter
6 | import torch
7 | import copy
8 |
9 |
10 | class Lwf(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(Lwf, self).__init__(model, opt, params)
13 |
14 | def train_learner(self, x_train, y_train):
15 | self.before_train(x_train, y_train)
16 |
17 | # set up loader
18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
20 | drop_last=True)
21 |
22 | # set up model
23 | self.model = self.model.train()
24 |
25 | # setup tracker
26 | losses_batch = AverageMeter()
27 | acc_batch = AverageMeter()
28 |
29 | for ep in range(self.epoch):
30 | for i, batch_data in enumerate(train_loader):
31 | # batch update
32 | batch_x, batch_y = batch_data
33 | batch_x = maybe_cuda(batch_x, self.cuda)
34 | batch_y = maybe_cuda(batch_y, self.cuda)
35 |
36 | logits = self.forward(batch_x)
37 | loss_old = self.kd_manager.get_kd_loss(logits, batch_x)
38 | loss_new = self.criterion(logits, batch_y)
39 | loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old
40 | _, pred_label = torch.max(logits, 1)
41 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
42 | # update tracker
43 | acc_batch.update(correct_cnt, batch_y.size(0))
44 | losses_batch.update(loss, batch_y.size(0))
45 | # backward
46 | self.opt.zero_grad()
47 | loss.backward()
48 | self.opt.step()
49 |
50 | if i % 100 == 1 and self.verbose:
51 | print(
52 | '==>>> it: {}, avg. loss: {:.6f}, '
53 | 'running train acc: {:.3f}'
54 | .format(i, losses_batch.avg(), acc_batch.avg())
55 | )
56 | self.after_train()
57 |
--------------------------------------------------------------------------------
/main_tune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils.io import load_yaml
3 | from types import SimpleNamespace
4 | from utils.utils import boolean_string
5 | import time
6 | import torch
7 | import random
8 | import numpy as np
9 | from experiment.run import multiple_run_tune_separate
10 | from utils.setup_elements import default_trick
11 |
12 | def main(args):
13 | genereal_params = load_yaml(args.general)
14 | data_params = load_yaml(args.data)
15 | default_params = load_yaml(args.default)
16 | tune_params = load_yaml(args.tune)
17 | genereal_params['verbose'] = args.verbose
18 | genereal_params['cuda'] = torch.cuda.is_available()
19 | genereal_params['train_val'] = args.train_val
20 | if args.trick:
21 | default_trick[args.trick] = True
22 | genereal_params['trick'] = default_trick
23 | final_default_params = SimpleNamespace(**genereal_params, **data_params, **default_params)
24 |
25 | time_start = time.time()
26 | print(final_default_params)
27 | print()
28 |
29 | #reproduce
30 | np.random.seed(final_default_params.seed)
31 | random.seed(final_default_params.seed)
32 | torch.manual_seed(final_default_params.seed)
33 | if final_default_params.cuda:
34 | torch.cuda.manual_seed(final_default_params.seed)
35 | torch.backends.cudnn.deterministic = True
36 | torch.backends.cudnn.benchmark = False
37 |
38 | #run
39 | multiple_run_tune_separate(final_default_params, tune_params, args.save_path)
40 |
41 |
42 |
43 | if __name__ == "__main__":
44 | # Commandline arguments
45 | parser = argparse.ArgumentParser('Continual Learning')
46 | parser.add_argument('--general', dest='general', default='config/general_1.yml')
47 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml')
48 | parser.add_argument('--default', dest='default', default='config/agent/er/er_1k.yml')
49 | parser.add_argument('--tune', dest='tune', default='config/agent/er/er_tune.yml')
50 | parser.add_argument('--save-path', dest='save_path', default=None)
51 | parser.add_argument('--verbose', type=boolean_string, default=False,
52 | help='print information or not')
53 | parser.add_argument('--train_val', type=boolean_string, default=False,
54 | help='use tha val batches to train')
55 | parser.add_argument('--trick', type=str, default=None)
56 | args = parser.parse_args()
57 | main(args)
--------------------------------------------------------------------------------
/experiment/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.stats import sem
3 | import scipy.stats as stats
4 |
5 | def compute_performance(end_task_acc_arr):
6 | """
7 | Given test accuracy results from multiple runs saved in end_task_acc_arr,
8 | compute the average accuracy, forgetting, and task accuracies as well as their confidence intervals.
9 |
10 | :param end_task_acc_arr: (list) List of lists
11 | :param task_ids: (list or tuple) Task ids to keep track of
12 | :return: (avg_end_acc, forgetting, avg_acc_task)
13 | """
14 | n_run, n_tasks = end_task_acc_arr.shape[:2]
15 | t_coef = stats.t.ppf((1+0.95) / 2, n_run-1) # t coefficient used to compute 95% CIs: mean +- t *
16 |
17 | # compute average test accuracy and CI
18 | end_acc = end_task_acc_arr[:, -1, :] # shape: (num_run, num_task)
19 | avg_acc_per_run = np.mean(end_acc, axis=1) # mean of end task accuracies per run
20 | avg_end_acc = (np.mean(avg_acc_per_run), t_coef * sem(avg_acc_per_run))
21 |
22 | # compute forgetting
23 | best_acc = np.max(end_task_acc_arr, axis=1)
24 | final_forgets = best_acc - end_acc
25 | avg_fgt = np.mean(final_forgets, axis=1)
26 | avg_end_fgt = (np.mean(avg_fgt), t_coef * sem(avg_fgt))
27 |
28 | # compute ACC
29 | acc_per_run = np.mean((np.sum(np.tril(end_task_acc_arr), axis=2) /
30 | (np.arange(n_tasks) + 1)), axis=1)
31 | avg_acc = (np.mean(acc_per_run), t_coef * sem(acc_per_run))
32 |
33 |
34 | # compute BWT+
35 | bwt_per_run = (np.sum(np.tril(end_task_acc_arr, -1), axis=(1,2)) -
36 | np.sum(np.diagonal(end_task_acc_arr, axis1=1, axis2=2) *
37 | (np.arange(n_tasks, 0, -1) - 1), axis=1)) / (n_tasks * (n_tasks - 1) / 2)
38 | bwtp_per_run = np.maximum(bwt_per_run, 0)
39 | avg_bwtp = (np.mean(bwtp_per_run), t_coef * sem(bwtp_per_run))
40 |
41 | # compute FWT
42 | fwt_per_run = np.sum(np.triu(end_task_acc_arr, 1), axis=(1,2)) / (n_tasks * (n_tasks - 1) / 2)
43 | avg_fwt = (np.mean(fwt_per_run), t_coef * sem(fwt_per_run))
44 | return avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt
45 |
46 |
47 |
48 |
49 | def single_run_avg_end_fgt(acc_array):
50 | best_acc = np.max(acc_array, axis=1)
51 | end_acc = acc_array[-1]
52 | final_forgets = best_acc - end_acc
53 | avg_fgt = np.mean(final_forgets)
54 | return avg_fgt
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/utils/buffer/reservoir_update.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Reservoir_update(object):
5 | def __init__(self, params):
6 | super().__init__()
7 |
8 | def update(self, buffer, x, y, **kwargs):
9 | batch_size = x.size(0)
10 |
11 | # add whatever still fits in the buffer
12 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index)
13 | if place_left:
14 | offset = min(place_left, batch_size)
15 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset])
16 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset])
17 |
18 |
19 | buffer.current_index += offset
20 | buffer.n_seen_so_far += offset
21 |
22 | # everything was added
23 | if offset == x.size(0):
24 | filled_idx = list(range(buffer.current_index - offset, buffer.current_index, ))
25 | if buffer.params.buffer_tracker:
26 | buffer.buffer_tracker.update_cache(buffer.buffer_label, y[:offset], filled_idx)
27 | return filled_idx
28 |
29 |
30 | #TODO: the buffer tracker will have bug when the mem size can't be divided by batch size
31 |
32 | # remove what is already in the buffer
33 | x, y = x[place_left:], y[place_left:]
34 |
35 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long()
36 | valid_indices = (indices < buffer.buffer_img.size(0)).long()
37 |
38 | idx_new_data = valid_indices.nonzero().squeeze(-1)
39 | idx_buffer = indices[idx_new_data]
40 |
41 | buffer.n_seen_so_far += x.size(0)
42 |
43 | if idx_buffer.numel() == 0:
44 | return []
45 |
46 | assert idx_buffer.max() < buffer.buffer_img.size(0)
47 | assert idx_buffer.max() < buffer.buffer_label.size(0)
48 | # assert idx_buffer.max() < self.buffer_task.size(0)
49 |
50 | assert idx_new_data.max() < x.size(0)
51 | assert idx_new_data.max() < y.size(0)
52 |
53 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))}
54 |
55 | replace_y = y[list(idx_map.values())]
56 | if buffer.params.buffer_tracker:
57 | buffer.buffer_tracker.update_cache(buffer.buffer_label, replace_y, list(idx_map.keys()))
58 | # perform overwrite op
59 | buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())]
60 | buffer.buffer_label[list(idx_map.keys())] = replace_y
61 | return list(idx_map.keys())
--------------------------------------------------------------------------------
/utils/buffer/mir_retrieve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import maybe_cuda
3 | import torch.nn.functional as F
4 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector
5 | import copy
6 |
7 |
8 | class MIR_retrieve(object):
9 | def __init__(self, params, **kwargs):
10 | super().__init__()
11 | self.params = params
12 | self.subsample = params.subsample
13 | self.num_retrieve = params.eps_mem_batch
14 |
15 | def retrieve(self, buffer, **kwargs):
16 | sub_x, sub_y = random_retrieve(buffer, self.subsample)
17 | grad_dims = []
18 | for param in buffer.model.parameters():
19 | grad_dims.append(param.data.numel())
20 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims)
21 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims)
22 | if sub_x.size(0) > 0:
23 | with torch.no_grad():
24 | logits_pre = buffer.model.forward(sub_x)
25 | logits_post = model_temp.forward(sub_x)
26 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none')
27 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none')
28 | scores = post_loss - pre_loss
29 | big_ind = scores.sort(descending=True)[1][:self.num_retrieve]
30 | return sub_x[big_ind], sub_y[big_ind]
31 | else:
32 | return sub_x, sub_y
33 |
34 | def get_future_step_parameters(self, model, grad_vector, grad_dims):
35 | """
36 | computes \theta-\delta\theta
37 | :param this_net:
38 | :param grad_vector:
39 | :return:
40 | """
41 | new_model = copy.deepcopy(model)
42 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims)
43 | with torch.no_grad():
44 | for param in new_model.parameters():
45 | if param.grad is not None:
46 | param.data = param.data - self.params.learning_rate * param.grad.data
47 | return new_model
48 |
49 | def overwrite_grad(self, pp, new_grad, grad_dims):
50 | """
51 | This is used to overwrite the gradients with a new gradient
52 | vector, whenever violations occur.
53 | pp: parameters
54 | newgrad: corrected gradient
55 | grad_dims: list storing number of parameters at each layer
56 | """
57 | cnt = 0
58 | for param in pp():
59 | param.grad = torch.zeros_like(param.data)
60 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
61 | en = sum(grad_dims[:cnt + 1])
62 | this_grad = new_grad[beg: en].contiguous().view(
63 | param.data.size())
64 | param.grad.data.copy_(this_grad)
65 | cnt += 1
--------------------------------------------------------------------------------
/utils/setup_elements.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.resnet import Reduced_ResNet18, SupConResNet
3 | from torchvision import transforms
4 | import torch.nn as nn
5 |
6 |
7 | default_trick = {'labels_trick': False, 'kd_trick': False, 'separated_softmax': False,
8 | 'review_trick': False, 'ncm_trick': False, 'kd_trick_star': False}
9 |
10 |
11 | input_size_match = {
12 | 'cifar100': [3, 32, 32],
13 | 'cifar10': [3, 32, 32],
14 | 'core50': [3, 128, 128],
15 | 'mini_imagenet': [3, 84, 84],
16 | 'tiny_imagenet': [3, 64, 64],
17 | 'openloris': [3, 50, 50]
18 | }
19 |
20 |
21 | n_classes = {
22 | 'cifar100': 100,
23 | 'cifar10': 10,
24 | 'core50': 50,
25 | 'mini_imagenet': 100,
26 | 'tiny_imagenet': 200,
27 | 'openloris': 69
28 | }
29 |
30 |
31 | transforms_match = {
32 | 'core50': transforms.Compose([
33 | transforms.ToTensor(),
34 | ]),
35 | 'cifar100': transforms.Compose([
36 | transforms.ToTensor(),
37 | ]),
38 | 'cifar10': transforms.Compose([
39 | transforms.ToTensor(),
40 | ]),
41 | 'mini_imagenet': transforms.Compose([
42 | transforms.ToTensor()]),
43 | 'tiny_imagenet': transforms.Compose([
44 | transforms.ToTensor()]),
45 | 'openloris': transforms.Compose([
46 | transforms.ToTensor()])
47 | }
48 |
49 |
50 | def setup_architecture(params):
51 | nclass = n_classes[params.data]
52 | if params.agent in ['SCR', 'SCP', 'SSCR']:
53 | if params.data == 'mini_imagenet' or params.data == 'tiny_imagenet':
54 | return SupConResNet(640, head=params.head)
55 | return SupConResNet(head=params.head)
56 | if params.agent == 'CNDPM':
57 | from models.ndpm.ndpm import Ndpm
58 | return Ndpm(params)
59 | if params.data == 'cifar100':
60 | return Reduced_ResNet18(nclass)
61 | elif params.data == 'cifar10':
62 | return Reduced_ResNet18(nclass)
63 | elif params.data == 'core50':
64 | model = Reduced_ResNet18(nclass)
65 | model.linear = nn.Linear(2560, nclass, bias=True)
66 | return model
67 | elif params.data == 'mini_imagenet' or params.data == 'tiny_imagenet':
68 | model = Reduced_ResNet18(nclass)
69 | model.linear = nn.Linear(640, nclass, bias=True)
70 | return model
71 | elif params.data == 'openloris':
72 | return Reduced_ResNet18(nclass)
73 |
74 |
75 | def setup_opt(optimizer, model, lr, wd):
76 | if optimizer == 'SGD':
77 | optim = torch.optim.SGD(model.parameters(),
78 | lr=lr,
79 | weight_decay=wd)
80 | elif optimizer == 'Adam':
81 | optim = torch.optim.Adam(model.parameters(),
82 | lr=lr,
83 | betas=(0.9, 0.99),
84 | weight_decay=wd)
85 | else:
86 | raise Exception('wrong optimizer name')
87 | return optim
88 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/cifar100.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets
3 | from continuum.data_utils import create_task_composition, load_task_with_labels
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 |
8 | class CIFAR100(DatasetBase):
9 | def __init__(self, scenario, params):
10 | dataset = 'cifar100'
11 | if scenario == 'ni':
12 | num_tasks = len(params.ns_factor)
13 | else:
14 | num_tasks = params.num_tasks
15 | super(CIFAR100, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
16 |
17 |
18 | def download_load(self):
19 | dataset_train = datasets.CIFAR100(root=self.root, train=True, download=True)
20 | self.train_data = dataset_train.data
21 | self.train_label = np.array(dataset_train.targets)
22 | dataset_test = datasets.CIFAR100(root=self.root, train=False, download=True)
23 | self.test_data = dataset_test.data
24 | self.test_label = np.array(dataset_test.targets)
25 |
26 | def setup(self):
27 | if self.scenario == 'ni':
28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
29 | self.train_label,
30 | self.test_data, self.test_label,
31 | self.task_nums, 32,
32 | self.params.val_size,
33 | self.params.ns_type, self.params.ns_factor,
34 | plot=self.params.plot_sample)
35 | elif self.scenario == 'nc':
36 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, fixed_order=self.params.fix_order)
37 | self.test_set = []
38 | for labels in self.task_labels:
39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
40 | self.test_set.append((x_test, y_test))
41 | else:
42 | raise Exception('wrong scenario')
43 |
44 | def new_task(self, cur_task, **kwargs):
45 | if self.scenario == 'ni':
46 | x_train, y_train = self.train_set[cur_task]
47 | labels = set(y_train)
48 | elif self.scenario == 'nc':
49 | labels = self.task_labels[cur_task]
50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
51 | return x_train, y_train, labels
52 |
53 | def new_run(self, **kwargs):
54 | self.setup()
55 | return self.test_set
56 |
57 | def test_plot(self):
58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
59 | self.params.ns_factor)
60 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/cifar10.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets
3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 |
8 | class CIFAR10(DatasetBase):
9 | def __init__(self, scenario, params):
10 | dataset = 'cifar10'
11 | if scenario == 'ni':
12 | num_tasks = len(params.ns_factor)
13 | else:
14 | num_tasks = params.num_tasks
15 | super(CIFAR10, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
16 |
17 |
18 | def download_load(self):
19 | dataset_train = datasets.CIFAR10(root=self.root, train=True, download=True)
20 | self.train_data = dataset_train.data
21 | self.train_label = np.array(dataset_train.targets)
22 | dataset_test = datasets.CIFAR10(root=self.root, train=False, download=True)
23 | self.test_data = dataset_test.data
24 | self.test_label = np.array(dataset_test.targets)
25 |
26 | def setup(self):
27 | if self.scenario == 'ni':
28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
29 | self.train_label,
30 | self.test_data, self.test_label,
31 | self.task_nums, 32,
32 | self.params.val_size,
33 | self.params.ns_type, self.params.ns_factor,
34 | plot=self.params.plot_sample)
35 | elif self.scenario == 'nc':
36 | self.task_labels = create_task_composition(class_nums=10, num_tasks=self.task_nums, fixed_order=self.params.fix_order)
37 | self.test_set = []
38 | for labels in self.task_labels:
39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
40 | self.test_set.append((x_test, y_test))
41 | else:
42 | raise Exception('wrong scenario')
43 |
44 | def new_task(self, cur_task, **kwargs):
45 | if self.scenario == 'ni':
46 | x_train, y_train = self.train_set[cur_task]
47 | labels = set(y_train)
48 | elif self.scenario == 'nc':
49 | labels = self.task_labels[cur_task]
50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
51 | return x_train, y_train, labels
52 |
53 | def new_run(self, **kwargs):
54 | self.setup()
55 | return self.test_set
56 |
57 | def test_plot(self):
58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
59 | self.params.ns_factor)
60 |
--------------------------------------------------------------------------------
/agents/icarl.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils import utils
4 | from utils.buffer.buffer_utils import random_retrieve
5 | from utils.setup_elements import transforms_match
6 | from torch.utils import data
7 | import numpy as np
8 | from torch.nn import functional as F
9 | from utils.utils import maybe_cuda, AverageMeter
10 | from utils.buffer.buffer import Buffer
11 | import torch
12 | import copy
13 |
14 |
15 | class Icarl(ContinualLearner):
16 | def __init__(self, model, opt, params):
17 | super(Icarl, self).__init__(model, opt, params)
18 | self.model = model
19 | self.mem_size = params.mem_size
20 | self.buffer = Buffer(model, params)
21 | self.prev_model = None
22 |
23 | def train_learner(self, x_train, y_train):
24 | self.before_train(x_train, y_train)
25 | # set up loader
26 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
27 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
28 | drop_last=True)
29 | self.model.train()
30 | self.update_representation(train_loader)
31 | self.prev_model = copy.deepcopy(self.model)
32 | self.after_train()
33 |
34 | def update_representation(self, train_loader):
35 | updated_idx = []
36 | for ep in range(self.epoch):
37 | for i, train_data in enumerate(train_loader):
38 | # batch update
39 | train_x, train_y = train_data
40 | train_x = maybe_cuda(train_x, self.cuda)
41 | train_y = maybe_cuda(train_y, self.cuda)
42 | train_y_copy = train_y.clone()
43 | for k, y in enumerate(train_y_copy):
44 | train_y_copy[k] = len(self.old_labels) + self.new_labels.index(y)
45 | all_cls_num = len(self.new_labels) + len(self.old_labels)
46 | target_labels = utils.ohe_label(train_y_copy, all_cls_num, device=train_y_copy.device).float()
47 | if self.prev_model is not None:
48 | mem_x, mem_y = random_retrieve(self.buffer, self.batch,
49 | excl_indices=updated_idx)
50 | mem_x = maybe_cuda(mem_x, self.cuda)
51 | batch_x = torch.cat([train_x, mem_x])
52 | target_labels = torch.cat([target_labels, torch.zeros_like(target_labels)])
53 | else:
54 | batch_x = train_x
55 | logits = self.forward(batch_x)
56 | self.opt.zero_grad()
57 | if self.prev_model is not None:
58 | with torch.no_grad():
59 | q = torch.sigmoid(self.prev_model.forward(batch_x))
60 | for k, y in enumerate(self.old_labels):
61 | target_labels[:, k] = q[:, k]
62 | loss = F.binary_cross_entropy_with_logits(logits[:, :all_cls_num], target_labels, reduction='none').sum(dim=1).mean()
63 | loss.backward()
64 | self.opt.step()
65 | updated_idx += self.buffer.update(train_x, train_y)
66 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 | TEST_SPLIT = 1 / 6
8 |
9 |
10 | class TinyImageNet(DatasetBase):
11 | def __init__(self, scenario, params):
12 | dataset = 'tiny_imagenet'
13 | if scenario == 'ni':
14 | num_tasks = len(params.ns_factor)
15 | else:
16 | num_tasks = params.num_tasks
17 | super(TinyImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
18 |
19 | def download_load(self):
20 | train_dir = './datasets/tiny-imagenet-200/train.pkl'
21 | test_dir = './datasets/tiny-imagenet-200/val.pkl'
22 |
23 | train = pickle.load(open(train_dir, 'rb'))
24 | self.train_data = train['data'].reshape((100000, 64, 64, 3))
25 | self.train_label = train['target']
26 |
27 | test = pickle.load(open(test_dir, 'rb'))
28 | self.test_data = test['data'].reshape((10000, 64, 64, 3))
29 | self.test_label = test['target']
30 |
31 | def new_run(self, **kwargs):
32 | self.setup()
33 | return self.test_set
34 |
35 | def new_task(self, cur_task, **kwargs):
36 | if self.scenario == 'ni':
37 | x_train, y_train = self.train_set[cur_task]
38 | labels = set(y_train)
39 | elif self.scenario == 'nc':
40 | labels = self.task_labels[cur_task]
41 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
42 | else:
43 | raise Exception('unrecognized scenario')
44 | return x_train, y_train, labels
45 |
46 | def setup(self):
47 | if self.scenario == 'ni':
48 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
49 | self.train_label,
50 | self.test_data, self.test_label,
51 | self.task_nums, 84,
52 | self.params.val_size,
53 | self.params.ns_type, self.params.ns_factor,
54 | plot=self.params.plot_sample)
55 |
56 | elif self.scenario == 'nc':
57 | self.task_labels = create_task_composition(class_nums=200, num_tasks=self.task_nums,
58 | fixed_order=self.params.fix_order)
59 | self.test_set = []
60 | for labels in self.task_labels:
61 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
62 | self.test_set.append((x_test, y_test))
63 |
64 | def test_plot(self):
65 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
66 | self.params.ns_factor)
67 |
--------------------------------------------------------------------------------
/agents/scr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | from utils.buffer.buffer import Buffer, DynamicBuffer
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform, BalancedSampler
6 | from utils.setup_elements import transforms_match, input_size_match
7 | from utils.utils import maybe_cuda, AverageMeter
8 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale
9 | import torch.nn as nn
10 |
11 | class SupContrastReplay(ContinualLearner):
12 | def __init__(self, model, opt, params):
13 | super(SupContrastReplay, self).__init__(model, opt, params)
14 | self.buffer = Buffer(model, params)
15 | self.mem_size = params.mem_size
16 | self.eps_mem_batch = params.eps_mem_batch
17 | self.mem_iters = params.mem_iters
18 | self.transform = nn.Sequential(
19 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)),
20 | RandomHorizontalFlip(),
21 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
22 | RandomGrayscale(p=0.2)
23 |
24 | )
25 |
26 | def train_learner(self, x_train, y_train, **kwargs):
27 | self.before_train(x_train, y_train)
28 | #self.buffer.new_condense_task()
29 | self.buffer.new_task()
30 | # set up loader
31 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
32 | train_sampler = BalancedSampler(x_train, y_train, self.batch)
33 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, num_workers=0,
34 | drop_last=True, sampler=train_sampler)
35 | # set up model
36 | self.model = self.model.train()
37 |
38 | # setup tracker
39 | losses = AverageMeter()
40 | acc_batch = AverageMeter()
41 |
42 | for ep in range(self.epoch):
43 | for i, batch_data in enumerate(train_loader):
44 | # batch update
45 | batch_x, batch_y = batch_data
46 | batch_x = maybe_cuda(batch_x, self.cuda)
47 | batch_y = maybe_cuda(batch_y, self.cuda)
48 |
49 | for j in range(self.mem_iters):
50 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
51 |
52 | if mem_x.size(0) > 0:
53 | mem_x = maybe_cuda(mem_x, self.cuda)
54 | mem_y = maybe_cuda(mem_y, self.cuda)
55 | combined_batch = torch.cat((mem_x, batch_x))
56 | combined_labels = torch.cat((mem_y, batch_y))
57 | combined_batch_aug = self.transform(combined_batch)
58 | features = torch.cat([self.model.forward(combined_batch).unsqueeze(1), self.model.forward(combined_batch_aug).unsqueeze(1)], dim=1)
59 | loss = self.criterion(features, combined_labels)
60 | losses.update(loss, batch_y.size(0))
61 | self.opt.zero_grad()
62 | loss.backward()
63 | self.opt.step()
64 |
65 | # update mem
66 | self.buffer.update(batch_x, batch_y)
67 | if i % 100 == 1 and self.verbose:
68 | print(
69 | '==>>> it: {}, avg. loss: {:.6f}, '
70 | .format(i, losses.avg(), acc_batch.avg())
71 | )
72 | self.after_train()
73 |
--------------------------------------------------------------------------------
/agents/gdumb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | import math
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform
6 | from utils.setup_elements import transforms_match, setup_architecture, setup_opt
7 | from utils.utils import maybe_cuda, EarlyStopping
8 | import numpy as np
9 | import random
10 |
11 |
12 | class Gdumb(ContinualLearner):
13 | def __init__(self, model, opt, params):
14 | super(Gdumb, self).__init__(model, opt, params)
15 | self.mem_img = {}
16 | self.mem_c = {}
17 | #self.early_stopping = EarlyStopping(self.params.min_delta, self.params.patience, self.params.cumulative_delta)
18 |
19 | def greedy_balancing_update(self, x, y):
20 | k_c = self.params.mem_size // max(1, len(self.mem_img))
21 | if y not in self.mem_img or self.mem_c[y] < k_c:
22 | if sum(self.mem_c.values()) >= self.params.mem_size:
23 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0]
24 | idx = random.randrange(self.mem_c[cls_max])
25 | self.mem_img[cls_max].pop(idx)
26 | self.mem_c[cls_max] -= 1
27 | if y not in self.mem_img:
28 | self.mem_img[y] = []
29 | self.mem_c[y] = 0
30 | self.mem_img[y].append(x)
31 | self.mem_c[y] += 1
32 |
33 | def train_learner(self, x_train, y_train):
34 | self.before_train(x_train, y_train)
35 | # set up loader
36 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
37 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
38 | drop_last=True)
39 |
40 | for i, batch_data in enumerate(train_loader):
41 | # batch update
42 | batch_x, batch_y = batch_data
43 | batch_x = maybe_cuda(batch_x, self.cuda)
44 | batch_y = maybe_cuda(batch_y, self.cuda)
45 | # update mem
46 | for j in range(len(batch_x)):
47 | self.greedy_balancing_update(batch_x[j], batch_y[j].item())
48 | #self.early_stopping.reset()
49 | self.train_mem()
50 | self.after_train()
51 |
52 | def train_mem(self):
53 | mem_x = []
54 | mem_y = []
55 | for i in self.mem_img.keys():
56 | mem_x += self.mem_img[i]
57 | mem_y += [i] * self.mem_c[i]
58 |
59 | mem_x = torch.stack(mem_x)
60 | mem_y = torch.LongTensor(mem_y)
61 | self.model = setup_architecture(self.params)
62 | self.model = maybe_cuda(self.model, self.cuda)
63 | opt = setup_opt(self.params.optimizer, self.model, self.params.learning_rate, self.params.weight_decay)
64 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1, T_mult=2, eta_min=self.params.minlr)
65 |
66 | #loss = math.inf
67 | for i in range(self.params.mem_epoch):
68 | idx = np.random.permutation(len(mem_x)).tolist()
69 | mem_x = maybe_cuda(mem_x[idx], self.cuda)
70 | mem_y = maybe_cuda(mem_y[idx], self.cuda)
71 | self.model = self.model.train()
72 | batch_size = self.params.batch
73 | #scheduler.step()
74 | #if opt.param_groups[0]['lr'] == self.params.learning_rate:
75 | # if self.early_stopping.step(-loss):
76 | # return
77 | for j in range(len(mem_y) // batch_size):
78 | opt.zero_grad()
79 | logits = self.model.forward(mem_x[batch_size * j:batch_size * (j + 1)])
80 | loss = self.criterion(logits, mem_y[batch_size * j:batch_size * (j + 1)])
81 | loss.backward()
82 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip)
83 | opt.step()
84 |
85 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Yonglong Tian (yonglong@mit.edu)
3 | Date: May 07, 2020
4 | """
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class SupConLoss(nn.Module):
12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
13 | It also supports the unsupervised contrastive loss in SimCLR"""
14 | def __init__(self, temperature=0.07, contrast_mode='all'):
15 | super(SupConLoss, self).__init__()
16 | self.temperature = temperature
17 | self.contrast_mode = contrast_mode
18 |
19 | def forward(self, features, labels=None, mask=None):
20 | """Compute loss for model. If both `labels` and `mask` are None,
21 | it degenerates to SimCLR unsupervised loss:
22 | https://arxiv.org/pdf/2002.05709.pdf
23 |
24 | Args:
25 | features: hidden vector of shape [bsz, n_views, ...].
26 | labels: ground truth of shape [bsz].
27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
28 | has the same class as sample i. Can be asymmetric.
29 | Returns:
30 | A loss scalar.
31 | """
32 | device = (torch.device('cuda')
33 | if features.is_cuda
34 | else torch.device('cpu'))
35 |
36 | if len(features.shape) < 3:
37 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
38 | 'at least 3 dimensions are required')
39 | if len(features.shape) > 3:
40 | features = features.view(features.shape[0], features.shape[1], -1)
41 |
42 | batch_size = features.shape[0]
43 | if labels is not None and mask is not None:
44 | raise ValueError('Cannot define both `labels` and `mask`')
45 | elif labels is None and mask is None:
46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
47 | elif labels is not None:
48 | labels = labels.contiguous().view(-1, 1)
49 | if labels.shape[0] != batch_size:
50 | raise ValueError('Num of labels does not match num of features')
51 | mask = torch.eq(labels, labels.T).float().to(device)
52 | else:
53 | mask = mask.float().to(device)
54 |
55 | contrast_count = features.shape[1]
56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
57 | if self.contrast_mode == 'one':
58 | anchor_feature = features[:, 0]
59 | anchor_count = 1
60 | elif self.contrast_mode == 'all':
61 | anchor_feature = contrast_feature
62 | anchor_count = contrast_count
63 | else:
64 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
65 |
66 | # compute logits
67 | anchor_dot_contrast = torch.div(
68 | torch.matmul(anchor_feature, contrast_feature.T),
69 | self.temperature)
70 | # for numerical stability
71 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
72 | logits = anchor_dot_contrast - logits_max.detach()
73 |
74 | # tile mask
75 | mask = mask.repeat(anchor_count, contrast_count)
76 | # mask-out self-contrast cases
77 | logits_mask = torch.scatter(
78 | torch.ones_like(mask),
79 | 1,
80 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
81 | 0
82 | )
83 | mask = mask * logits_mask
84 |
85 | # compute log_prob
86 | exp_logits = torch.exp(logits) * logits_mask
87 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
88 |
89 | # compute mean of log-likelihood over positive
90 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
91 |
92 | # loss
93 | loss = -1 * mean_log_prob_pos
94 | loss = loss.view(anchor_count, batch_size).mean()
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/agents/summarize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils import data
4 | from utils.buffer.buffer import Buffer, DynamicBuffer
5 | from agents.base import ContinualLearner
6 | from continuum.data_utils import dataset_transform, BalancedSampler
7 | from utils.setup_elements import transforms_match, input_size_match
8 | from utils.utils import maybe_cuda, AverageMeter
9 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale
10 | import torch.nn as nn
11 | from torchvision.utils import make_grid, save_image
12 |
13 |
14 | class SummarizeContrastReplay(ContinualLearner):
15 | def __init__(self, model, opt, params):
16 | super(SummarizeContrastReplay, self).__init__(model, opt, params)
17 | self.buffer = DynamicBuffer(model, params)
18 | self.mem_size = params.mem_size
19 | self.eps_mem_batch = params.eps_mem_batch
20 | self.mem_iters = params.mem_iters
21 | self.transform = nn.Sequential(
22 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)),
23 | RandomHorizontalFlip(),
24 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
25 | RandomGrayscale(p=0.2)
26 |
27 | )
28 | self.queue_size = params.queue_size
29 |
30 | def train_learner(self, x_train, y_train, labels):
31 | self.before_train(x_train, y_train)
32 | # set up loader
33 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
34 | train_sampler = BalancedSampler(x_train, y_train, self.batch)
35 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, num_workers=0,
36 | drop_last=True, sampler=train_sampler)
37 | # set up model
38 | self.model = self.model.train()
39 | self.buffer.new_condense_task(labels)
40 |
41 | # setup tracker
42 | losses = AverageMeter()
43 | acc_batch = AverageMeter()
44 |
45 | aff_x = []
46 | aff_y = []
47 | for ep in range(self.epoch):
48 | for i, batch_data in enumerate(train_loader):
49 | # batch update
50 | batch_x, batch_y = batch_data
51 | batch_x = maybe_cuda(batch_x, self.cuda)
52 | batch_y = maybe_cuda(batch_y, self.cuda)
53 |
54 | for j in range(self.mem_iters):
55 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
56 |
57 | if mem_x.size(0) > 0:
58 | mem_x = maybe_cuda(mem_x, self.cuda)
59 | mem_y = maybe_cuda(mem_y, self.cuda)
60 | combined_batch = torch.cat((mem_x, batch_x))
61 | combined_labels = torch.cat((mem_y, batch_y))
62 | combined_batch_aug = self.transform(combined_batch)
63 | features = torch.cat([self.model.forward(combined_batch).unsqueeze(1), self.model.forward(combined_batch_aug).unsqueeze(1)], dim=1)
64 | loss = self.criterion(features, combined_labels)
65 | losses.update(loss, batch_y.size(0))
66 | self.opt.zero_grad()
67 | loss.backward()
68 | self.opt.step()
69 |
70 | # update memory
71 | aff_x.append(batch_x)
72 | aff_y.append(batch_y)
73 | if len(aff_x) > self.queue_size:
74 | aff_x.pop(0)
75 | aff_y.pop(0)
76 | self.buffer.update(batch_x, batch_y, aff_x=aff_x, aff_y=aff_y, update_index=i, transform=self.transform)
77 |
78 | if i % 100 == 1 and self.verbose:
79 | print(
80 | '==>>> it: {}, avg. loss: {:.6f}, '
81 | .format(i, losses.avg(), acc_batch.avg())
82 | )
83 | self.after_train()
84 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 | TEST_SPLIT = 1 / 6
8 |
9 |
10 | class Mini_ImageNet(DatasetBase):
11 | def __init__(self, scenario, params):
12 | dataset = 'mini_imagenet'
13 | if scenario == 'ni':
14 | num_tasks = len(params.ns_factor)
15 | else:
16 | num_tasks = params.num_tasks
17 | super(Mini_ImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
18 |
19 |
20 | def download_load(self):
21 | train_in = open("./datasets/mini_imagenet/mini-imagenet-cache-train.pkl", "rb")
22 | train = pickle.load(train_in)
23 | train_x = train["image_data"].reshape([64, 600, 84, 84, 3])
24 | val_in = open("./datasets/mini_imagenet/mini-imagenet-cache-val.pkl", "rb")
25 | val = pickle.load(val_in)
26 | val_x = val['image_data'].reshape([16, 600, 84, 84, 3])
27 | test_in = open("./datasets/mini_imagenet/mini-imagenet-cache-test.pkl", "rb")
28 | test = pickle.load(test_in)
29 | test_x = test['image_data'].reshape([20, 600, 84, 84, 3])
30 | all_data = np.vstack((train_x, val_x, test_x))
31 | train_data = []
32 | train_label = []
33 | test_data = []
34 | test_label = []
35 | for i in range(len(all_data)):
36 | cur_x = all_data[i]
37 | cur_y = np.ones((600,)) * i
38 | rdm_x, rdm_y = shuffle_data(cur_x, cur_y)
39 | x_test = rdm_x[: int(600 * TEST_SPLIT)]
40 | y_test = rdm_y[: int(600 * TEST_SPLIT)]
41 | x_train = rdm_x[int(600 * TEST_SPLIT):]
42 | y_train = rdm_y[int(600 * TEST_SPLIT):]
43 | train_data.append(x_train)
44 | train_label.append(y_train)
45 | test_data.append(x_test)
46 | test_label.append(y_test)
47 | self.train_data = np.concatenate(train_data)
48 | self.train_label = np.concatenate(train_label)
49 | self.test_data = np.concatenate(test_data)
50 | self.test_label = np.concatenate(test_label)
51 |
52 | def new_run(self, **kwargs):
53 | self.setup()
54 | return self.test_set
55 |
56 | def new_task(self, cur_task, **kwargs):
57 | if self.scenario == 'ni':
58 | x_train, y_train = self.train_set[cur_task]
59 | labels = set(y_train)
60 | elif self.scenario == 'nc':
61 | labels = self.task_labels[cur_task]
62 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
63 | else:
64 | raise Exception('unrecognized scenario')
65 | return x_train, y_train, labels
66 |
67 | def setup(self):
68 | if self.scenario == 'ni':
69 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
70 | self.train_label,
71 | self.test_data, self.test_label,
72 | self.task_nums, 84,
73 | self.params.val_size,
74 | self.params.ns_type, self.params.ns_factor,
75 | plot=self.params.plot_sample)
76 |
77 | elif self.scenario == 'nc':
78 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums,
79 | fixed_order=self.params.fix_order)
80 | self.test_set = []
81 | for labels in self.task_labels:
82 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
83 | self.test_set.append((x_test, y_test))
84 |
85 | def test_plot(self):
86 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
87 | self.params.ns_factor)
88 |
--------------------------------------------------------------------------------
/models/ndpm/component.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 | import torch.nn as nn
4 | from typing import Tuple
5 |
6 | from utils.utils import maybe_cuda
7 | from utils.global_vars import *
8 |
9 |
10 | class Component(nn.Module, ABC):
11 | def __init__(self, params, experts: Tuple):
12 | super().__init__()
13 | self.params = params
14 | self.experts = experts
15 |
16 | self.optimizer = NotImplemented
17 | self.lr_scheduler = NotImplemented
18 |
19 | @abstractmethod
20 | def nll(self, x, y, step=None):
21 | """Return NLL"""
22 | pass
23 |
24 | @abstractmethod
25 | def collect_nll(self, x, y, step=None):
26 | """Return NLLs including previous experts"""
27 | pass
28 |
29 | def _clip_grad_value(self, clip_value):
30 | for group in self.optimizer.param_groups:
31 | nn.utils.clip_grad_value_(group['params'], clip_value)
32 |
33 | def _clip_grad_norm(self, max_norm, norm_type=2):
34 | for group in self.optimizer.param_groups:
35 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type)
36 |
37 | def clip_grad(self):
38 | clip_grad_config = MODELS_NDPM_COMPONENT_CLIP_GRAD
39 | if clip_grad_config['type'] == 'value':
40 | self._clip_grad_value(**clip_grad_config['options'])
41 | elif clip_grad_config['type'] == 'norm':
42 | self._clip_grad_norm(**clip_grad_config['options'])
43 | else:
44 | raise ValueError('Invalid clip_grad type: {}'
45 | .format(clip_grad_config['type']))
46 |
47 | @staticmethod
48 | def build_optimizer(optim_config, params):
49 | return getattr(torch.optim, optim_config['type'])(
50 | params, **optim_config['options'])
51 |
52 | @staticmethod
53 | def build_lr_scheduler(lr_config, optimizer):
54 | return getattr(torch.optim.lr_scheduler, lr_config['type'])(
55 | optimizer, **lr_config['options'])
56 |
57 | def weight_decay_loss(self):
58 | loss = maybe_cuda(torch.zeros([]))
59 | for param in self.parameters():
60 | loss += torch.norm(param) ** 2
61 | return loss
62 |
63 |
64 | class ComponentG(Component, ABC):
65 | def setup_optimizer(self):
66 | self.optimizer = self.build_optimizer(
67 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters())
68 | self.lr_scheduler = self.build_lr_scheduler(
69 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G, self.optimizer)
70 |
71 | def collect_nll(self, x, y=None, step=None):
72 | """Default `collect_nll`
73 |
74 | Warning: Parameter-sharing components should implement their own
75 | `collect_nll`
76 |
77 | Returns:
78 | nll: Tensor of shape [B, 1+K]
79 | """
80 | outputs = [expert.g.nll(x, y, step) for expert in self.experts]
81 | nll = outputs
82 | output = self.nll(x, y, step)
83 | nll.append(output)
84 | return torch.stack(nll, dim=1)
85 |
86 |
87 |
88 | class ComponentD(Component, ABC):
89 | def setup_optimizer(self):
90 | self.optimizer = self.build_optimizer(
91 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters())
92 | self.lr_scheduler = self.build_lr_scheduler(
93 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D, self.optimizer)
94 |
95 | def collect_forward(self, x):
96 | """Default `collect_forward`
97 |
98 | Warning: Parameter-sharing components should implement their own
99 | `collect_forward`
100 |
101 | Returns:
102 | output: Tensor of shape [B, 1+K, C]
103 | """
104 | outputs = [expert.d(x) for expert in self.experts]
105 | outputs.append(self.forward(x))
106 | return torch.stack(outputs, 1)
107 |
108 | def collect_nll(self, x, y, step=None):
109 | """Default `collect_nll`
110 |
111 | Warning: Parameter-sharing components should implement their own
112 | `collect_nll`
113 |
114 | Returns:
115 | nll: Tensor of shape [B, 1+K]
116 | """
117 | outputs = [expert.d.nll(x, y, step) for expert in self.experts]
118 | nll = outputs
119 | output = self.nll(x, y, step)
120 | nll.append(output)
121 | return torch.stack(nll, dim=1)
122 |
--------------------------------------------------------------------------------
/utils/buffer/aser_retrieve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.buffer.buffer_utils import random_retrieve, ClassBalancedRandomSampling
3 | from utils.buffer.aser_utils import compute_knn_sv
4 | from utils.utils import maybe_cuda
5 | from utils.setup_elements import n_classes
6 |
7 |
8 | class ASER_retrieve(object):
9 | def __init__(self, params, **kwargs):
10 | super().__init__()
11 | self.num_retrieve = params.eps_mem_batch
12 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
13 | self.k = params.k
14 | self.mem_size = params.mem_size
15 | self.aser_type = params.aser_type
16 | self.n_smp_cls = int(params.n_smp_cls)
17 | self.out_dim = n_classes[params.data]
18 | self.is_aser_upt = params.update == "ASER"
19 | ClassBalancedRandomSampling.class_index_cache = None
20 |
21 | def retrieve(self, buffer, **kwargs):
22 | model = buffer.model
23 |
24 | if buffer.n_seen_so_far <= self.mem_size:
25 | # Use random retrieval until buffer is filled
26 | ret_x, ret_y = random_retrieve(buffer, self.num_retrieve)
27 | else:
28 | # Use ASER retrieval if buffer is filled
29 | cur_x, cur_y = kwargs['x'], kwargs['y']
30 | buffer_x, buffer_y = buffer.buffer_img, buffer.buffer_label
31 | ret_x, ret_y = self._retrieve_by_knn_sv(model, buffer_x, buffer_y, cur_x, cur_y, self.num_retrieve)
32 | return ret_x, ret_y
33 |
34 | def _retrieve_by_knn_sv(self, model, buffer_x, buffer_y, cur_x, cur_y, num_retrieve):
35 | """
36 | Retrieves data instances with top-N Shapley Values from candidate set.
37 | Args:
38 | model (object): neural network.
39 | buffer_x (tensor): data buffer.
40 | buffer_y (tensor): label buffer.
41 | cur_x (tensor): current input data tensor.
42 | cur_y (tensor): current input label tensor.
43 | num_retrieve (int): number of data instances to be retrieved.
44 | Returns
45 | ret_x (tensor): retrieved data tensor.
46 | ret_y (tensor): retrieved label tensor.
47 | """
48 | cur_x = maybe_cuda(cur_x)
49 | cur_y = maybe_cuda(cur_y)
50 |
51 | # Reset and update ClassBalancedRandomSampling cache if ASER update is not enabled
52 | if not self.is_aser_upt:
53 | ClassBalancedRandomSampling.update_cache(buffer_y, self.out_dim)
54 |
55 | # Get candidate data for retrieval (i.e., cand <- class balanced subsamples from memory)
56 | cand_x, cand_y, cand_ind = \
57 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, device=self.device)
58 |
59 | # Type 1 - Adversarial SV
60 | # Get evaluation data for type 1 (i.e., eval <- current input)
61 | eval_adv_x, eval_adv_y = cur_x, cur_y
62 | # Compute adversarial Shapley value of candidate data
63 | # (i.e., sv wrt current input)
64 | sv_matrix_adv = compute_knn_sv(model, eval_adv_x, eval_adv_y, cand_x, cand_y, self.k, device=self.device)
65 |
66 | if self.aser_type != "neg_sv":
67 | # Type 2 - Cooperative SV
68 | # Get evaluation data for type 2
69 | # (i.e., eval <- class balanced subsamples from memory excluding those already in candidate set)
70 | excl_indices = set(cand_ind.tolist())
71 | eval_coop_x, eval_coop_y, _ = \
72 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls,
73 | excl_indices=excl_indices, device=self.device)
74 | # Compute Shapley value
75 | sv_matrix_coop = \
76 | compute_knn_sv(model, eval_coop_x, eval_coop_y, cand_x, cand_y, self.k, device=self.device)
77 | if self.aser_type == "asv":
78 | # Use extremal SVs for computation
79 | sv = sv_matrix_coop.max(0).values - sv_matrix_adv.min(0).values
80 | else:
81 | # Use mean variation for aser_type == "asvm" or anything else
82 | sv = sv_matrix_coop.mean(0) - sv_matrix_adv.mean(0)
83 | else:
84 | # aser_type == "neg_sv"
85 | # No Type 1 - Cooperative SV; Use sum of Adversarial SV only
86 | sv = sv_matrix_adv.sum(0) * -1
87 |
88 | ret_ind = sv.argsort(descending=True)
89 |
90 | ret_x = cand_x[ret_ind][:num_retrieve]
91 | ret_y = cand_y[ret_ind][:num_retrieve]
92 | return ret_x, ret_y
93 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def maybe_cuda(what, use_cuda=True, **kw):
5 | """
6 | Moves `what` to CUDA and returns it, if `use_cuda` and it's available.
7 | Args:
8 | what (object): any object to move to eventually gpu
9 | use_cuda (bool): if we want to use gpu or cpu.
10 | Returns
11 | object: the same object but eventually moved to gpu.
12 | """
13 |
14 | if use_cuda is not False and torch.cuda.is_available():
15 | what = what.cuda()
16 | return what
17 |
18 |
19 | def boolean_string(s):
20 | if s not in {'False', 'True'}:
21 | raise ValueError('Not a valid boolean string')
22 | return s == 'True'
23 |
24 |
25 | class AverageMeter(object):
26 | """Computes and stores the average and current value"""
27 |
28 | def __init__(self):
29 | self.reset()
30 |
31 | def reset(self):
32 | self.sum = 0
33 | self.count = 0
34 |
35 | def update(self, val, n):
36 | self.sum += val * n
37 | self.count += n
38 |
39 | def avg(self):
40 | if self.count == 0:
41 | return 0
42 | return float(self.sum) / self.count
43 |
44 |
45 | def mini_batch_deep_features(model, total_x, num):
46 | """
47 | Compute deep features with mini-batches.
48 | Args:
49 | model (object): neural network.
50 | total_x (tensor): data tensor.
51 | num (int): number of data.
52 | Returns
53 | deep_features (tensor): deep feature representation of data tensor.
54 | """
55 | is_train = False
56 | if model.training:
57 | is_train = True
58 | model.eval()
59 | if hasattr(model, "features"):
60 | model_has_feature_extractor = True
61 | else:
62 | model_has_feature_extractor = False
63 | # delete the last fully connected layer
64 | modules = list(model.children())[:-1]
65 | # make feature extractor
66 | model_features = torch.nn.Sequential(*modules)
67 |
68 | with torch.no_grad():
69 | bs = 64
70 | num_itr = num // bs + int(num % bs > 0)
71 | sid = 0
72 | deep_features_list = []
73 | for i in range(num_itr):
74 | eid = sid + bs if i != num_itr - 1 else num
75 | batch_x = total_x[sid: eid]
76 |
77 | if model_has_feature_extractor:
78 | batch_deep_features_ = model.features(batch_x)
79 | else:
80 | batch_deep_features_ = torch.squeeze(model_features(batch_x))
81 |
82 | deep_features_list.append(batch_deep_features_.reshape((batch_x.size(0), -1)))
83 | sid = eid
84 | if num_itr == 1:
85 | deep_features_ = deep_features_list[0]
86 | else:
87 | deep_features_ = torch.cat(deep_features_list, 0)
88 | if is_train:
89 | model.train()
90 | return deep_features_
91 |
92 |
93 | def euclidean_distance(u, v):
94 | euclidean_distance_ = (u - v).pow(2).sum(1)
95 | return euclidean_distance_
96 |
97 |
98 | def ohe_label(label_tensor, dim, device="cpu"):
99 | # Returns one-hot-encoding of input label tensor
100 | n_labels = label_tensor.size(0)
101 | zero_tensor = torch.zeros((n_labels, dim), device=device, dtype=torch.long)
102 | return zero_tensor.scatter_(1, label_tensor.reshape((n_labels, 1)), 1)
103 |
104 |
105 | def nonzero_indices(bool_mask_tensor):
106 | # Returns tensor which contains indices of nonzero elements in bool_mask_tensor
107 | return bool_mask_tensor.nonzero(as_tuple=True)[0]
108 |
109 |
110 | class EarlyStopping():
111 | def __init__(self, min_delta, patience, cumulative_delta):
112 | self.min_delta = min_delta
113 | self.patience = patience
114 | self.cumulative_delta = cumulative_delta
115 | self.counter = 0
116 | self.best_score = None
117 |
118 | def step(self, score):
119 | if self.best_score is None:
120 | self.best_score = score
121 | elif score <= self.best_score + self.min_delta:
122 | if not self.cumulative_delta and score > self.best_score:
123 | self.best_score = score
124 | self.counter += 1
125 | if self.counter >= self.patience:
126 | return True
127 | else:
128 | self.best_score = score
129 | self.counter = 0
130 | return False
131 |
132 | def reset(self):
133 | self.counter = 0
134 | self.best_score = None
135 |
--------------------------------------------------------------------------------
/utils/buffer/buffer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import numpy as np
4 | from collections import defaultdict
5 |
6 | from utils.setup_elements import input_size_match
7 | from utils import name_match
8 | from utils.utils import maybe_cuda
9 | from utils.buffer.buffer_utils import BufferClassTracker
10 | from utils.setup_elements import n_classes
11 |
12 |
13 | class Buffer(torch.nn.Module):
14 | def __init__(self, model, params):
15 | super().__init__()
16 | self.params = params
17 | self.model = model
18 | self.cuda = self.params.cuda
19 | self.current_index = 0
20 | self.n_seen_so_far = 0
21 | self.device = "cuda" if self.params.cuda else "cpu"
22 | self.num_classes_per_task = self.params.num_classes_per_task
23 | self.num_classes = 0
24 |
25 | # define buffer
26 | buffer_size = params.mem_size
27 | print('buffer has %d slots' % buffer_size)
28 | input_size = input_size_match[params.data]
29 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0))
30 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0))
31 |
32 | # registering as buffer allows us to save the object using `torch.save`
33 | self.register_buffer('buffer_img', buffer_img)
34 | self.register_buffer('buffer_label', buffer_label)
35 | self.labeldict = defaultdict(list)
36 | self.labelsize = params.images_per_class
37 | self.avail_indices = list(np.arange(buffer_size))
38 |
39 | # define update and retrieve method
40 | self.update_method = name_match.update_methods[params.update](params)
41 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params)
42 |
43 | if self.params.buffer_tracker:
44 | self.buffer_tracker = BufferClassTracker(n_classes[params.data], self.device)
45 |
46 | def update(self, x, y,**kwargs):
47 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs)
48 |
49 | def retrieve(self, **kwargs):
50 | return self.retrieve_method.retrieve(buffer=self, **kwargs)
51 |
52 | def new_task(self, **kwargs):
53 | self.num_classes += self.num_classes_per_task
54 | self.labelsize = self.params.mem_size // self.num_classes
55 |
56 | def new_condense_task(self, **kwargs):
57 | self.num_classes += self.num_classes_per_task
58 | self.update_method.new_task(self.num_classes)
59 |
60 |
61 | class DynamicBuffer(torch.nn.Module):
62 | def __init__(self, model, params):
63 | super().__init__()
64 | self.params = params
65 | self.model = model
66 | self.cuda = self.params.cuda
67 | self.current_index = 0
68 | self.n_seen_so_far = 0
69 | self.device = "cuda" if self.params.cuda else "cpu"
70 | self.num_classes_per_task = self.params.num_classes_per_task
71 | self.images_per_class = self.params.images_per_class
72 | self.num_classes = 0
73 |
74 | # define buffer
75 | buffer_size = params.mem_size
76 | print('buffer has %d slots' % buffer_size)
77 | input_size = input_size_match[params.data]
78 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0))
79 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0))
80 |
81 | # registering as buffer allows us to save the object using `torch.save`
82 | self.register_buffer('buffer_img', buffer_img)
83 | self.register_buffer('buffer_img_rep', copy.deepcopy(buffer_img))
84 | self.register_buffer('buffer_label', buffer_label)
85 | self.condense_dict = defaultdict(list)
86 | self.labelsize = params.images_per_class
87 | self.avail_indices = list(np.arange(buffer_size))
88 |
89 | # define update and retrieve method
90 | self.update_method = name_match.update_methods[params.update](params)
91 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params)
92 |
93 | def update(self, x, y,**kwargs):
94 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs)
95 |
96 | def retrieve(self, **kwargs):
97 | return self.retrieve_method.retrieve(buffer=self, **kwargs)
98 |
99 | def new_task(self, **kwargs):
100 | self.num_classes += self.num_classes_per_task
101 | self.labelsize = self.params.mem_size // self.num_classes
102 |
103 | def new_condense_task(self, labels, **kwargs):
104 | self.num_classes += self.num_classes_per_task
105 | self.update_method.new_task(self.num_classes, labels)
106 |
107 |
--------------------------------------------------------------------------------
/agents/agem.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils.setup_elements import transforms_match
4 | from torch.utils import data
5 | from utils.buffer.buffer import Buffer
6 | from utils.utils import maybe_cuda, AverageMeter
7 | import torch
8 |
9 |
10 | class AGEM(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(AGEM, self).__init__(model, opt, params)
13 | self.buffer = Buffer(model, params)
14 | self.mem_size = params.mem_size
15 | self.eps_mem_batch = params.eps_mem_batch
16 | self.mem_iters = params.mem_iters
17 |
18 | def train_learner(self, x_train, y_train):
19 | self.before_train(x_train, y_train)
20 |
21 | # set up loader
22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
24 | drop_last=True)
25 | # set up model
26 | self.model = self.model.train()
27 |
28 | # setup tracker
29 | losses_batch = AverageMeter()
30 | acc_batch = AverageMeter()
31 |
32 | for ep in range(self.epoch):
33 | for i, batch_data in enumerate(train_loader):
34 | # batch update
35 | batch_x, batch_y = batch_data
36 | batch_x = maybe_cuda(batch_x, self.cuda)
37 | batch_y = maybe_cuda(batch_y, self.cuda)
38 | for j in range(self.mem_iters):
39 | logits = self.forward(batch_x)
40 | loss = self.criterion(logits, batch_y)
41 | if self.params.trick['kd_trick']:
42 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
43 | self.kd_manager.get_kd_loss(logits, batch_x)
44 | if self.params.trick['kd_trick_star']:
45 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
46 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
47 | _, pred_label = torch.max(logits, 1)
48 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
49 | # update tracker
50 | acc_batch.update(correct_cnt, batch_y.size(0))
51 | losses_batch.update(loss, batch_y.size(0))
52 | # backward
53 | self.opt.zero_grad()
54 | loss.backward()
55 |
56 | if self.task_seen > 0:
57 | # sample from memory of previous tasks
58 | mem_x, mem_y = self.buffer.retrieve()
59 | if mem_x.size(0) > 0:
60 | params = [p for p in self.model.parameters() if p.requires_grad]
61 | # gradient computed using current batch
62 | grad = [p.grad.clone() for p in params]
63 | mem_x = maybe_cuda(mem_x, self.cuda)
64 | mem_y = maybe_cuda(mem_y, self.cuda)
65 | mem_logits = self.forward(mem_x)
66 | loss_mem = self.criterion(mem_logits, mem_y)
67 | self.opt.zero_grad()
68 | loss_mem.backward()
69 | # gradient computed using memory samples
70 | grad_ref = [p.grad.clone() for p in params]
71 |
72 | # inner product of grad and grad_ref
73 | prod = sum([torch.sum(g * g_r) for g, g_r in zip(grad, grad_ref)])
74 | if prod < 0:
75 | prod_ref = sum([torch.sum(g_r ** 2) for g_r in grad_ref])
76 | # do projection
77 | grad = [g - prod / prod_ref * g_r for g, g_r in zip(grad, grad_ref)]
78 | # replace params' grad
79 | for g, p in zip(grad, params):
80 | p.grad.data.copy_(g)
81 | self.opt.step()
82 | # update mem
83 | self.buffer.update(batch_x, batch_y)
84 |
85 | if i % 100 == 1 and self.verbose:
86 | print(
87 | '==>>> it: {}, avg. loss: {:.6f}, '
88 | 'running train acc: {:.3f}'
89 | .format(i, losses_batch.avg(), acc_batch.avg())
90 | )
91 | self.after_train()
--------------------------------------------------------------------------------
/continuum/dataset_scripts/openloris.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from PIL import Image
3 | import numpy as np
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | import time
6 | from continuum.data_utils import shuffle_data
7 |
8 |
9 | class OpenLORIS(DatasetBase):
10 | """
11 | tasks_nums is predefined and it depends on the ns_type.
12 | """
13 | def __init__(self, scenario, params): # scenario refers to "ni" or "nc"
14 | dataset = 'openloris'
15 | self.ns_type = params.ns_type
16 | task_nums = openloris_ntask[self.ns_type] # ns_type can be (illumination, occlusion, pixel, clutter, sequence)
17 | super(OpenLORIS, self).__init__(dataset, scenario, task_nums, params.num_runs, params)
18 |
19 |
20 | def download_load(self):
21 | s = time.time()
22 | self.train_set = []
23 | for batch_num in range(1, self.task_nums+1):
24 | train_x = []
25 | train_y = []
26 | test_x = []
27 | test_y = []
28 | for i in range(len(datapath)):
29 | train_temp = glob.glob('datasets/openloris/' + self.ns_type + '/train/task{}/{}/*.jpg'.format(batch_num, datapath[i]))
30 |
31 | train_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in train_temp])
32 | train_y.extend([i] * len(train_temp))
33 |
34 | test_temp = glob.glob(
35 | 'datasets/openloris/' + self.ns_type + '/test/task{}/{}/*.jpg'.format(batch_num, datapath[i]))
36 |
37 | test_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in test_temp])
38 | test_y.extend([i] * len(test_temp))
39 |
40 | print(" --> batch{}'-dataset consisting of {} samples".format(batch_num, len(train_x)))
41 | print(" --> test'-dataset consisting of {} samples".format(len(test_x)))
42 | self.train_set.append((np.array(train_x), np.array(train_y)))
43 | self.test_set.append((np.array(test_x), np.array(test_y)))
44 | e = time.time()
45 | print('loading time: {}'.format(str(e - s)))
46 |
47 | def new_run(self, **kwargs):
48 | pass
49 |
50 | def new_task(self, cur_task, **kwargs):
51 | train_x, train_y = self.train_set[cur_task]
52 | # get val set
53 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y)
54 | val_size = int(len(train_x_rdm) * self.params.val_size)
55 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size]
56 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:]
57 | self.val_set.append((val_data_rdm, val_label_rdm))
58 | labels = set(train_label_rdm)
59 | return train_data_rdm, train_label_rdm, labels
60 |
61 | def setup(self, **kwargs):
62 | pass
63 |
64 |
65 |
66 | openloris_ntask = {
67 | 'illumination': 9,
68 | 'occlusion': 9,
69 | 'pixel': 9,
70 | 'clutter': 9,
71 | 'sequence': 12
72 | }
73 |
74 | datapath = ['bottle_01', 'bottle_02', 'bottle_03', 'bottle_04', 'bowl_01', 'bowl_02', 'bowl_03', 'bowl_04', 'bowl_05',
75 | 'corkscrew_01', 'cottonswab_01', 'cottonswab_02', 'cup_01', 'cup_02', 'cup_03', 'cup_04', 'cup_05',
76 | 'cup_06', 'cup_07', 'cup_08', 'cup_10', 'cushion_01', 'cushion_02', 'cushion_03', 'glasses_01',
77 | 'glasses_02', 'glasses_03', 'glasses_04', 'knife_01', 'ladle_01', 'ladle_02', 'ladle_03', 'ladle_04',
78 | 'mask_01', 'mask_02', 'mask_03', 'mask_04', 'mask_05', 'paper_cutter_01', 'paper_cutter_02',
79 | 'paper_cutter_03', 'paper_cutter_04', 'pencil_01', 'pencil_02', 'pencil_03', 'pencil_04', 'pencil_05',
80 | 'plasticbag_01', 'plasticbag_02', 'plasticbag_03', 'plug_01', 'plug_02', 'plug_03', 'plug_04', 'pot_01',
81 | 'scissors_01', 'scissors_02', 'scissors_03', 'stapler_01', 'stapler_02', 'stapler_03', 'thermometer_01',
82 | 'thermometer_02', 'thermometer_03', 'toy_01', 'toy_02', 'toy_03', 'toy_04', 'toy_05','nail_clippers_01','nail_clippers_02',
83 | 'nail_clippers_03', 'bracelet_01', 'bracelet_02','bracelet_03', 'comb_01','comb_02',
84 | 'comb_03', 'umbrella_01','umbrella_02','umbrella_03','socks_01','socks_02','socks_03',
85 | 'toothpaste_01','toothpaste_02','toothpaste_03','wallet_01','wallet_02','wallet_03',
86 | 'headphone_01','headphone_02','headphone_03', 'key_01','key_02','key_03',
87 | 'battery_01', 'battery_02', 'mouse_01', 'pencilcase_01', 'pencilcase_02', 'tape_01',
88 | 'chopsticks_01', 'chopsticks_02', 'chopsticks_03',
89 | 'notebook_01', 'notebook_02', 'notebook_03',
90 | 'spoon_01', 'spoon_02', 'spoon_03',
91 | 'tissue_01', 'tissue_02', 'tissue_03',
92 | 'clamp_01', 'clamp_02', 'hat_01', 'hat_02', 'u_disk_01', 'u_disk_02', 'swimming_glasses_01'
93 | ]
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/agents/ewc_pp.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils.setup_elements import transforms_match
4 | from torch.utils import data
5 | from utils.utils import maybe_cuda, AverageMeter
6 | import torch
7 |
8 | class EWC_pp(ContinualLearner):
9 | def __init__(self, model, opt, params):
10 | super(EWC_pp, self).__init__(model, opt, params)
11 | self.weights = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
12 | self.lambda_ = params.lambda_
13 | self.alpha = params.alpha
14 | self.fisher_update_after = params.fisher_update_after
15 | self.prev_params = {}
16 | self.running_fisher = self.init_fisher()
17 | self.tmp_fisher = self.init_fisher()
18 | self.normalized_fisher = self.init_fisher()
19 |
20 | def train_learner(self, x_train, y_train):
21 | self.before_train(x_train, y_train)
22 | # set up loader
23 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
24 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
25 | drop_last=True)
26 | # setup tracker
27 | losses_batch = AverageMeter()
28 | acc_batch = AverageMeter()
29 |
30 | # set up model
31 | self.model.train()
32 |
33 | for ep in range(self.epoch):
34 | for i, batch_data in enumerate(train_loader):
35 | # batch update
36 | batch_x, batch_y = batch_data
37 | batch_x = maybe_cuda(batch_x, self.cuda)
38 | batch_y = maybe_cuda(batch_y, self.cuda)
39 |
40 | # update the running fisher
41 | if (ep * len(train_loader) + i + 1) % self.fisher_update_after == 0:
42 | self.update_running_fisher()
43 |
44 | out = self.forward(batch_x)
45 | loss = self.total_loss(out, batch_y)
46 | if self.params.trick['kd_trick']:
47 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
48 | self.kd_manager.get_kd_loss(out, batch_x)
49 | if self.params.trick['kd_trick_star']:
50 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
51 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x)
52 | # update tracker
53 | losses_batch.update(loss.item(), batch_y.size(0))
54 | _, pred_label = torch.max(out, 1)
55 | acc = (pred_label == batch_y).sum().item() / batch_y.size(0)
56 | acc_batch.update(acc, batch_y.size(0))
57 | # backward
58 | self.opt.zero_grad()
59 | loss.backward()
60 |
61 | # accumulate the fisher of current batch
62 | self.accum_fisher()
63 | self.opt.step()
64 |
65 | if i % 100 == 1 and self.verbose:
66 | print(
67 | '==>>> it: {}, avg. loss: {:.6f}, '
68 | 'running train acc: {:.3f}'
69 | .format(i, losses_batch.avg(), acc_batch.avg())
70 | )
71 |
72 | # save params for current task
73 | for n, p in self.weights.items():
74 | self.prev_params[n] = p.clone().detach()
75 |
76 | # update normalized fisher of current task
77 | max_fisher = max([torch.max(m) for m in self.running_fisher.values()])
78 | min_fisher = min([torch.min(m) for m in self.running_fisher.values()])
79 | for n, p in self.running_fisher.items():
80 | self.normalized_fisher[n] = (p - min_fisher) / (max_fisher - min_fisher + 1e-32)
81 | self.after_train()
82 |
83 | def total_loss(self, inputs, targets):
84 | # cross entropy loss
85 | loss = self.criterion(inputs, targets)
86 | if len(self.prev_params) > 0:
87 | # add regularization loss
88 | reg_loss = 0
89 | for n, p in self.weights.items():
90 | reg_loss += (self.normalized_fisher[n] * (p - self.prev_params[n]) ** 2).sum()
91 | loss += self.lambda_ * reg_loss
92 | return loss
93 |
94 | def init_fisher(self):
95 | return {n: p.clone().detach().fill_(0) for n, p in self.model.named_parameters() if p.requires_grad}
96 |
97 | def update_running_fisher(self):
98 | for n, p in self.running_fisher.items():
99 | self.running_fisher[n] = (1. - self.alpha) * p \
100 | + 1. / self.fisher_update_after * self.alpha * self.tmp_fisher[n]
101 | # reset the accumulated fisher
102 | self.tmp_fisher = self.init_fisher()
103 |
104 | def accum_fisher(self):
105 | for n, p in self.tmp_fisher.items():
106 | p += self.weights[n].grad ** 2
--------------------------------------------------------------------------------
/continuum/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils import data
4 | from utils.setup_elements import transforms_match
5 | from collections import defaultdict
6 | import copy
7 |
8 | def create_task_composition(class_nums, num_tasks, fixed_order=False):
9 | classes_per_task = class_nums // num_tasks
10 | total_classes = classes_per_task * num_tasks
11 | label_array = np.arange(0, total_classes)
12 | if not fixed_order:
13 | np.random.shuffle(label_array)
14 |
15 | task_labels = []
16 | for tt in range(num_tasks):
17 | tt_offset = tt * classes_per_task
18 | task_labels.append(list(label_array[tt_offset:tt_offset + classes_per_task]))
19 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt]))
20 | return task_labels
21 |
22 |
23 | def load_task_with_labels_torch(x, y, labels):
24 | tmp = []
25 | for i in labels:
26 | tmp.append((y == i).nonzero().view(-1))
27 | idx = torch.cat(tmp)
28 | return x[idx], y[idx]
29 |
30 |
31 | def load_task_with_labels(x, y, labels):
32 | tmp = []
33 | for i in labels:
34 | tmp.append((np.where(y == i)[0]))
35 | idx = np.concatenate(tmp, axis=None)
36 | return x[idx], y[idx]
37 |
38 |
39 |
40 | class dataset_transform(data.Dataset):
41 | def __init__(self, x, y, transform=None):
42 | self.x = x
43 | self.y = torch.from_numpy(y).type(torch.LongTensor)
44 | self.transform = transform # save the transform
45 |
46 | def __len__(self):
47 | return len(self.y)#self.x.shape[0] # return 1 as we have only one image
48 |
49 | def __getitem__(self, idx):
50 | # return the augmented image
51 | if self.transform:
52 | x = self.transform(self.x[idx])
53 | else:
54 | x = self.x[idx]
55 |
56 | return x.float(), self.y[idx]
57 |
58 |
59 | class BalancedSampler(data.Sampler):
60 | def __init__(self, x, y, batch_size):
61 | self.x = x
62 | self.y = y
63 | self.batch_size = batch_size
64 | self.labeldict = defaultdict(list)
65 | for idx, label in enumerate(y):
66 | self.labeldict[label].append(idx)
67 | self.labelset = set(self.y)
68 | self.num_classes = len(set(self.y))
69 | self.num_instances = batch_size // self.num_classes
70 |
71 | def __iter__(self):
72 | batch_idx_dict = defaultdict(list)
73 | for label in self.labelset:
74 | indices = copy.deepcopy(self.labeldict[label])
75 | if len(indices) < self.num_instances:
76 | indices = np.random.choice(indices, size=self.num_instances, replace=True)
77 | np.random.shuffle(indices)
78 | batch_idx = []
79 | for idx in indices:
80 | batch_idx.append(idx)
81 | if len(batch_idx) == self.num_instances:
82 | batch_idx_dict[label].append(batch_idx)
83 | batch_idx = []
84 |
85 | avail_labels = copy.deepcopy(self.labelset)
86 | final_indices = []
87 |
88 | while len(avail_labels) >= self.num_classes:
89 | batch_indices = []
90 | for label in self.labelset:
91 | batch_idx = batch_idx_dict[label].pop(0)
92 | batch_indices.extend(batch_idx)
93 | if len(batch_idx_dict[label]) == 0:
94 | avail_labels.remove(label)
95 | np.random.shuffle(batch_indices)
96 | final_indices.extend(batch_indices)
97 |
98 | return iter(final_indices)
99 |
100 | def __len__(self):
101 | return len(self.y) // self.batch_size
102 |
103 |
104 | def setup_test_loader(test_data, params):
105 | test_loaders = []
106 |
107 | for (x_test, y_test) in test_data:
108 | test_dataset = dataset_transform(x_test, y_test, transform=transforms_match[params.data])
109 | test_loader = data.DataLoader(test_dataset, batch_size=params.test_batch, shuffle=True, num_workers=0)
110 | test_loaders.append(test_loader)
111 | return test_loaders
112 |
113 |
114 | def shuffle_data(x, y):
115 | perm_inds = np.arange(0, x.shape[0])
116 | np.random.shuffle(perm_inds)
117 | rdm_x = x[perm_inds]
118 | rdm_y = y[perm_inds]
119 | return rdm_x, rdm_y
120 |
121 |
122 | def train_val_test_split_ni(train_data, train_label, test_data, test_label, task_nums, img_size, val_size=0.1):
123 | train_data_rdm, train_label_rdm = shuffle_data(train_data, train_label)
124 | val_size = int(len(train_data_rdm) * val_size)
125 | val_data_rdm, val_label_rdm = train_data_rdm[:val_size], train_label_rdm[:val_size]
126 | train_data_rdm, train_label_rdm = train_data_rdm[val_size:], train_label_rdm[val_size:]
127 | test_data_rdm, test_label_rdm = shuffle_data(test_data, test_label)
128 | train_data_rdm_split = train_data_rdm.reshape(task_nums, -1, img_size, img_size, 3)
129 | train_label_rdm_split = train_label_rdm.reshape(task_nums, -1)
130 | val_data_rdm_split = val_data_rdm.reshape(task_nums, -1, img_size, img_size, 3)
131 | val_label_rdm_split = val_label_rdm.reshape(task_nums, -1)
132 | test_data_rdm_split = test_data_rdm.reshape(task_nums, -1, img_size, img_size, 3)
133 | test_label_rdm_split = test_label_rdm.reshape(task_nums, -1)
134 | return train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split
135 |
--------------------------------------------------------------------------------
/utils/buffer/aser_update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.buffer.reservoir_update import Reservoir_update
3 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve
4 | from utils.buffer.aser_utils import compute_knn_sv, add_minority_class_input
5 | from utils.setup_elements import n_classes
6 | from utils.utils import nonzero_indices, maybe_cuda
7 |
8 |
9 | class ASER_update(object):
10 | def __init__(self, params, **kwargs):
11 | super().__init__()
12 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
13 | self.k = params.k
14 | self.mem_size = params.mem_size
15 | self.num_tasks = params.num_tasks
16 | self.out_dim = n_classes[params.data]
17 | self.n_smp_cls = int(params.n_smp_cls)
18 | self.n_total_smp = int(params.n_smp_cls * self.out_dim)
19 | self.reservoir_update = Reservoir_update(params)
20 | ClassBalancedRandomSampling.class_index_cache = None
21 |
22 | def update(self, buffer, x, y, **kwargs):
23 | model = buffer.model
24 |
25 | place_left = self.mem_size - buffer.current_index
26 |
27 | # If buffer is not filled, use available space to store whole or part of batch
28 | if place_left:
29 | x_fit = x[:place_left]
30 | y_fit = y[:place_left]
31 |
32 | ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device)
33 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim,
34 | new_y=y_fit, ind=ind, device=self.device)
35 | self.reservoir_update.update(buffer, x_fit, y_fit)
36 |
37 | # If buffer is filled, update buffer by sv
38 | if buffer.current_index == self.mem_size:
39 | # remove what is already in the buffer
40 | cur_x, cur_y = x[place_left:], y[place_left:]
41 | self._update_by_knn_sv(model, buffer, cur_x, cur_y)
42 |
43 | def _update_by_knn_sv(self, model, buffer, cur_x, cur_y):
44 | """
45 | Returns indices for replacement.
46 | Buffered instances with smallest SV are replaced by current input with higher SV.
47 | Args:
48 | model (object): neural network.
49 | buffer (object): buffer object.
50 | cur_x (tensor): current input data tensor.
51 | cur_y (tensor): current input label tensor.
52 | Returns
53 | ind_buffer (tensor): indices of buffered instances to be replaced.
54 | ind_cur (tensor): indices of current data to do replacement.
55 | """
56 | cur_x = maybe_cuda(cur_x)
57 | cur_y = maybe_cuda(cur_y)
58 |
59 | # Find minority class samples from current input batch
60 | minority_batch_x, minority_batch_y = add_minority_class_input(cur_x, cur_y, self.mem_size, self.out_dim)
61 |
62 | # Evaluation set
63 | eval_x, eval_y, eval_indices = \
64 | ClassBalancedRandomSampling.sample(buffer.buffer_img, buffer.buffer_label, self.n_smp_cls,
65 | device=self.device)
66 |
67 | # Concatenate minority class samples from current input batch to evaluation set
68 | eval_x = torch.cat((eval_x, minority_batch_x))
69 | eval_y = torch.cat((eval_y, minority_batch_y))
70 |
71 | # Candidate set
72 | cand_excl_indices = set(eval_indices.tolist())
73 | cand_x, cand_y, cand_ind = random_retrieve(buffer, self.n_total_smp, cand_excl_indices, return_indices=True)
74 |
75 | # Concatenate current input batch to candidate set
76 | cand_x = torch.cat((cand_x, cur_x))
77 | cand_y = torch.cat((cand_y, cur_y))
78 |
79 | sv_matrix = compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, self.k, device=self.device)
80 | sv = sv_matrix.sum(0)
81 |
82 | n_cur = cur_x.size(0)
83 | n_cand = cand_x.size(0)
84 |
85 | # Number of previously buffered instances in candidate set
86 | n_cand_buf = n_cand - n_cur
87 |
88 | sv_arg_sort = sv.argsort(descending=True)
89 |
90 | # Divide SV array into two segments
91 | # - large: candidate args to be retained; small: candidate args to be discarded
92 | sv_arg_large = sv_arg_sort[:n_cand_buf]
93 | sv_arg_small = sv_arg_sort[n_cand_buf:]
94 |
95 | # Extract args relevant to replacement operation
96 | # If current data instances are in 'large' segment, they are added to buffer
97 | # If buffered instances are in 'small' segment, they are discarded from buffer
98 | # Replacement happens between these two sets
99 | # Retrieve original indices from candidate args
100 | ind_cur = sv_arg_large[nonzero_indices(sv_arg_large >= n_cand_buf)] - n_cand_buf
101 | arg_buffer = sv_arg_small[nonzero_indices(sv_arg_small < n_cand_buf)]
102 | ind_buffer = cand_ind[arg_buffer]
103 |
104 | buffer.n_seen_so_far += n_cur
105 |
106 | # perform overwrite op
107 | y_upt = cur_y[ind_cur]
108 | x_upt = cur_x[ind_cur]
109 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim,
110 | new_y=y_upt, ind=ind_buffer, device=self.device)
111 | buffer.buffer_img[ind_buffer] = x_upt
112 | buffer.buffer_label[ind_buffer] = y_upt
113 |
--------------------------------------------------------------------------------
/agents/exp_replay.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | from utils.buffer.buffer import Buffer
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform
6 | from utils.setup_elements import transforms_match
7 | from utils.utils import maybe_cuda, AverageMeter
8 |
9 |
10 | class ExperienceReplay(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(ExperienceReplay, self).__init__(model, opt, params)
13 | self.buffer = Buffer(model, params)
14 | self.mem_size = params.mem_size
15 | self.eps_mem_batch = params.eps_mem_batch
16 | self.mem_iters = params.mem_iters
17 |
18 | def train_learner(self, x_train, y_train, **kwargs):
19 | self.before_train(x_train, y_train)
20 | # set up loader
21 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
22 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
23 | drop_last=True)
24 | # set up model
25 | self.model = self.model.train()
26 |
27 | # setup tracker
28 | losses_batch = AverageMeter()
29 | losses_mem = AverageMeter()
30 | acc_batch = AverageMeter()
31 | acc_mem = AverageMeter()
32 |
33 | for ep in range(self.epoch):
34 | for i, batch_data in enumerate(train_loader):
35 | # batch update
36 | batch_x, batch_y = batch_data
37 | batch_x = maybe_cuda(batch_x, self.cuda)
38 | batch_y = maybe_cuda(batch_y, self.cuda)
39 | for j in range(self.mem_iters):
40 | logits = self.model.forward(batch_x)
41 | loss = self.criterion(logits, batch_y)
42 | if self.params.trick['kd_trick']:
43 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
44 | self.kd_manager.get_kd_loss(logits, batch_x)
45 | if self.params.trick['kd_trick_star']:
46 | loss = 1/((self.task_seen + 1) ** 0.5) * loss + \
47 | (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
48 | _, pred_label = torch.max(logits, 1)
49 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
50 | # update tracker
51 | acc_batch.update(correct_cnt, batch_y.size(0))
52 | losses_batch.update(loss, batch_y.size(0))
53 | # backward
54 | self.opt.zero_grad()
55 | loss.backward()
56 |
57 | # mem update
58 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
59 | if mem_x.size(0) > 0:
60 | mem_x = maybe_cuda(mem_x, self.cuda)
61 | mem_y = maybe_cuda(mem_y, self.cuda)
62 | mem_logits = self.model.forward(mem_x)
63 | loss_mem = self.criterion(mem_logits, mem_y)
64 | if self.params.trick['kd_trick']:
65 | loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \
66 | self.kd_manager.get_kd_loss(mem_logits, mem_x)
67 | if self.params.trick['kd_trick_star']:
68 | loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \
69 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(mem_logits,
70 | mem_x)
71 | # update tracker
72 | losses_mem.update(loss_mem, mem_y.size(0))
73 | _, pred_label = torch.max(mem_logits, 1)
74 | correct_cnt = (pred_label == mem_y).sum().item() / mem_y.size(0)
75 | acc_mem.update(correct_cnt, mem_y.size(0))
76 |
77 | loss_mem.backward()
78 |
79 | if self.params.update == 'ASER' or self.params.retrieve == 'ASER':
80 | # opt update
81 | self.opt.zero_grad()
82 | combined_batch = torch.cat((mem_x, batch_x))
83 | combined_labels = torch.cat((mem_y, batch_y))
84 | combined_logits = self.model.forward(combined_batch)
85 | loss_combined = self.criterion(combined_logits, combined_labels)
86 | loss_combined.backward()
87 | self.opt.step()
88 | else:
89 | self.opt.step()
90 |
91 | # update mem
92 | self.buffer.update(batch_x, batch_y)
93 |
94 | if i % 100 == 1 and self.verbose:
95 | print(
96 | '==>>> it: {}, avg. loss: {:.6f}, '
97 | 'running train acc: {:.3f}'
98 | .format(i, losses_batch.avg(), acc_batch.avg())
99 | )
100 | print(
101 | '==>>> it: {}, mem avg. loss: {:.6f}, '
102 | 'running mem acc: {:.3f}'
103 | .format(i, losses_mem.avg(), acc_mem.avg())
104 | )
105 | self.after_train()
106 |
--------------------------------------------------------------------------------
/models/convnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 |
7 | class ConvNet(nn.Module):
8 | def __init__(self,
9 | num_classes,
10 | net_norm='instance',
11 | net_depth=3,
12 | net_width=128,
13 | channel=3,
14 | net_act='relu',
15 | net_pooling='avgpooling',
16 | im_size=(32, 32)):
17 | super(ConvNet, self).__init__()
18 | if net_act == 'sigmoid':
19 | self.net_act = nn.Sigmoid()
20 | elif net_act == 'relu':
21 | self.net_act = nn.ReLU()
22 | elif net_act == 'leakyrelu':
23 | self.net_act = nn.LeakyReLU(negative_slope=0.01)
24 | else:
25 | exit('unknown activation function: %s' % net_act)
26 |
27 | if net_pooling == 'maxpooling':
28 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
29 | elif net_pooling == 'avgpooling':
30 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2)
31 | elif net_pooling == 'none':
32 | self.net_pooling = None
33 | else:
34 | exit('unknown net_pooling: %s' % net_pooling)
35 |
36 | self.depth = net_depth
37 | self.net_norm = net_norm
38 |
39 | self.layers, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm,
40 | net_pooling, im_size)
41 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
42 | self.num_feat = num_feat
43 | self.classifier = nn.Linear(num_feat, num_classes)
44 |
45 | def forward(self, x, return_features=False):
46 | for d in range(self.depth):
47 | x = self.layers['conv'][d](x)
48 | if len(self.layers['norm']) > 0:
49 | x = self.layers['norm'][d](x)
50 | x = self.layers['act'][d](x)
51 | if len(self.layers['pool']) > 0:
52 | x = self.layers['pool'][d](x)
53 |
54 | out = x.view(x.shape[0], -1)
55 | logit = self.classifier(out)
56 |
57 | if return_features:
58 | return logit, out
59 | else:
60 | return logit
61 |
62 | def features(self, x):
63 | for d in range(self.depth):
64 | x = self.layers['conv'][d](x)
65 | if len(self.layers['norm']) > 0:
66 | x = self.layers['norm'][d](x)
67 | x = self.layers['act'][d](x)
68 | if len(self.layers['pool']) > 0:
69 | x = self.layers['pool'][d](x)
70 |
71 | out = x.view(x.shape[0], -1)
72 |
73 | return out
74 |
75 | def get_feature(self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False):
76 | if idx_to == -1:
77 | idx_to = idx_from
78 | features = []
79 |
80 | for d in range(self.depth):
81 | x = self.layers['conv'][d](x)
82 | if self.net_norm:
83 | x = self.layers['norm'][d](x)
84 | x = self.layers['act'][d](x)
85 | if self.net_pooling:
86 | x = self.layers['pool'][d](x)
87 | features.append(x)
88 | if idx_to < len(features):
89 | return features[idx_from:idx_to + 1]
90 |
91 | if return_prob:
92 | out = x.view(x.size(0), -1)
93 | logit = self.classifier(out)
94 | prob = torch.softmax(logit, dim=-1)
95 | return features, prob
96 | elif return_logit:
97 | out = x.view(x.size(0), -1)
98 | logit = self.classifier(out)
99 | return features, logit
100 | else:
101 | return features[idx_from:idx_to + 1]
102 |
103 | def _get_normlayer(self, net_norm, shape_feat):
104 | if net_norm == 'batch':
105 | norm = nn.BatchNorm2d(shape_feat[0], affine=True)
106 | elif net_norm == 'layer':
107 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True)
108 | elif net_norm == 'instance':
109 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
110 | elif net_norm == 'group':
111 | norm = nn.GroupNorm(4, shape_feat[0], affine=True)
112 | elif net_norm == 'none':
113 | norm = None
114 | else:
115 | norm = None
116 | exit('unknown net_norm: %s' % net_norm)
117 | return norm
118 |
119 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_pooling, im_size):
120 | layers = {'conv': [], 'norm': [], 'act': [], 'pool': []}
121 |
122 | in_channels = channel
123 | if im_size[0] == 28:
124 | im_size = (32, 32)
125 | shape_feat = [in_channels, im_size[0], im_size[1]]
126 |
127 | for d in range(net_depth):
128 | layers['conv'] += [
129 | nn.Conv2d(in_channels,
130 | net_width,
131 | kernel_size=3,
132 | padding=3 if channel == 1 and d == 0 else 1)
133 | ]
134 | shape_feat[0] = net_width
135 | if net_norm != 'none':
136 | layers['norm'] += [self._get_normlayer(net_norm, shape_feat)]
137 | layers['act'] += [self.net_act]
138 | in_channels = net_width
139 | if net_pooling != 'none':
140 | layers['pool'] += [self.net_pooling]
141 | shape_feat[1] //= 2
142 | shape_feat[2] //= 2
143 |
144 | layers['conv'] = nn.ModuleList(layers['conv'])
145 | layers['norm'] = nn.ModuleList(layers['norm'])
146 | layers['act'] = nn.ModuleList(layers['act'])
147 | layers['pool'] = nn.ModuleList(layers['pool'])
148 | layers = nn.ModuleDict(layers)
149 |
150 | return layers, shape_feat
151 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/core50.py:
--------------------------------------------------------------------------------
1 | import os
2 | from continuum.dataset_scripts.dataset_base import DatasetBase
3 | import pickle as pkl
4 | import logging
5 | from hashlib import md5
6 | import numpy as np
7 | from PIL import Image
8 | from continuum.data_utils import shuffle_data, load_task_with_labels
9 | import time
10 |
11 | core50_ntask = {
12 | 'ni': 8,
13 | 'nc': 9,
14 | 'nic': 79,
15 | 'nicv2_79': 79,
16 | 'nicv2_196': 196,
17 | 'nicv2_391': 391
18 | }
19 |
20 |
21 | class CORE50(DatasetBase):
22 | def __init__(self, scenario, params):
23 | if isinstance(params.num_runs, int) and params.num_runs > 10:
24 | raise Exception('the max number of runs for CORE50 is 10')
25 | dataset = 'core50'
26 | task_nums = core50_ntask[scenario]
27 | super(CORE50, self).__init__(dataset, scenario, task_nums, params.num_runs, params)
28 |
29 |
30 | def download_load(self):
31 |
32 | print("Loading paths...")
33 | with open(os.path.join(self.root, 'paths.pkl'), 'rb') as f:
34 | self.paths = pkl.load(f)
35 |
36 | print("Loading LUP...")
37 | with open(os.path.join(self.root, 'LUP.pkl'), 'rb') as f:
38 | self.LUP = pkl.load(f)
39 |
40 | print("Loading labels...")
41 | with open(os.path.join(self.root, 'labels.pkl'), 'rb') as f:
42 | self.labels = pkl.load(f)
43 |
44 |
45 |
46 | def setup(self, cur_run):
47 | self.val_set = []
48 | self.test_set = []
49 | print('Loading test set...')
50 | test_idx_list = self.LUP[self.scenario][cur_run][-1]
51 |
52 | #test paths
53 | test_paths = []
54 | for idx in test_idx_list:
55 | test_paths.append(os.path.join(self.root, self.paths[idx]))
56 |
57 | # test imgs
58 | self.test_data = self.get_batch_from_paths(test_paths)
59 | self.test_label = np.asarray(self.labels[self.scenario][cur_run][-1])
60 |
61 |
62 | if self.scenario == 'nc':
63 | self.task_labels = self.labels[self.scenario][cur_run][:-1]
64 | for labels in self.task_labels:
65 | labels = list(set(labels))
66 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
67 | self.test_set.append((x_test, y_test))
68 | elif self.scenario == 'ni':
69 | self.test_set = [(self.test_data, self.test_label)]
70 |
71 | def new_task(self, cur_task, **kwargs):
72 | cur_run = kwargs['cur_run']
73 | s = time.time()
74 | train_idx_list = self.LUP[self.scenario][cur_run][cur_task]
75 | print("Loading data...")
76 | # Getting the actual paths
77 | train_paths = []
78 | for idx in train_idx_list:
79 | train_paths.append(os.path.join(self.root, self.paths[idx]))
80 | # loading imgs
81 | train_x = self.get_batch_from_paths(train_paths)
82 | train_y = self.labels[self.scenario][cur_run][cur_task]
83 | train_y = np.asarray(train_y)
84 | # get val set
85 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y)
86 | val_size = int(len(train_x_rdm) * self.params.val_size)
87 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size]
88 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:]
89 | self.val_set.append((val_data_rdm, val_label_rdm))
90 | e = time.time()
91 | print('loading time {}'.format(str(e-s)))
92 | return train_data_rdm, train_label_rdm, set(train_label_rdm)
93 |
94 |
95 | def new_run(self, **kwargs):
96 | cur_run = kwargs['cur_run']
97 | self.setup(cur_run)
98 |
99 |
100 | @staticmethod
101 | def get_batch_from_paths(paths, compress=False, snap_dir='',
102 | on_the_fly=True, verbose=False):
103 | """ Given a number of abs. paths it returns the numpy array
104 | of all the images. """
105 |
106 | # Getting root logger
107 | log = logging.getLogger('mylogger')
108 |
109 | # If we do not process data on the fly we check if the same train
110 | # filelist has been already processed and saved. If so, we load it
111 | # directly. In either case we end up returning x and y, as the full
112 | # training set and respective labels.
113 | num_imgs = len(paths)
114 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest()
115 | log.debug("Paths Hex: " + str(hexdigest))
116 | loaded = False
117 | x = None
118 | file_path = None
119 |
120 | if compress:
121 | file_path = snap_dir + hexdigest + ".npz"
122 | if os.path.exists(file_path) and not on_the_fly:
123 | loaded = True
124 | with open(file_path, 'rb') as f:
125 | npzfile = np.load(f)
126 | x, y = npzfile['x']
127 | else:
128 | x_file_path = snap_dir + hexdigest + "_x.bin"
129 | if os.path.exists(x_file_path) and not on_the_fly:
130 | loaded = True
131 | with open(x_file_path, 'rb') as f:
132 | x = np.fromfile(f, dtype=np.uint8) \
133 | .reshape(num_imgs, 128, 128, 3)
134 |
135 | # Here we actually load the images.
136 | if not loaded:
137 | # Pre-allocate numpy arrays
138 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8)
139 |
140 | for i, path in enumerate(paths):
141 | if verbose:
142 | print("\r" + path + " processed: " + str(i + 1), end='')
143 | x[i] = np.array(Image.open(path))
144 |
145 | if verbose:
146 | print()
147 |
148 | if not on_the_fly:
149 | # Then we save x
150 | if compress:
151 | with open(file_path, 'wb') as g:
152 | np.savez_compressed(g, x=x)
153 | else:
154 | x.tofile(snap_dir + hexdigest + "_x.bin")
155 |
156 | assert (x is not None), 'Problems loading data. x is None!'
157 |
158 | return x
159 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/facebookresearch/GradientEpisodicMemory
3 | &
4 | https://github.com/kuangliu/pytorch-cifar
5 | """
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from torch.nn.functional import relu, avg_pool2d
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(in_planes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion * planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
28 | stride=stride, bias=False),
29 | nn.BatchNorm2d(self.expansion * planes)
30 | )
31 |
32 | def forward(self, x):
33 | out = relu(self.bn1(self.conv1(x)))
34 | out = self.bn2(self.conv2(out))
35 | out += self.shortcut(x)
36 | out = relu(out)
37 | return out
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1):
43 | super(Bottleneck, self).__init__()
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
47 | stride=stride, padding=1, bias=False)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.conv3 = nn.Conv2d(planes, self.expansion *
50 | planes, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
52 |
53 | self.shortcut = nn.Sequential()
54 | if stride != 1 or in_planes != self.expansion * planes:
55 | self.shortcut = nn.Sequential(
56 | nn.Conv2d(in_planes, self.expansion * planes,
57 | kernel_size=1, stride=stride, bias=False),
58 | nn.BatchNorm2d(self.expansion * planes)
59 | )
60 |
61 | def forward(self, x):
62 | out = relu(self.bn1(self.conv1(x)))
63 | out = relu(self.bn2(self.conv2(out)))
64 | out = self.bn3(self.conv3(out))
65 | out += self.shortcut(x)
66 | out = relu(out)
67 | return out
68 |
69 | class ResNet(nn.Module):
70 | def __init__(self, block, num_blocks, num_classes, nf, bias):
71 | super(ResNet, self).__init__()
72 | self.in_planes = nf
73 | self.conv1 = conv3x3(3, nf * 1)
74 | self.bn1 = nn.BatchNorm2d(nf * 1)
75 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
76 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
77 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
78 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
79 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes, bias=bias)
80 |
81 |
82 | def _make_layer(self, block, planes, num_blocks, stride):
83 | strides = [stride] + [1] * (num_blocks - 1)
84 | layers = []
85 | for stride in strides:
86 | layers.append(block(self.in_planes, planes, stride))
87 | self.in_planes = planes * block.expansion
88 | return nn.Sequential(*layers)
89 |
90 | def features(self, x):
91 | '''Features before FC layers'''
92 | out = relu(self.bn1(self.conv1(x)))
93 | out = self.layer1(out)
94 | out = self.layer2(out)
95 | out = self.layer3(out)
96 | out = self.layer4(out)
97 | out = avg_pool2d(out, 4)
98 | out = out.view(out.size(0), -1)
99 | return out
100 |
101 | def logits(self, x):
102 | '''Apply the last FC linear mapping to get logits'''
103 | x = self.linear(x)
104 | return x
105 |
106 | def forward(self, x):
107 | out = self.features(x)
108 | logits = self.logits(out)
109 | return logits
110 |
111 |
112 | def Reduced_ResNet18(nclasses, nf=20, bias=True):
113 | """
114 | Reduced ResNet18 as in GEM MIR(note that nf=20).
115 | """
116 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias)
117 |
118 | def ResNet18(nclasses, nf=64, bias=True):
119 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias)
120 |
121 | '''
122 | See https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
123 | '''
124 |
125 | def ResNet34(nclasses, nf=64, bias=True):
126 | return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf, bias)
127 |
128 | def ResNet50(nclasses, nf=64, bias=True):
129 | return ResNet(Bottleneck, [3, 4, 6, 3], nclasses, nf, bias)
130 |
131 |
132 | def ResNet101(nclasses, nf=64, bias=True):
133 | return ResNet(Bottleneck, [3, 4, 23, 3], nclasses, nf, bias)
134 |
135 |
136 | def ResNet152(nclasses, nf=64, bias=True):
137 | return ResNet(Bottleneck, [3, 8, 36, 3], nclasses, nf, bias)
138 |
139 |
140 | class SupConResNet(nn.Module):
141 | """backbone + projection head"""
142 | def __init__(self, dim_in=160, head='mlp', feat_dim=128):
143 | super(SupConResNet, self).__init__()
144 | self.encoder = Reduced_ResNet18(100)
145 | if head == 'linear':
146 | self.head = nn.Linear(dim_in, feat_dim)
147 | elif head == 'mlp':
148 | self.head = nn.Sequential(
149 | nn.Linear(dim_in, dim_in),
150 | nn.ReLU(inplace=True),
151 | nn.Linear(dim_in, feat_dim)
152 | )
153 | elif head == 'None':
154 | self.head = None
155 | else:
156 | raise NotImplementedError(
157 | 'head not supported: {}'.format(head))
158 |
159 | def forward(self, x):
160 | feat = self.encoder.features(x)
161 | if self.head:
162 | feat = F.normalize(self.head(feat), dim=1)
163 | else:
164 | feat = F.normalize(feat, dim=1)
165 | return feat
166 |
167 | def features(self, x):
168 | return self.encoder.features(x)
--------------------------------------------------------------------------------
/utils/buffer/gss_greedy_update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from utils.buffer.buffer_utils import get_grad_vector, cosine_similarity
5 | from utils.utils import maybe_cuda
6 |
7 | class GSSGreedyUpdate(object):
8 | def __init__(self, params):
9 | super().__init__()
10 | # the number of gradient vectors to estimate new samples similarity, line 5 in alg.2
11 | self.mem_strength = params.gss_mem_strength
12 | self.gss_batch_size = params.gss_batch_size
13 | self.buffer_score = maybe_cuda(torch.FloatTensor(params.mem_size).fill_(0))
14 |
15 | def update(self, buffer, x, y, **kwargs):
16 | buffer.model.eval()
17 |
18 | grad_dims = []
19 | for param in buffer.model.parameters():
20 | grad_dims.append(param.data.numel())
21 |
22 | place_left = buffer.buffer_img.size(0) - buffer.current_index
23 | if place_left <= 0: # buffer is full
24 | batch_sim, mem_grads = self.get_batch_sim(buffer, grad_dims, x, y)
25 | if batch_sim < 0:
26 | buffer_score = self.buffer_score[:buffer.current_index].cpu()
27 | buffer_sim = (buffer_score - torch.min(buffer_score)) / \
28 | ((torch.max(buffer_score) - torch.min(buffer_score)) + 0.01)
29 | # draw candidates for replacement from the buffer
30 | index = torch.multinomial(buffer_sim, x.size(0), replacement=False)
31 | # estimate the similarity of each sample in the recieved batch
32 | # to the randomly drawn samples from the buffer.
33 | batch_item_sim = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y)
34 | # normalize to [0,1]
35 | scaled_batch_item_sim = ((batch_item_sim + 1) / 2).unsqueeze(1)
36 | buffer_repl_batch_sim = ((self.buffer_score[index] + 1) / 2).unsqueeze(1)
37 | # draw an event to decide on replacement decision
38 | outcome = torch.multinomial(torch.cat((scaled_batch_item_sim, buffer_repl_batch_sim), dim=1), 1,
39 | replacement=False)
40 | # replace samples with outcome =1
41 | added_indx = torch.arange(end=batch_item_sim.size(0))
42 | sub_index = outcome.squeeze(1).bool()
43 | buffer.buffer_img[index[sub_index]] = x[added_indx[sub_index]].clone()
44 | buffer.buffer_label[index[sub_index]] = y[added_indx[sub_index]].clone()
45 | self.buffer_score[index[sub_index]] = batch_item_sim[added_indx[sub_index]].clone()
46 | else:
47 | offset = min(place_left, x.size(0))
48 | x = x[:offset]
49 | y = y[:offset]
50 | # first buffer insertion
51 | if buffer.current_index == 0:
52 | batch_sample_memory_cos = torch.zeros(x.size(0)) + 0.1
53 | else:
54 | # draw random samples from buffer
55 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims)
56 | # estimate a score for each added sample
57 | batch_sample_memory_cos = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y)
58 | buffer.buffer_img[buffer.current_index:buffer.current_index + offset].data.copy_(x)
59 | buffer.buffer_label[buffer.current_index:buffer.current_index + offset].data.copy_(y)
60 | self.buffer_score[buffer.current_index:buffer.current_index + offset] \
61 | .data.copy_(batch_sample_memory_cos)
62 | buffer.current_index += offset
63 | buffer.model.train()
64 |
65 | def get_batch_sim(self, buffer, grad_dims, batch_x, batch_y):
66 | """
67 | Args:
68 | buffer: memory buffer
69 | grad_dims: gradient dimensions
70 | batch_x: batch images
71 | batch_y: batch labels
72 | Returns: score of current batch, gradient from memory subsets
73 | """
74 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims)
75 | buffer.model.zero_grad()
76 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y)
77 | loss.backward()
78 | batch_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0)
79 | batch_sim = max(cosine_similarity(mem_grads, batch_grad))
80 | return batch_sim, mem_grads
81 |
82 | def get_rand_mem_grads(self, buffer, grad_dims):
83 | """
84 | Args:
85 | buffer: memory buffer
86 | grad_dims: gradient dimensions
87 | Returns: gradient from memory subsets
88 | """
89 | gss_batch_size = min(self.gss_batch_size, buffer.current_index)
90 | num_mem_subs = min(self.mem_strength, buffer.current_index // gss_batch_size)
91 | mem_grads = maybe_cuda(torch.zeros(num_mem_subs, sum(grad_dims), dtype=torch.float32))
92 | shuffeled_inds = torch.randperm(buffer.current_index)
93 | for i in range(num_mem_subs):
94 | random_batch_inds = shuffeled_inds[
95 | i * gss_batch_size:i * gss_batch_size + gss_batch_size]
96 | batch_x = buffer.buffer_img[random_batch_inds]
97 | batch_y = buffer.buffer_label[random_batch_inds]
98 | buffer.model.zero_grad()
99 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y)
100 | loss.backward()
101 | mem_grads[i].data.copy_(get_grad_vector(buffer.model.parameters, grad_dims))
102 | return mem_grads
103 |
104 | def get_each_batch_sample_sim(self, buffer, grad_dims, mem_grads, batch_x, batch_y):
105 | """
106 | Args:
107 | buffer: memory buffer
108 | grad_dims: gradient dimensions
109 | mem_grads: gradient from memory subsets
110 | batch_x: batch images
111 | batch_y: batch labels
112 | Returns: score of each sample from current batch
113 | """
114 | cosine_sim = maybe_cuda(torch.zeros(batch_x.size(0)))
115 | for i, (x, y) in enumerate(zip(batch_x, batch_y)):
116 | buffer.model.zero_grad()
117 | ptloss = F.cross_entropy(buffer.model.forward(x.unsqueeze(0)), y.unsqueeze(0))
118 | ptloss.backward()
119 | # add the new grad to the memory grads and add it is cosine similarity
120 | this_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0)
121 | cosine_sim[i] = max(cosine_similarity(mem_grads, this_grad))
122 | return cosine_sim
123 |
--------------------------------------------------------------------------------
/utils/buffer/aser_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import maybe_cuda, mini_batch_deep_features, euclidean_distance, nonzero_indices, ohe_label
3 | from utils.setup_elements import n_classes
4 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling
5 |
6 |
7 | def compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, k, device="cpu"):
8 | """
9 | Compute KNN SV of candidate data w.r.t. evaluation data.
10 | Args:
11 | model (object): neural network.
12 | eval_x (tensor): evaluation data tensor.
13 | eval_y (tensor): evaluation label tensor.
14 | cand_x (tensor): candidate data tensor.
15 | cand_y (tensor): candidate label tensor.
16 | k (int): number of nearest neighbours.
17 | device (str): device for tensor allocation.
18 | Returns
19 | sv_matrix (tensor): KNN Shapley value matrix of candidate data w.r.t. evaluation data.
20 | """
21 | # Compute KNN SV score for candidate samples w.r.t. evaluation samples
22 | n_eval = eval_x.size(0)
23 | n_cand = cand_x.size(0)
24 | # Initialize SV matrix to matrix of -1
25 | sv_matrix = torch.zeros((n_eval, n_cand), device=device)
26 | # Get deep features
27 | eval_df, cand_df = deep_features(model, eval_x, n_eval, cand_x, n_cand)
28 | # Sort indices based on distance in deep feature space
29 | sorted_ind_mat = sorted_cand_ind(eval_df, cand_df, n_eval, n_cand)
30 |
31 | # Evaluation set labels
32 | el = eval_y
33 | el_vec = el.repeat([n_cand, 1]).T
34 | # Sorted candidate set labels
35 | cl = cand_y[sorted_ind_mat]
36 |
37 | # Indicator function matrix
38 | indicator = (el_vec == cl).float()
39 | indicator_next = torch.zeros_like(indicator, device=device)
40 | indicator_next[:, 0:n_cand - 1] = indicator[:, 1:]
41 | indicator_diff = indicator - indicator_next
42 |
43 | cand_ind = torch.arange(n_cand, dtype=torch.float, device=device) + 1
44 | denom_factor = cand_ind.clone()
45 | denom_factor[:n_cand - 1] = denom_factor[:n_cand - 1] * k
46 | numer_factor = cand_ind.clone()
47 | numer_factor[k:n_cand - 1] = k
48 | numer_factor[n_cand - 1] = 1
49 | factor = numer_factor / denom_factor
50 |
51 | indicator_factor = indicator_diff * factor
52 | indicator_factor_cumsum = indicator_factor.flip(1).cumsum(1).flip(1)
53 |
54 | # Row indices
55 | row_ind = torch.arange(n_eval, device=device)
56 | row_mat = torch.repeat_interleave(row_ind, n_cand).reshape([n_eval, n_cand])
57 |
58 | # Compute SV recursively
59 | sv_matrix[row_mat, sorted_ind_mat] = indicator_factor_cumsum
60 |
61 | return sv_matrix
62 |
63 |
64 | def deep_features(model, eval_x, n_eval, cand_x, n_cand):
65 | """
66 | Compute deep features of evaluation and candidate data.
67 | Args:
68 | model (object): neural network.
69 | eval_x (tensor): evaluation data tensor.
70 | n_eval (int): number of evaluation data.
71 | cand_x (tensor): candidate data tensor.
72 | n_cand (int): number of candidate data.
73 | Returns
74 | eval_df (tensor): deep features of evaluation data.
75 | cand_df (tensor): deep features of evaluation data.
76 | """
77 | # Get deep features
78 | if cand_x is None:
79 | num = n_eval
80 | total_x = eval_x
81 | else:
82 | num = n_eval + n_cand
83 | total_x = torch.cat((eval_x, cand_x), 0)
84 |
85 | # compute deep features with mini-batches
86 | total_x = maybe_cuda(total_x)
87 | deep_features_ = mini_batch_deep_features(model, total_x, num)
88 |
89 | eval_df = deep_features_[0:n_eval]
90 | cand_df = deep_features_[n_eval:]
91 | return eval_df, cand_df
92 |
93 |
94 | def sorted_cand_ind(eval_df, cand_df, n_eval, n_cand):
95 | """
96 | Sort indices of candidate data according to
97 | their Euclidean distance to each evaluation data in deep feature space.
98 | Args:
99 | eval_df (tensor): deep features of evaluation data.
100 | cand_df (tensor): deep features of evaluation data.
101 | n_eval (int): number of evaluation data.
102 | n_cand (int): number of candidate data.
103 | Returns
104 | sorted_cand_ind (tensor): sorted indices of candidate set w.r.t. each evaluation data.
105 | """
106 | # Sort indices of candidate set according to distance w.r.t. evaluation set in deep feature space
107 | # Preprocess feature vectors to facilitate vector-wise distance computation
108 | eval_df_repeat = eval_df.repeat([1, n_cand]).reshape([n_eval * n_cand, eval_df.shape[1]])
109 | cand_df_tile = cand_df.repeat([n_eval, 1])
110 | # Compute distance between evaluation and candidate feature vectors
111 | distance_vector = euclidean_distance(eval_df_repeat, cand_df_tile)
112 | # Turn distance vector into distance matrix
113 | distance_matrix = distance_vector.reshape((n_eval, n_cand))
114 | # Sort candidate set indices based on distance
115 | sorted_cand_ind_ = distance_matrix.argsort(1)
116 | return sorted_cand_ind_
117 |
118 |
119 | def add_minority_class_input(cur_x, cur_y, mem_size, num_class):
120 | """
121 | Find input instances from minority classes, and concatenate them to evaluation data/label tensors later.
122 | This facilitates the inclusion of minority class samples into memory when ASER's update method is used under online-class incremental setting.
123 |
124 | More details:
125 |
126 | Evaluation set may not contain any samples from minority classes (i.e., those classes with very few number of corresponding samples stored in the memory).
127 | This happens after task changes in online-class incremental setting.
128 | Minority class samples can then get very low or negative KNN-SV, making it difficult to store any of them in the memory.
129 |
130 | By identifying minority class samples in the current input batch, and concatenating them to the evaluation set,
131 | KNN-SV of the minority class samples can be artificially boosted (i.e., positive value with larger magnitude).
132 | This allows to quickly accomodate new class samples in the memory right after task changes.
133 |
134 | Threshold for being a minority class is a hyper-parameter related to the class proportion.
135 | In this implementation, it is randomly selected between 0 and 1 / number of all classes for each current input batch.
136 |
137 |
138 | Args:
139 | cur_x (tensor): current input data tensor.
140 | cur_y (tensor): current input label tensor.
141 | mem_size (int): memory size.
142 | num_class (int): number of classes in dataset.
143 | Returns
144 | minority_batch_x (tensor): subset of current input data from minority class.
145 | minority_batch_y (tensor): subset of current input label from minority class.
146 | """
147 | # Select input instances from minority classes that will be concatenated to pre-selected data
148 | threshold = torch.tensor(1).float().uniform_(0, 1 / num_class).item()
149 |
150 | # If number of buffered samples from certain class is lower than random threshold,
151 | # that class is minority class
152 | cls_proportion = ClassBalancedRandomSampling.class_num_cache.float() / mem_size
153 | minority_ind = nonzero_indices(cls_proportion[cur_y] < threshold)
154 |
155 | minority_batch_x = cur_x[minority_ind]
156 | minority_batch_y = cur_y[minority_ind]
157 | return minority_batch_x, minority_batch_y
158 |
--------------------------------------------------------------------------------
/models/ndpm/ndpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler
4 |
5 | from utils.utils import maybe_cuda
6 | from utils.global_vars import *
7 | from .expert import Expert
8 | from .priors import CumulativePrior
9 |
10 |
11 | class Ndpm(nn.Module):
12 | def __init__(self, params):
13 | super().__init__()
14 | self.params = params
15 | self.experts = nn.ModuleList([Expert(params)])
16 | self.stm_capacity = params.stm_capacity
17 | self.stm_x = []
18 | self.stm_y = []
19 | self.prior = CumulativePrior(params)
20 |
21 | def get_experts(self):
22 | return tuple(self.experts.children())
23 |
24 | def forward(self, x):
25 | with torch.no_grad():
26 | if len(self.experts) == 1:
27 | raise RuntimeError('There\'s no expert to run on the input')
28 | x = maybe_cuda(x)
29 | log_evid = -self.experts[-1].g.collect_nll(x) # [B, 1+K]
30 | log_evid = log_evid[:, 1:].unsqueeze(2) # [B, K, 1]
31 | log_prior = -self.prior.nl_prior()[1:] # [K]
32 | log_prior -= torch.logsumexp(log_prior, dim=0)
33 | log_prior = log_prior.unsqueeze(0).unsqueeze(2) # [1, K, 1]
34 | log_joint = log_prior + log_evid # [B, K, 1]
35 | if not MODELS_NDPM_NDPM_DISABLE_D:
36 | log_pred = self.experts[-1].d.collect_forward(x) # [B, 1+K, C]
37 | log_pred = log_pred[:, 1:, :] # [B, K, C]
38 | log_joint = log_joint + log_pred # [B, K, C]
39 |
40 | log_joint = log_joint.logsumexp(dim=1).squeeze() # [B,] or [B, C]
41 | return log_joint
42 |
43 |
44 | def learn(self, x, y):
45 | x, y = maybe_cuda(x), maybe_cuda(y)
46 |
47 | if MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS:
48 | self.stm_x.extend(torch.unbind(x.cpu()))
49 | self.stm_y.extend(torch.unbind(y.cpu()))
50 | else:
51 | # Determine the destination of each data point
52 | nll = self.experts[-1].collect_nll(x, y) # [B, 1+K]
53 | nl_prior = self.prior.nl_prior() # [1+K]
54 | nl_joint = nll + nl_prior.unsqueeze(0).expand(
55 | nll.size(0), -1) # [B, 1+K]
56 |
57 | # Save to short-term memory
58 | destination = maybe_cuda(torch.argmin(nl_joint, dim=1)) # [B]
59 | to_stm = destination == 0 # [B]
60 | self.stm_x.extend(torch.unbind(x[to_stm].cpu()))
61 | self.stm_y.extend(torch.unbind(y[to_stm].cpu()))
62 |
63 | # Train expert
64 | with torch.no_grad():
65 | min_joint = nl_joint.min(dim=1)[0].view(-1, 1)
66 | to_expert = torch.exp(-nl_joint + min_joint) # [B, 1+K]
67 | to_expert[:, 0] = 0. # [B, 1+K]
68 | to_expert = \
69 | to_expert / (to_expert.sum(dim=1).view(-1, 1) + 1e-7)
70 |
71 | # Compute losses per expert
72 | nll_for_train = nll * (1. - to_stm.float()).unsqueeze(1) # [B,1+K]
73 | losses = (nll_for_train * to_expert).sum(0) # [1+K]
74 |
75 | # Record expert usage
76 | expert_usage = to_expert.sum(dim=0) # [K+1]
77 | self.prior.record_usage(expert_usage)
78 |
79 | # Do lr_decay implicitly
80 | if MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY:
81 | losses = losses \
82 | * self.params.stm_capacity / (self.prior.counts + 1e-8)
83 | loss = losses.sum()
84 |
85 | if loss.requires_grad:
86 | update_threshold = 0
87 | for k, usage in enumerate(expert_usage):
88 | if usage > update_threshold:
89 | self.experts[k].zero_grad()
90 | loss.backward()
91 | for k, usage in enumerate(expert_usage):
92 | if usage > update_threshold:
93 | self.experts[k].clip_grad()
94 | self.experts[k].optimizer_step()
95 | self.experts[k].lr_scheduler_step()
96 |
97 | # Sleep
98 | if len(self.stm_x) >= self.stm_capacity:
99 | dream_dataset = TensorDataset(
100 | torch.stack(self.stm_x), torch.stack(self.stm_y))
101 | self.sleep(dream_dataset)
102 | self.stm_x = []
103 | self.stm_y = []
104 |
105 | def sleep(self, dream_dataset):
106 | print('\nGoing to sleep...')
107 | # Add new expert and optimizer
108 | expert = Expert(self.params, self.get_experts())
109 | self.experts.append(expert)
110 | self.prior.add_expert()
111 |
112 | stacked_stm_x = torch.stack(self.stm_x)
113 | stacked_stm_y = torch.stack(self.stm_y)
114 | indices = torch.randperm(stacked_stm_x.size(0))
115 | train_size = stacked_stm_x.size(0) - MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE
116 | dream_dataset = TensorDataset(
117 | stacked_stm_x[indices[:train_size]],
118 | stacked_stm_y[indices[:train_size]])
119 |
120 | # Prepare data iterator
121 | self.prior.record_usage(len(dream_dataset), index=-1)
122 | dream_iterator = iter(DataLoader(
123 | dream_dataset,
124 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE,
125 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS,
126 | sampler=RandomSampler(
127 | dream_dataset,
128 | replacement=True,
129 | num_samples=(
130 | MODELS_NDPM_NDPM_SLEEP_STEP_G *
131 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE
132 | ))
133 | ))
134 |
135 | # Train generative component
136 | for step, (x, y) in enumerate(dream_iterator):
137 | step += 1
138 | x, y = maybe_cuda(x), maybe_cuda(y)
139 | g_loss = expert.g.nll(x, y, step=step)
140 | g_loss = (g_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY
141 | * expert.g.weight_decay_loss())
142 | expert.g.zero_grad()
143 | g_loss.mean().backward()
144 | expert.g.clip_grad()
145 | expert.g.optimizer.step()
146 |
147 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0:
148 | print('\r [Sleep-G %6d] loss: %5.1f' % (
149 | step, g_loss.mean()
150 | ), end='')
151 | print()
152 |
153 | dream_iterator = iter(DataLoader(
154 | dream_dataset,
155 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE,
156 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS,
157 | sampler=RandomSampler(
158 | dream_dataset,
159 | replacement=True,
160 | num_samples=(
161 | MODELS_NDPM_NDPM_SLEEP_STEP_D *
162 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE)
163 | )
164 | ))
165 |
166 | # Train discriminative component
167 | if not MODELS_NDPM_NDPM_DISABLE_D:
168 | for step, (x, y) in enumerate(dream_iterator):
169 | step += 1
170 | x, y = maybe_cuda(x), maybe_cuda(y)
171 | d_loss = expert.d.nll(x, y, step=step)
172 | d_loss = (d_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY
173 | * expert.d.weight_decay_loss())
174 | expert.d.zero_grad()
175 | d_loss.mean().backward()
176 | expert.d.clip_grad()
177 | expert.d.optimizer.step()
178 |
179 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0:
180 | print('\r [Sleep-D %6d] loss: %5.1f' % (
181 | step, d_loss.mean()
182 | ), end='')
183 |
184 | expert.lr_scheduler_step()
185 | expert.lr_scheduler_step()
186 | expert.eval()
187 | print()
188 |
189 | @staticmethod
190 | def _nl_joint(nl_prior, nll):
191 | batch = nll.size(0)
192 | nl_prior = nl_prior.unsqueeze(0).expand(batch, -1) # [B, 1+K]
193 | return nll + nl_prior
194 |
195 | def train(self, mode=True):
196 | # Disabled
197 | pass
198 |
--------------------------------------------------------------------------------
/continuum/non_stationary.py:
--------------------------------------------------------------------------------
1 | import random
2 | from copy import deepcopy
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from skimage.filters import gaussian
6 | from continuum.data_utils import train_val_test_split_ni
7 |
8 |
9 | class Original(object):
10 | def __init__(self, x, y, unroll=False, color=False):
11 | if color:
12 | self.x = x / 255.0
13 | else:
14 | self.x = x
15 | self.next_x = self.x
16 | self.next_y = y
17 | self.y = y
18 | self.unroll = unroll
19 |
20 | def get_dims(self):
21 | # Get data input and output dimensions
22 | print("input size {}\noutput size {}".format(self.x.shape[1], self.y.shape[1]))
23 | return self.x.shape[1], self.y.shape[1]
24 |
25 | def show_sample(self, num_plot=1):
26 | # idx = np.random.choice(self.x.shape[0])
27 | for i in range(num_plot):
28 | plt.subplot(1, 2, 1)
29 | if self.x[i].shape[2] == 1:
30 | plt.imshow(np.squeeze(self.x[i]))
31 | else:
32 | plt.imshow(self.x[i])
33 | plt.title("original task image")
34 | plt.subplot(1, 2, 2)
35 | if self.x[i].shape[2] == 1:
36 | plt.imshow(np.squeeze(self.next_x[i]))
37 | else:
38 | plt.imshow(self.next_x[i])
39 | plt.title(self.get_name())
40 | plt.axis('off')
41 | plt.show()
42 |
43 | def create_output(self):
44 | if self.unroll:
45 | ret = self.next_x.reshape((-1, self.x.shape[1] ** 2)), self.next_y
46 | else:
47 | ret = self.next_x, self.next_y
48 | return ret
49 |
50 | @staticmethod
51 | def clip_minmax(l, min_=0., max_=1.):
52 | return np.clip(l, min_, max_)
53 |
54 | def get_name(self):
55 | if hasattr(self, 'factor'):
56 | return str(self.__class__.__name__) + '_' + str(self.factor)
57 |
58 | def next_task(self, *args):
59 | self.next_x = self.x
60 | self.next_y = self.y
61 | return self.create_output()
62 |
63 |
64 | class Noisy(Original):
65 | def __init__(self, x, y, full=False, color=False):
66 | super(Noisy, self).__init__(x, y, full, color)
67 |
68 | def next_task(self, noise_factor=0.8, sig=0.1, noise_type='Gaussian'):
69 | next_x = deepcopy(self.x)
70 | self.factor = noise_factor
71 | if noise_type == 'Gaussian':
72 | self.next_x = next_x + noise_factor * np.random.normal(loc=0.0, scale=sig, size=next_x.shape)
73 | elif noise_factor == 'S&P':
74 | # TODO implement S&P
75 | pass
76 |
77 | self.next_x = super().clip_minmax(self.next_x, 0, 1)
78 |
79 | return super().create_output()
80 |
81 |
82 | class Blurring(Original):
83 | def __init__(self, x, y, full=False, color=False):
84 | super(Blurring, self).__init__(x, y, full, color)
85 |
86 | def next_task(self, blurry_factor=0.6, blurry_type='Gaussian'):
87 | next_x = deepcopy(self.x)
88 | self.factor = blurry_factor
89 | if blurry_type == 'Gaussian':
90 | self.next_x = gaussian(next_x, sigma=blurry_factor, multichannel=True)
91 | elif blurry_type == 'Average':
92 | pass
93 | # TODO implement average
94 |
95 | self.next_x = super().clip_minmax(self.next_x, 0, 1)
96 |
97 | return super().create_output()
98 |
99 |
100 | class Occlusion(Original):
101 | def __init__(self, x, y, full=False, color=False):
102 | super(Occlusion, self).__init__(x, y, full, color)
103 |
104 | def next_task(self, occlusion_factor=0.2):
105 | next_x = deepcopy(self.x)
106 | self.factor = occlusion_factor
107 | self.image_size = next_x.shape[1]
108 |
109 | occlusion_size = int(occlusion_factor * self.image_size)
110 | half_size = occlusion_size // 2
111 | occlusion_x = random.randint(min(half_size, self.image_size - half_size),
112 | max(half_size, self.image_size - half_size))
113 | occlusion_y = random.randint(min(half_size, self.image_size - half_size),
114 | max(half_size, self.image_size - half_size))
115 |
116 | # self.next_x = next_x.reshape((-1, self.image_size, self.image_size))
117 |
118 | next_x[:, max((occlusion_x - half_size), 0):min((occlusion_x + half_size), self.image_size), \
119 | max((occlusion_y - half_size), 0):min((occlusion_y + half_size), self.image_size)] = 1
120 |
121 | self.next_x = next_x
122 | super().clip_minmax(self.next_x, 0, 1)
123 |
124 | return super().create_output()
125 |
126 |
127 | def test_ns(x, y, ns_type, change_factor):
128 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring}
129 | change = ns_match[ns_type]
130 | tmp = change(x, y, color=True)
131 | tmp.next_task(change_factor)
132 | tmp.show_sample(10)
133 |
134 |
135 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring}
136 |
137 |
138 | def construct_ns_single(train_x_split, train_y_split, test_x_split, test_y_split, ns_type, change_factor, ns_task,
139 | plot=True):
140 | # Data splits
141 | train_list = []
142 | test_list = []
143 | change = ns_match[ns_type]
144 | i = 0
145 | if len(change_factor) == 1:
146 | change_factor = change_factor[0]
147 | for idx, val in enumerate(ns_task):
148 | if idx % 2 == 0:
149 | for _ in range(val):
150 | print(i, 'normal')
151 | # train
152 | tmp = Original(train_x_split[i], train_y_split[i], color=True)
153 | train_list.append(tmp.next_task())
154 | if plot:
155 | tmp.show_sample()
156 |
157 | # test
158 | tmp_test = Original(test_x_split[i], test_y_split[i], color=True)
159 | test_list.append(tmp_test.next_task())
160 | if plot:
161 | tmp_test.show_sample()
162 |
163 | i += 1
164 | else:
165 | for _ in range(val):
166 | print(i, 'change')
167 | # train
168 | tmp = change(train_x_split[i], train_y_split[i], color=True)
169 | train_list.append(tmp.next_task(change_factor))
170 | if plot:
171 | tmp.show_sample()
172 | # test
173 | tmp_test = change(test_x_split[i], test_y_split[i], color=True)
174 | test_list.append(tmp_test.next_task(change_factor))
175 | if plot:
176 | tmp_test.show_sample()
177 |
178 | i += 1
179 | return train_list, test_list
180 |
181 |
182 | def construct_ns_multiple(train_x_split, train_y_split, val_x_rdm_split, val_y_rdm_split, test_x_split,
183 | test_y_split, ns_type, change_factors, plot):
184 | train_list = []
185 | val_list = []
186 | test_list = []
187 | ns_len = len(change_factors)
188 | for i in range(ns_len):
189 | factor = change_factors[i]
190 | if factor == 0:
191 | ns_generator = Original
192 | else:
193 | ns_generator = ns_match[ns_type]
194 | print(i, factor)
195 | # train
196 | tmp = ns_generator(train_x_split[i], train_y_split[i], color=True)
197 | train_list.append(tmp.next_task(factor))
198 | if plot:
199 | tmp.show_sample()
200 |
201 | tmp_val = ns_generator(val_x_rdm_split[i], val_y_rdm_split[i], color=True)
202 | val_list.append(tmp_val.next_task(factor))
203 |
204 | tmp_test = ns_generator(test_x_split[i], test_y_split[i], color=True)
205 | test_list.append(tmp_test.next_task(factor))
206 | return train_list, val_list, test_list
207 |
208 |
209 | def construct_ns_multiple_wrapper(train_data, train_label, test_data, est_label, task_nums, img_size,
210 | val_size, ns_type, ns_factor, plot):
211 | train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split = train_val_test_split_ni(
212 | train_data, train_label, test_data, est_label, task_nums, img_size,
213 | val_size)
214 | train_set, val_set, test_set = construct_ns_multiple(train_data_rdm_split, train_label_rdm_split,
215 | val_data_rdm_split, val_label_rdm_split,
216 | test_data_rdm_split, test_label_rdm_split,
217 | ns_type,
218 | ns_factor,
219 | plot=plot)
220 | return train_set, val_set, test_set
221 |
--------------------------------------------------------------------------------
/models/ndpm/classifier.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from utils.utils import maybe_cuda
6 | from .component import ComponentD
7 | from utils.global_vars import *
8 | from utils.setup_elements import n_classes
9 |
10 |
11 | class Classifier(ComponentD, ABC):
12 | def __init__(self, params, experts):
13 | super().__init__(params, experts)
14 | self.ce_loss = nn.NLLLoss(reduction='none')
15 |
16 | @abstractmethod
17 | def forward(self, x):
18 | """Output log P(y|x)"""
19 | pass
20 |
21 | def nll(self, x, y, step=None):
22 | x, y = maybe_cuda(x), maybe_cuda(y)
23 | log_softmax = self.forward(x)
24 | loss_pred = self.ce_loss(log_softmax, y)
25 |
26 | # Classifier chilling
27 | chilled_log_softmax = F.log_softmax(
28 | log_softmax / self.params.classifier_chill, dim=1)
29 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y)
30 |
31 | # Value with chill & gradient without chill
32 | loss_pred = loss_pred - loss_pred.detach() \
33 | + chilled_loss_pred.detach()
34 |
35 | return loss_pred
36 |
37 |
38 | class SharingClassifier(Classifier, ABC):
39 | @abstractmethod
40 | def forward(self, x, collect=False):
41 | pass
42 |
43 | def collect_forward(self, x):
44 | dummy_pred = self.experts[0](x)
45 | preds, _ = self.forward(x, collect=True)
46 | return torch.stack([dummy_pred] + preds, dim=1)
47 |
48 | def collect_nll(self, x, y, step=None):
49 | preds = self.collect_forward(x) # [B, 1+K, C]
50 | loss_preds = []
51 | for log_softmax in preds.unbind(dim=1):
52 | loss_pred = self.ce_loss(log_softmax, y)
53 |
54 | # Classifier chilling
55 | chilled_log_softmax = F.log_softmax(
56 | log_softmax / self.params.classifier_chill, dim=1)
57 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y)
58 |
59 | # Value with chill & gradient without chill
60 | loss_pred = loss_pred - loss_pred.detach() \
61 | + chilled_loss_pred.detach()
62 |
63 | loss_preds.append(loss_pred)
64 | return torch.stack(loss_preds, dim=1)
65 |
66 |
67 | class BasicBlock(nn.Module):
68 | expansion = 1
69 |
70 | def __init__(self, inplanes, planes, stride=1,
71 | downsample=None, upsample=None,
72 | dilation=1, norm_layer=nn.BatchNorm2d):
73 | super(BasicBlock, self).__init__()
74 | if dilation > 1:
75 | raise NotImplementedError(
76 | "Dilation > 1 not supported in BasicBlock"
77 | )
78 | transpose = upsample is not None and stride != 1
79 | self.conv1 = (
80 | conv4x4t(inplanes, planes, stride) if transpose else
81 | conv3x3(inplanes, planes, stride)
82 | )
83 | self.bn1 = norm_layer(planes)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.conv2 = conv3x3(planes, planes)
86 | self.bn2 = norm_layer(planes)
87 | self.downsample = downsample
88 | self.upsample = upsample
89 | self.stride = stride
90 |
91 | def forward(self, x):
92 | identity = x
93 |
94 | out = self.conv1(x)
95 | out = self.bn1(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv2(out)
99 | out = self.bn2(out)
100 |
101 | if self.downsample is not None:
102 | identity = self.downsample(x)
103 | elif self.upsample is not None:
104 | identity = self.upsample(x)
105 |
106 | out += identity
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 |
112 | def conv4x4t(in_planes, out_planes, stride=1, groups=1, dilation=1):
113 | """4x4 transposed convolution with padding"""
114 | return nn.ConvTranspose2d(
115 | in_planes, out_planes, kernel_size=4, stride=stride,
116 | padding=dilation, groups=groups, bias=False, dilation=dilation,
117 | )
118 |
119 |
120 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
121 | """3x3 convolution with padding"""
122 | return nn.Conv2d(
123 | in_planes, out_planes, kernel_size=3, stride=stride,
124 | padding=dilation, groups=groups, bias=False, dilation=dilation
125 | )
126 |
127 |
128 | def conv1x1(in_planes, out_planes, stride=1):
129 | """1x1 convolution"""
130 | return nn.Conv2d(
131 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False
132 | )
133 |
134 |
135 | class ResNetSharingClassifier(SharingClassifier):
136 | block = BasicBlock
137 | num_blocks = [2, 2, 2, 2]
138 | norm_layer = nn.InstanceNorm2d
139 |
140 | def __init__(self, params, experts):
141 | super().__init__(params, experts)
142 | self.precursors = [expert.d for expert in self.experts[1:]]
143 | first = len(self.precursors) == 0
144 |
145 | if MODELS_NDPM_CLASSIFIER_NUM_BLOCKS is not None:
146 | num_blocks = MODELS_NDPM_CLASSIFIER_NUM_BLOCKS
147 | else:
148 | num_blocks = self.num_blocks
149 | if MODELS_NDPM_CLASSIFIER_NORM_LAYER is not None:
150 | self.norm_layer = getattr(nn, MODELS_NDPM_CLASSIFIER_NORM_LAYER)
151 | else:
152 | self.norm_layer = nn.BatchNorm2d
153 |
154 | num_classes = n_classes[params.data]
155 | nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT
156 | nf_cat = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE \
157 | + len(self.precursors) * MODELS_NDPM_CLASSIFIER_CLS_NF_EXT
158 | self.nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT
159 | self.nf_cat = nf_cat
160 |
161 | self.layer0 = nn.Sequential(
162 | nn.Conv2d(
163 | 3, nf * 1, kernel_size=3, stride=1, padding=1, bias=False
164 | ),
165 | self.norm_layer(nf * 1),
166 | nn.ReLU()
167 | )
168 | self.layer1 = self._make_layer(
169 | nf_cat * 1, nf * 1, num_blocks[0], stride=1)
170 | self.layer2 = self._make_layer(
171 | nf_cat * 1, nf * 2, num_blocks[1], stride=2)
172 | self.layer3 = self._make_layer(
173 | nf_cat * 2, nf * 4, num_blocks[2], stride=2)
174 | self.layer4 = self._make_layer(
175 | nf_cat * 4, nf * 8, num_blocks[3], stride=2)
176 | self.predict = nn.Sequential(
177 | nn.Linear(nf_cat * 8, num_classes),
178 | nn.LogSoftmax(dim=1)
179 | )
180 | self.setup_optimizer()
181 |
182 | def _make_layer(self, nf_in, nf_out, num_blocks, stride):
183 | norm_layer = self.norm_layer
184 | block = self.block
185 | downsample = None
186 | if stride != 1 or nf_in != nf_out:
187 | downsample = nn.Sequential(
188 | conv1x1(nf_in, nf_out, stride),
189 | norm_layer(nf_out),
190 | )
191 | layers = [block(
192 | nf_in, nf_out, stride,
193 | downsample=downsample,
194 | norm_layer=norm_layer
195 | )]
196 | for _ in range(1, num_blocks):
197 | layers.append(block(nf_out, nf_out, norm_layer=norm_layer))
198 |
199 | return nn.Sequential(*layers)
200 |
201 | def forward(self, x, collect=False):
202 | x = maybe_cuda(x)
203 |
204 | # First component
205 | if len(self.precursors) == 0:
206 | h1 = self.layer0(x)
207 | h2 = self.layer1(h1)
208 | h3 = self.layer2(h2)
209 | h4 = self.layer3(h3)
210 | h5 = self.layer4(h4)
211 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1)
212 | pred = self.predict(h5)
213 |
214 | if collect:
215 | return [pred], [
216 | h1.detach(), h2.detach(), h3.detach(),
217 | h4.detach(), h5.detach()]
218 | else:
219 | return pred
220 |
221 | # Second or layer component
222 | preds, features = self.precursors[-1](x, collect=True)
223 | h1 = self.layer0(x)
224 | h1_cat = torch.cat([features[0], h1], dim=1)
225 | h2 = self.layer1(h1_cat)
226 | h2_cat = torch.cat([features[1], h2], dim=1)
227 | h3 = self.layer2(h2_cat)
228 | h3_cat = torch.cat([features[2], h3], dim=1)
229 | h4 = self.layer3(h3_cat)
230 | h4_cat = torch.cat([features[3], h4], dim=1)
231 | h5 = self.layer4(h4_cat)
232 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1)
233 | h5_cat = torch.cat([features[4], h5], dim=1)
234 | pred = self.predict(h5_cat)
235 |
236 | if collect:
237 | preds.append(pred)
238 | return preds, [
239 | h1_cat.detach(), h2_cat.detach(), h3_cat.detach(),
240 | h4_cat.detach(), h5_cat.detach(),
241 | ]
242 | else:
243 | return pred
244 |
--------------------------------------------------------------------------------
/utils/buffer/buffer_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from utils.utils import maybe_cuda
4 | from collections import defaultdict
5 | from collections import Counter
6 | import random
7 |
8 |
9 | def random_retrieve(buffer, num_retrieve, excl_indices=None, return_indices=False):
10 | filled_indices = np.arange(buffer.current_index)
11 | if excl_indices is not None:
12 | excl_indices = list(excl_indices)
13 | else:
14 | excl_indices = []
15 | valid_indices = np.setdiff1d(filled_indices, np.array(excl_indices))
16 | num_retrieve = min(num_retrieve, valid_indices.shape[0])
17 | indices = torch.from_numpy(np.random.choice(valid_indices, num_retrieve, replace=False)).long()
18 |
19 | x = buffer.buffer_img[indices]
20 |
21 | y = buffer.buffer_label[indices]
22 |
23 | if return_indices:
24 | return x, y, indices
25 | else:
26 | return x, y
27 |
28 |
29 | def balanced_retrieve(buffer, num_retrieve, excl_labels=None, return_indices=False):
30 | avail_labels = list(buffer.labeldict.keys())
31 | if excl_labels is not None:
32 | avail_labels = np.array(list(set(avail_labels) - set(excl_labels)))
33 | if len(avail_labels) > 0:
34 | num_instances = num_retrieve // 10
35 | random_labels = np.random.choice(avail_labels, 10, replace=False)
36 | indices = []
37 | for label in random_labels:
38 | tmp_instances = min(num_instances, len(buffer.labeldict[label]))
39 | if tmp_instances:
40 | tmp_indices = np.random.choice(buffer.labeldict[label], tmp_instances, replace=False)
41 | indices.extend(tmp_indices)
42 | else:
43 | indices = []
44 |
45 | indices = np.random.permutation(indices)
46 | x = buffer.buffer_img[indices]
47 | y = buffer.buffer_label[indices]
48 |
49 | if return_indices:
50 | return x, y, indices
51 | else:
52 | return x, y
53 |
54 |
55 | def match_retrieve(buffer, cur_y, exclud_idx=None):
56 | counter = Counter(cur_y.tolist())
57 | idx_dict = defaultdict(list)
58 | for idx, val in enumerate(cur_y.tolist()):
59 | idx_dict[val].append(idx)
60 | select = [None] * len(cur_y)
61 | for y in counter:
62 | idx = buffer.buffer_tracker.class_index_cache[y]
63 | if exclud_idx is not None:
64 | idx = idx - set(exclud_idx.tolist())
65 | if not idx or len(idx) < counter[y]:
66 | print('match retrieve attempt fail')
67 | return torch.tensor([]), torch.tensor([])
68 | retrieved = random.sample(list(idx), counter[y])
69 | for idx, val in zip(idx_dict[y], retrieved):
70 | select[idx] = val
71 | indices = torch.tensor(select)
72 | x = buffer.buffer_img[indices]
73 | y = buffer.buffer_label[indices]
74 | return x, y
75 |
76 | def cosine_similarity(x1, x2=None, eps=1e-8):
77 | x2 = x1 if x2 is None else x2
78 | w1 = x1.norm(p=2, dim=1, keepdim=True)
79 | w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
80 | sim = torch.mm(x1, x2.t())/(w1 * w2.t()).clamp(min=eps)
81 | return sim
82 |
83 |
84 | def get_grad_vector(pp, grad_dims):
85 | """
86 | gather the gradients in one vector
87 | """
88 | grads = maybe_cuda(torch.Tensor(sum(grad_dims)))
89 | grads.fill_(0.0)
90 | cnt = 0
91 | for param in pp():
92 | if param.grad is not None:
93 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
94 | en = sum(grad_dims[:cnt + 1])
95 | grads[beg: en].copy_(param.grad.data.view(-1))
96 | cnt += 1
97 | return grads
98 |
99 |
100 | class ClassBalancedRandomSampling:
101 | # For faster label-based sampling (e.g., class balanced sampling), cache class-index via auxiliary dictionary
102 | # Store {class, set of memory sample indices from class} key-value pairs to speed up label-based sampling
103 | # e.g., {: {, }, : {}, : {}, ...}
104 | class_index_cache = None
105 | class_num_cache = None
106 |
107 | @classmethod
108 | def sample(cls, buffer_x, buffer_y, n_smp_cls, excl_indices=None, device="cpu"):
109 | """
110 | Take same number of random samples from each class from buffer.
111 | Args:
112 | buffer_x (tensor): data buffer.
113 | buffer_y (tensor): label buffer.
114 | n_smp_cls (int): number of samples to take from each class.
115 | excl_indices (set): indices of buffered instances to be excluded from sampling.
116 | device (str): device for tensor allocation.
117 | Returns
118 | x (tensor): class balanced random sample data tensor.
119 | y (tensor): class balanced random sample label tensor.
120 | sample_ind (tensor): class balanced random sample index tensor.
121 | """
122 | if excl_indices is None:
123 | excl_indices = set()
124 |
125 | # Get indices for class balanced random samples
126 | # cls_ind_cache = class_index_tensor_list_cache(buffer_y, num_class, excl_indices, device=device)
127 |
128 | sample_ind = torch.tensor([], device=device, dtype=torch.long)
129 |
130 | # Use cache to retrieve indices belonging to each class in buffer
131 | for ind_set in cls.class_index_cache.values():
132 | if ind_set:
133 | # Exclude some indices
134 | valid_ind = ind_set - excl_indices
135 | # Auxiliary indices for permutation
136 | perm_ind = torch.randperm(len(valid_ind), device=device)
137 | # Apply permutation, and select indices
138 | ind = torch.tensor(list(valid_ind), device=device, dtype=torch.long)[perm_ind][:n_smp_cls]
139 | sample_ind = torch.cat((sample_ind, ind))
140 |
141 | x = buffer_x[sample_ind]
142 | y = buffer_y[sample_ind]
143 |
144 | x = maybe_cuda(x)
145 | y = maybe_cuda(y)
146 |
147 | return x, y, sample_ind
148 |
149 | @classmethod
150 | def update_cache(cls, buffer_y, num_class, new_y=None, ind=None, device="cpu"):
151 | """
152 | Collect indices of buffered data from each class in set.
153 | Update class_index_cache with list of such sets.
154 | Args:
155 | buffer_y (tensor): label buffer.
156 | num_class (int): total number of unique class labels.
157 | new_y (tensor): label tensor for replacing memory samples at ind in buffer.
158 | ind (tensor): indices of memory samples to be updated.
159 | device (str): device for tensor allocation.
160 | """
161 | if cls.class_index_cache is None:
162 | # Initialize caches
163 | cls.class_index_cache = defaultdict(set)
164 | cls.class_num_cache = torch.zeros(num_class, dtype=torch.long, device=device)
165 |
166 | if new_y is not None:
167 | # If ASER update is being used, keep updating existing caches
168 | # Get labels of memory samples to be replaced
169 | orig_y = buffer_y[ind]
170 | # Update caches
171 | for i, ny, oy in zip(ind, new_y, orig_y):
172 | oy_int = oy.item()
173 | ny_int = ny.item()
174 | i_int = i.item()
175 | # Update dictionary according to new class label of index i
176 | if oy_int in cls.class_index_cache and i_int in cls.class_index_cache[oy_int]:
177 | cls.class_index_cache[oy_int].remove(i_int)
178 | cls.class_num_cache[oy_int] -= 1
179 | cls.class_index_cache[ny_int].add(i_int)
180 | cls.class_num_cache[ny_int] += 1
181 | else:
182 | # If only ASER retrieve is being used, reset cache and update it based on buffer
183 | cls_ind_cache = defaultdict(set)
184 | for i, c in enumerate(buffer_y):
185 | cls_ind_cache[c.item()].add(i)
186 | cls.class_index_cache = cls_ind_cache
187 |
188 |
189 | class BufferClassTracker(object):
190 | # For faster label-based sampling (e.g., class balanced sampling), cache class-index via auxiliary dictionary
191 | # Store {class, set of memory sample indices from class} key-value pairs to speed up label-based sampling
192 | # e.g., {: {, }, : {}, : {}, ...}
193 |
194 | def __init__(self, num_class, device="cpu"):
195 | super().__init__()
196 | # Initialize caches
197 | self.class_index_cache = defaultdict(set)
198 | self.class_num_cache = np.zeros(num_class)
199 |
200 |
201 | def update_cache(self, buffer_y, new_y=None, ind=None, ):
202 | """
203 | Collect indices of buffered data from each class in set.
204 | Update class_index_cache with list of such sets.
205 | Args:
206 | buffer_y (tensor): label buffer.
207 | num_class (int): total number of unique class labels.
208 | new_y (tensor): label tensor for replacing memory samples at ind in buffer.
209 | ind (tensor): indices of memory samples to be updated.
210 | device (str): device for tensor allocation.
211 | """
212 |
213 | # Get labels of memory samples to be replaced
214 | orig_y = buffer_y[ind]
215 | # Update caches
216 | for i, ny, oy in zip(ind, new_y, orig_y):
217 | oy_int = oy.item()
218 | ny_int = ny.item()
219 | # Update dictionary according to new class label of index i
220 | if oy_int in self.class_index_cache and i in self.class_index_cache[oy_int]:
221 | self.class_index_cache[oy_int].remove(i)
222 | self.class_num_cache[oy_int] -= 1
223 |
224 | self.class_index_cache[ny_int].add(i)
225 | self.class_num_cache[ny_int] += 1
226 |
227 |
228 | def check_tracker(self):
229 | print(self.class_num_cache.sum())
230 | print(len([k for i in self.class_index_cache.values() for k in i]))
231 |
--------------------------------------------------------------------------------
/models/ndpm/vae.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from itertools import accumulate
3 | import torch
4 | import torch.nn as nn
5 |
6 | from utils.utils import maybe_cuda
7 | from .loss import bernoulli_nll, logistic_nll, gaussian_nll, laplace_nll
8 | from .component import ComponentG
9 | from .utils import Lambda
10 | from utils.global_vars import *
11 | from utils.setup_elements import input_size_match
12 |
13 | class Vae(ComponentG, ABC):
14 | def __init__(self, params, experts):
15 | super().__init__(params, experts)
16 | x_c, x_h, x_w = input_size_match[params.data]
17 | bernoulli = MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli'
18 | if bernoulli:
19 | self.log_var_param = None
20 | elif MODELS_NDPM_VAE_LEARN_X_LOG_VAR:
21 | self.log_var_param = nn.Parameter(
22 | torch.ones([x_c]) * MODELS_NDPM_VAE_X_LOG_VAR_PARAM,
23 | requires_grad=True
24 | )
25 | else:
26 | self.log_var_param = (
27 | maybe_cuda(torch.ones([x_c])) *
28 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM
29 | )
30 |
31 | def forward(self, x):
32 | x = maybe_cuda(x)
33 | z_mean, z_log_var = self.encode(x)
34 | z = self.reparameterize(z_mean, z_log_var, 1)
35 | return self.decode(z)
36 |
37 | def nll(self, x, y=None, step=None):
38 | x = maybe_cuda(x)
39 | z_mean, z_log_var = self.encode(x)
40 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES)
41 | x_mean = self.decode(z)
42 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, *x.shape[1:])
43 | x_log_var = (
44 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else
45 | self.log_var.view(1, 1, -1, 1, 1)
46 | )
47 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var)
48 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, -1)
49 | loss_recon = loss_recon.sum(2).mean(1)
50 | loss_kl = self.gaussian_kl(z_mean, z_log_var)
51 | loss_vae = loss_recon + loss_kl
52 |
53 | return loss_vae
54 |
55 | def sample(self, n=1):
56 | z = maybe_cuda(torch.randn(n, MODELS_NDPM_VAE_Z_DIM))
57 | x_mean = self.decode(z)
58 | return x_mean
59 |
60 | def reconstruction_loss(self, x, x_mean, x_log_var=None):
61 | loss_type = MODELS_NDPM_VAE_RECON_LOSS
62 | loss = (
63 | bernoulli_nll if loss_type == 'bernoulli' else
64 | gaussian_nll if loss_type == 'gaussian' else
65 | laplace_nll if loss_type == 'laplace' else
66 | logistic_nll if loss_type == 'logistic' else None
67 | )
68 | if loss is None:
69 | raise ValueError('Unknown recon_loss type: {}'.format(loss_type))
70 |
71 | if len(x_mean.size()) > len(x.size()):
72 | x = x.unsqueeze(1)
73 |
74 | return (
75 | loss(x, x_mean) if x_log_var is None else
76 | loss(x, x_mean, x_log_var)
77 | )
78 |
79 | @staticmethod
80 | def gaussian_kl(q_mean, q_log_var, p_mean=None, p_log_var=None):
81 | # p defaults to N(0, 1)
82 | zeros = torch.zeros_like(q_mean)
83 | p_mean = p_mean if p_mean is not None else zeros
84 | p_log_var = p_log_var if p_log_var is not None else zeros
85 | # calcaulate KL(q, p)
86 | kld = 0.5 * (
87 | p_log_var - q_log_var +
88 | (q_log_var.exp() + (q_mean - p_mean) ** 2) / p_log_var.exp() - 1
89 | )
90 | kld = kld.sum(1)
91 | return kld
92 |
93 | @staticmethod
94 | def reparameterize(z_mean, z_log_var, num_samples=1):
95 | z_std = (z_log_var * 0.5).exp()
96 | z_std = z_std.unsqueeze(1).expand(-1, num_samples, -1)
97 | z_mean = z_mean.unsqueeze(1).expand(-1, num_samples, -1)
98 | unit_normal = torch.randn_like(z_std)
99 | z = z_mean + unit_normal * z_std
100 | z = z.view(-1, z_std.size(2))
101 | return z
102 |
103 | @abstractmethod
104 | def encode(self, x):
105 | pass
106 |
107 | @abstractmethod
108 | def decode(self, x):
109 | pass
110 |
111 | @property
112 | def log_var(self):
113 | return (
114 | None if self.log_var_param is None else
115 | self.log_var_param
116 | )
117 |
118 |
119 | class SharingVae(Vae, ABC):
120 | def collect_nll(self, x, y=None, step=None):
121 | """Collect NLL values
122 |
123 | Returns:
124 | loss_vae: Tensor of shape [B, 1+K]
125 | """
126 | x = maybe_cuda(x)
127 |
128 | # Dummy VAE
129 | dummy_nll = self.experts[0].g.nll(x, y, step)
130 |
131 | # Encode
132 | z_means, z_log_vars, features = self.encode(x, collect=True)
133 |
134 | # Decode
135 | loss_vaes = [dummy_nll]
136 | vaes = [expert.g for expert in self.experts[1:]] + [self]
137 | x_logits = []
138 | for z_mean, z_log_var, vae in zip(z_means, z_log_vars, vaes):
139 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES)
140 | if MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER:
141 | x_logit = vae.decode(z, as_logit=True)
142 | x_logits.append(x_logit)
143 | continue
144 | x_mean = vae.decode(z)
145 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
146 | *x.shape[1:])
147 | x_log_var = (
148 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else
149 | self.log_var.view(1, 1, -1, 1, 1)
150 | )
151 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var)
152 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
153 | -1)
154 | loss_recon = loss_recon.sum(2).mean(1)
155 | loss_kl = self.gaussian_kl(z_mean, z_log_var)
156 | loss_vae = loss_recon + loss_kl
157 |
158 | loss_vaes.append(loss_vae)
159 |
160 | x_logits = list(accumulate(
161 | x_logits, func=(lambda x, y: x.detach() + y)
162 | ))
163 | for x_logit in x_logits:
164 | x_mean = torch.sigmoid(x_logit)
165 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
166 | *x.shape[1:])
167 | x_log_var = (
168 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else
169 | self.log_var.view(1, 1, -1, 1, 1)
170 | )
171 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var)
172 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
173 | -1)
174 | loss_recon = loss_recon.sum(2).mean(1)
175 | loss_kl = self.gaussian_kl(z_mean, z_log_var)
176 | loss_vae = loss_recon + loss_kl
177 | loss_vaes.append(loss_vae)
178 |
179 | return torch.stack(loss_vaes, dim=1)
180 |
181 | @abstractmethod
182 | def encode(self, x, collect=False):
183 | pass
184 |
185 | @abstractmethod
186 | def decode(self, z, as_logit=False):
187 | """
188 | Decode do not share parameters
189 | """
190 | pass
191 |
192 |
193 | class CnnSharingVae(SharingVae):
194 | def __init__(self, params, experts):
195 | super().__init__(params, experts)
196 | self.precursors = [expert.g for expert in self.experts[1:]]
197 | first = len(self.precursors) == 0
198 | nf_base, nf_ext = MODLES_NDPM_VAE_NF_BASE, MODELS_NDPM_VAE_NF_EXT
199 | nf = nf_base if first else nf_ext
200 | nf_cat = nf_base + len(self.precursors) * nf_ext
201 |
202 | h1_dim = 1 * nf
203 | h2_dim = 2 * nf
204 | fc_dim = 4 * nf
205 | h1_cat_dim = 1 * nf_cat
206 | h2_cat_dim = 2 * nf_cat
207 | fc_cat_dim = 4 * nf_cat
208 |
209 | x_c, x_h, x_w = input_size_match[params.data]
210 |
211 | self.fc_dim = fc_dim
212 | feature_volume = ((x_h // 4) * (x_w // 4) *
213 | h2_cat_dim)
214 |
215 | self.enc1 = nn.Sequential(
216 | nn.Conv2d(x_c, h1_dim, 3, 1, 1),
217 | nn.MaxPool2d(2),
218 | nn.ReLU()
219 | )
220 | self.enc2 = nn.Sequential(
221 | nn.Conv2d(h1_cat_dim, h2_dim, 3, 1, 1),
222 | nn.MaxPool2d(2),
223 | nn.ReLU(),
224 | Lambda(lambda x: x.view(x.size(0), -1))
225 | )
226 | self.enc3 = nn.Sequential(
227 | nn.Linear(feature_volume, fc_dim),
228 | nn.ReLU()
229 | )
230 | self.enc_z_mean = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM)
231 | self.enc_z_log_var = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM)
232 |
233 | self.dec_z = nn.Sequential(
234 | nn.Linear(MODELS_NDPM_VAE_Z_DIM, 4 * nf_base),
235 | nn.ReLU()
236 | )
237 | self.dec3 = nn.Sequential(
238 | nn.Linear(
239 | 4 * nf_base,
240 | (x_h // 4) * (x_w // 4) * 2 * nf_base),
241 | nn.ReLU()
242 | )
243 | self.dec2 = nn.Sequential(
244 | Lambda(lambda x: x.view(
245 | x.size(0), 2 * nf_base,
246 | x_h // 4, x_w // 4)),
247 | nn.ConvTranspose2d(2 * nf_base, 1 * nf_base,
248 | kernel_size=4, stride=2, padding=1),
249 | nn.ReLU()
250 | )
251 | self.dec1 = nn.ConvTranspose2d(1 * nf_base, x_c,
252 | kernel_size=4, stride=2, padding=1)
253 |
254 | self.setup_optimizer()
255 |
256 | def encode(self, x, collect=False):
257 | # When first component
258 | if len(self.precursors) == 0:
259 | h1 = self.enc1(x)
260 | h2 = self.enc2(h1)
261 | h3 = self.enc3(h2)
262 | z_mean = self.enc_z_mean(h3)
263 | z_log_var = self.enc_z_log_var(h3)
264 |
265 | if collect:
266 | return [z_mean], [z_log_var], \
267 | [h1.detach(), h2.detach(), h3.detach()]
268 | else:
269 | return z_mean, z_log_var
270 |
271 | # Second or later component
272 | z_means, z_log_vars, features = \
273 | self.precursors[-1].encode(x, collect=True)
274 |
275 | h1 = self.enc1(x)
276 | h1_cat = torch.cat([features[0], h1], dim=1)
277 | h2 = self.enc2(h1_cat)
278 | h2_cat = torch.cat([features[1], h2], dim=1)
279 | h3 = self.enc3(h2_cat)
280 | h3_cat = torch.cat([features[2], h3], dim=1)
281 | z_mean = self.enc_z_mean(h3_cat)
282 | z_log_var = self.enc_z_log_var(h3_cat)
283 |
284 | if collect:
285 | z_means.append(z_mean)
286 | z_log_vars.append(z_log_var)
287 | features = [h1_cat.detach(), h2_cat.detach(), h3_cat.detach()]
288 | return z_means, z_log_vars, features
289 | else:
290 | return z_mean, z_log_var
291 |
292 | def decode(self, z, as_logit=False):
293 | h3 = self.dec_z(z)
294 | h2 = self.dec3(h3)
295 | h1 = self.dec2(h2)
296 | x_logit = self.dec1(h1)
297 | return x_logit if as_logit else torch.sigmoid(x_logit)
298 |
--------------------------------------------------------------------------------
/agents/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import abc
3 | import numpy as np
4 | import torch
5 | from torch.nn import functional as F
6 | from utils.kd_manager import KdManager
7 | from utils.utils import maybe_cuda, AverageMeter
8 | from torch.utils.data import TensorDataset, DataLoader
9 | import copy
10 | from utils.loss import SupConLoss
11 | import pickle
12 |
13 |
14 | class ContinualLearner(torch.nn.Module, metaclass=abc.ABCMeta):
15 | '''
16 | Abstract module which is inherited by each and every continual learning algorithm.
17 | '''
18 |
19 | def __init__(self, model, opt, params):
20 | super(ContinualLearner, self).__init__()
21 | self.params = params
22 | self.model = model
23 | self.opt = opt
24 | self.data = params.data
25 | self.cuda = params.cuda
26 | self.epoch = params.epoch
27 | self.batch = params.batch
28 | self.verbose = params.verbose
29 | self.old_labels = []
30 | self.new_labels = []
31 | self.task_seen = 0
32 | self.kd_manager = KdManager()
33 | self.error_list = []
34 | self.new_class_score = []
35 | self.old_class_score = []
36 | self.fc_norm_new = []
37 | self.fc_norm_old = []
38 | self.bias_norm_new = []
39 | self.bias_norm_old = []
40 | self.lbl_inv_map = {}
41 | self.class_task_map = {}
42 |
43 | def before_train(self, x_train, y_train):
44 | new_labels = list(set(y_train.tolist()))
45 | self.new_labels += new_labels
46 | for i, lbl in enumerate(new_labels):
47 | self.lbl_inv_map[lbl] = len(self.old_labels) + i
48 |
49 | for i in new_labels:
50 | self.class_task_map[i] = self.task_seen
51 |
52 | @abstractmethod
53 | def train_learner(self, x_train, y_train):
54 | pass
55 |
56 | def after_train(self):
57 | #self.old_labels = list(set(self.old_labels + self.new_labels))
58 | self.old_labels += self.new_labels
59 | self.new_labels_zombie = copy.deepcopy(self.new_labels)
60 | self.new_labels.clear()
61 | self.task_seen += 1
62 | if self.params.trick['review_trick'] and hasattr(self, 'buffer'):
63 | self.model.train()
64 | mem_x = self.buffer.buffer_img[:self.buffer.current_index]
65 | mem_y = self.buffer.buffer_label[:self.buffer.current_index]
66 | # criterion = torch.nn.CrossEntropyLoss(reduction='mean')
67 | if mem_x.size(0) > 0:
68 | rv_dataset = TensorDataset(mem_x, mem_y)
69 | rv_loader = DataLoader(rv_dataset, batch_size=self.params.eps_mem_batch, shuffle=True, num_workers=0,
70 | drop_last=True)
71 | for ep in range(1):
72 | for i, batch_data in enumerate(rv_loader):
73 | # batch update
74 | batch_x, batch_y = batch_data
75 | batch_x = maybe_cuda(batch_x, self.cuda)
76 | batch_y = maybe_cuda(batch_y, self.cuda)
77 | logits = self.model.forward(batch_x)
78 | if self.params.agent in ['SCR', 'SSCR']:
79 | logits = torch.cat([self.model.forward(batch_x).unsqueeze(1),
80 | self.model.forward(self.transform(batch_x)).unsqueeze(1)], dim=1)
81 | loss = self.criterion(logits, batch_y)
82 | self.opt.zero_grad()
83 | loss.backward()
84 | params = [p for p in self.model.parameters() if p.requires_grad and p.grad is not None]
85 | grad = [p.grad.clone()/10. for p in params]
86 | for g, p in zip(grad, params):
87 | p.grad.data.copy_(g)
88 | self.opt.step()
89 |
90 | if self.params.trick['kd_trick'] or self.params.agent == 'LWF':
91 | self.kd_manager.update_teacher(self.model)
92 |
93 | def criterion(self, logits, labels):
94 | labels = labels.clone()
95 | ce = torch.nn.CrossEntropyLoss(reduction='mean')
96 | if self.params.trick['labels_trick']:
97 | unq_lbls = labels.unique().sort()[0]
98 | for lbl_idx, lbl in enumerate(unq_lbls):
99 | labels[labels == lbl] = lbl_idx
100 | # Calcualte loss only over the heads appear in the batch:
101 | return ce(logits[:, unq_lbls], labels)
102 | elif self.params.trick['separated_softmax']:
103 | old_ss = F.log_softmax(logits[:, self.old_labels], dim=1)
104 | new_ss = F.log_softmax(logits[:, self.new_labels], dim=1)
105 | ss = torch.cat([old_ss, new_ss], dim=1)
106 | for i, lbl in enumerate(labels):
107 | labels[i] = self.lbl_inv_map[lbl.item()]
108 | return F.nll_loss(ss, labels)
109 | elif self.params.agent in ['SCR', 'SCP', 'SSCR']:
110 | SC = SupConLoss(temperature=self.params.temp)
111 | return SC(logits, labels)
112 | else:
113 | return ce(logits, labels)
114 |
115 | def forward(self, x):
116 | return self.model.forward(x)
117 |
118 | def evaluate(self, test_loaders):
119 | self.model.eval()
120 | acc_array = np.zeros(len(test_loaders))
121 | if self.params.trick['ncm_trick'] or self.params.agent in ['ICARL', 'SCR', 'SCP', 'SSCR']:
122 | exemplar_means = {}
123 | cls_exemplar = {cls: [] for cls in self.old_labels}
124 | buffer_filled = self.buffer.current_index
125 | for x, y in zip(self.buffer.buffer_img[:buffer_filled], self.buffer.buffer_label[:buffer_filled]):
126 | cls_exemplar[y.item()].append(x)
127 | for cls, exemplar in cls_exemplar.items():
128 | features = []
129 | # Extract feature for each exemplar in p_y
130 | for ex in exemplar:
131 | feature = self.model.features(ex.unsqueeze(0)).detach().clone()
132 | feature = feature.squeeze()
133 | feature.data = feature.data / feature.data.norm() # Normalize
134 | features.append(feature)
135 | if len(features) == 0:
136 | mu_y = maybe_cuda(torch.normal(0, 1, size=tuple(self.model.features(x.unsqueeze(0)).detach().size())), self.cuda)
137 | mu_y = mu_y.squeeze()
138 | else:
139 | features = torch.stack(features)
140 | mu_y = features.mean(0).squeeze()
141 | mu_y.data = mu_y.data / mu_y.data.norm() # Normalize
142 | exemplar_means[cls] = mu_y
143 | with torch.no_grad():
144 | if self.params.error_analysis:
145 | error = 0
146 | no = 0
147 | nn = 0
148 | oo = 0
149 | on = 0
150 | new_class_score = AverageMeter()
151 | old_class_score = AverageMeter()
152 | correct_lb = []
153 | predict_lb = []
154 | for task, test_loader in enumerate(test_loaders):
155 | acc = AverageMeter()
156 | for i, (batch_x, batch_y) in enumerate(test_loader):
157 | batch_x = maybe_cuda(batch_x, self.cuda)
158 | batch_y = maybe_cuda(batch_y, self.cuda)
159 | if self.params.trick['ncm_trick'] or self.params.agent in ['ICARL', 'SCR', 'SCP', 'SSCR']:
160 | feature = self.model.features(batch_x) # (batch_size, feature_size)
161 | for j in range(feature.size(0)): # Normalize
162 | feature.data[j] = feature.data[j] / feature.data[j].norm()
163 | feature = feature.unsqueeze(2) # (batch_size, feature_size, 1)
164 | means = torch.stack([exemplar_means[cls] for cls in self.old_labels]) # (n_classes, feature_size)
165 |
166 | #old ncm
167 | means = torch.stack([means] * batch_x.size(0)) # (batch_size, n_classes, feature_size)
168 | means = means.transpose(1, 2)
169 | feature = feature.expand_as(means) # (batch_size, feature_size, n_classes)
170 | dists = (feature - means).pow(2).sum(1).squeeze() # (batch_size, n_classes)
171 | _, pred_label = dists.min(1)
172 | # may be faster
173 | # feature = feature.squeeze(2).T
174 | # _, preds = torch.matmul(means, feature).max(0)
175 | correct_cnt = (np.array(self.old_labels)[
176 | pred_label.tolist()] == batch_y.cpu().numpy()).sum().item() / batch_y.size(0)
177 | else:
178 | logits = self.model.forward(batch_x)
179 | _, pred_label = torch.max(logits, 1)
180 | correct_cnt = (pred_label == batch_y).sum().item()/batch_y.size(0)
181 |
182 | if self.params.error_analysis:
183 | correct_lb += [task] * len(batch_y)
184 | for i in pred_label:
185 | predict_lb.append(self.class_task_map[i.item()])
186 | if task < self.task_seen-1:
187 | # old test
188 | total = (pred_label != batch_y).sum().item()
189 | wrong = pred_label[pred_label != batch_y]
190 | error += total
191 | on_tmp = sum([(wrong == i).sum().item() for i in self.new_labels_zombie])
192 | oo += total - on_tmp
193 | on += on_tmp
194 | old_class_score.update(logits[:, list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item(), batch_y.size(0))
195 | elif task == self.task_seen -1:
196 | # new test
197 | total = (pred_label != batch_y).sum().item()
198 | error += total
199 | wrong = pred_label[pred_label != batch_y]
200 | no_tmp = sum([(wrong == i).sum().item() for i in list(set(self.old_labels) - set(self.new_labels_zombie))])
201 | no += no_tmp
202 | nn += total - no_tmp
203 | new_class_score.update(logits[:, self.new_labels_zombie].mean().item(), batch_y.size(0))
204 | else:
205 | pass
206 | acc.update(correct_cnt, batch_y.size(0))
207 | acc_array[task] = acc.avg()
208 | print(acc_array)
209 | if self.params.error_analysis:
210 | self.error_list.append((no, nn, oo, on))
211 | self.new_class_score.append(new_class_score.avg())
212 | self.old_class_score.append(old_class_score.avg())
213 | print("no ratio: {}\non ratio: {}".format(no/(no+nn+0.1), on/(oo+on+0.1)))
214 | print(self.error_list)
215 | print(self.new_class_score)
216 | print(self.old_class_score)
217 | self.fc_norm_new.append(self.model.linear.weight[self.new_labels_zombie].mean().item())
218 | self.fc_norm_old.append(self.model.linear.weight[list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item())
219 | self.bias_norm_new.append(self.model.linear.bias[self.new_labels_zombie].mean().item())
220 | self.bias_norm_old.append(self.model.linear.bias[list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item())
221 | print(self.fc_norm_old)
222 | print(self.fc_norm_new)
223 | print(self.bias_norm_old)
224 | print(self.bias_norm_new)
225 | with open('confusion', 'wb') as fp:
226 | pickle.dump([correct_lb, predict_lb], fp)
227 | return acc_array
228 |
--------------------------------------------------------------------------------