├── agents ├── __init__.py ├── cndpm.py ├── lwf.py ├── icarl.py ├── scr.py ├── gdumb.py ├── summarize.py ├── agem.py ├── ewc_pp.py ├── exp_replay.py └── base.py ├── utils ├── __init__.py ├── buffer │ ├── __init__.py │ ├── sc_retrieve.py │ ├── random_retrieve.py │ ├── mem_match.py │ ├── reservoir_update.py │ ├── mir_retrieve.py │ ├── aser_retrieve.py │ ├── buffer.py │ ├── aser_update.py │ ├── gss_greedy_update.py │ ├── aser_utils.py │ └── buffer_utils.py ├── kd_manager.py ├── logging.py ├── global_vars.py ├── io.py ├── name_match.py ├── setup_elements.py ├── loss.py └── utils.py ├── continuum ├── __init__.py ├── dataset_scripts │ ├── __init__.py │ ├── dataset_base.py │ ├── cifar100.py │ ├── cifar10.py │ ├── tiny_imagenet.py │ ├── mini_imagenet.py │ ├── openloris.py │ └── core50.py ├── continuum.py ├── data_utils.py └── non_stationary.py ├── experiment ├── __init__.py ├── tune_hyperparam.py └── metrics.py ├── models ├── ndpm │ ├── __init__.py │ ├── utils.py │ ├── loss.py │ ├── priors.py │ ├── expert.py │ ├── component.py │ ├── ndpm.py │ ├── classifier.py │ └── vae.py ├── __init__.py ├── pretrained.py ├── convnet.py └── resnet.py ├── figs └── pipeline-ssd.png ├── requirements.txt ├── process_tiny_imagenet.py ├── main_config.py ├── README.md ├── main_tune.py └── .gitignore /agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/ndpm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet -------------------------------------------------------------------------------- /figs/pipeline-ssd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vimar-gu/SSD/HEAD/figs/pipeline-ssd.png -------------------------------------------------------------------------------- /models/pretrained.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | import torch 3 | 4 | def ResNet18_pretrained(n_classes): 5 | classifier = models.resnet18(pretrained=True) 6 | classifier.fc = torch.nn.Linear(512, n_classes) 7 | return classifier -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.7.1 3 | torchvision==0.8.2 4 | matplotlib==3.2.1 5 | scipy==1.4.1 6 | scikit-image==0.14.2 7 | scikit-learn==0.23.0 8 | pandas==1.0.5 9 | PyYAML==5.3.1 10 | psutil==5.7.0 11 | kornia==0.4.1 12 | -------------------------------------------------------------------------------- /models/ndpm/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Lambda(nn.Module): 5 | def __init__(self, f=None): 6 | super().__init__() 7 | self.f = f if f is not None else (lambda x: x) 8 | 9 | def forward(self, *args, **kwargs): 10 | return self.f(*args, **kwargs) 11 | -------------------------------------------------------------------------------- /utils/buffer/sc_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | import torch 3 | 4 | class Match_retrieve(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | self.num_retrieve = params.eps_mem_batch 8 | self.warmup = params.warmup 9 | 10 | def retrieve(self, buffer, **kwargs): 11 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 12 | cur_x, cur_y = kwargs['x'], kwargs['y'] 13 | return match_retrieve(buffer, cur_y) 14 | else: 15 | return torch.tensor([]), torch.tensor([]) -------------------------------------------------------------------------------- /utils/buffer/random_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import random_retrieve, balanced_retrieve 2 | 3 | 4 | class Random_retrieve(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | self.num_retrieve = params.eps_mem_batch 8 | 9 | def retrieve(self, buffer, **kwargs): 10 | return random_retrieve(buffer, self.num_retrieve) 11 | 12 | 13 | class BalancedRetrieve(object): 14 | def __init__(self, params): 15 | super().__init__() 16 | self.num_retrieve = params.eps_mem_batch 17 | 18 | def retrieve(self, buffer, **kwargs): 19 | return balanced_retrieve(buffer, self.num_retrieve) 20 | 21 | -------------------------------------------------------------------------------- /utils/kd_manager.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def loss_fn_kd(scores, target_scores, T=2.): 7 | log_scores_norm = F.log_softmax(scores / T, dim=1) 8 | targets_norm = F.softmax(target_scores / T, dim=1) 9 | # Calculate distillation loss (see e.g., Li and Hoiem, 2017) 10 | kd_loss = (-1 * targets_norm * log_scores_norm).sum(dim=1).mean() * T ** 2 11 | return kd_loss 12 | 13 | 14 | class KdManager: 15 | def __init__(self): 16 | self.teacher_model = None 17 | 18 | def update_teacher(self, model): 19 | self.teacher_model = copy.deepcopy(model) 20 | 21 | def get_kd_loss(self, cur_model_logits, x): 22 | if self.teacher_model is not None: 23 | with torch.no_grad(): 24 | prev_model_logits = self.teacher_model.forward(x) 25 | dist_loss = loss_fn_kd(cur_model_logits, prev_model_logits) 26 | else: 27 | dist_loss = 0 28 | return dist_loss 29 | -------------------------------------------------------------------------------- /utils/buffer/mem_match.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | from utils.buffer.buffer_utils import random_retrieve 3 | import torch 4 | 5 | class MemMatch_retrieve(object): 6 | def __init__(self, params): 7 | super().__init__() 8 | self.num_retrieve = params.eps_mem_batch 9 | self.warmup = params.warmup 10 | 11 | 12 | def retrieve(self, buffer, **kwargs): 13 | match_x, match_y = torch.tensor([]), torch.tensor([]) 14 | candidate_x, candidate_y = torch.tensor([]), torch.tensor([]) 15 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 16 | while match_x.size(0) == 0: 17 | candidate_x, candidate_y, indices = random_retrieve(buffer, self.num_retrieve,return_indices=True) 18 | if candidate_x.size(0) == 0: 19 | return candidate_x, candidate_y, match_x, match_y 20 | match_x, match_y = match_retrieve(buffer, candidate_y, indices) 21 | return candidate_x, candidate_y, match_x, match_y -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno 5 | 6 | 7 | def mkdir_if_missing(dir_path): 8 | try: 9 | os.makedirs(dir_path) 10 | except OSError as e: 11 | if e.errno != errno.EEXIST: 12 | raise 13 | 14 | class Logger(object): 15 | def __init__(self, fpath=None): 16 | self.console = sys.stdout 17 | self.file = None 18 | if fpath is not None: 19 | mkdir_if_missing(os.path.dirname(fpath)) 20 | self.file = open(fpath, 'w') 21 | 22 | def __del__(self): 23 | self.close() 24 | 25 | def __enter__(self): 26 | pass 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def write(self, msg): 32 | self.console.write(msg) 33 | if self.file is not None: 34 | self.file.write(msg) 35 | 36 | def flush(self): 37 | self.console.flush() 38 | if self.file is not None: 39 | self.file.flush() 40 | os.fsync(self.file.fileno()) 41 | 42 | def close(self): 43 | self.console.close() 44 | if self.file is not None: 45 | self.file.close() 46 | 47 | -------------------------------------------------------------------------------- /continuum/continuum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # other imports 4 | from utils.name_match import data_objects 5 | 6 | class continuum(object): 7 | def __init__(self, dataset, scenario, params): 8 | """" Initialize Object """ 9 | self.data_object = data_objects[dataset](scenario, params) 10 | self.run = params.num_runs 11 | self.task_nums = self.data_object.task_nums 12 | self.cur_task = 0 13 | self.cur_run = -1 14 | 15 | def __iter__(self): 16 | return self 17 | 18 | def __next__(self): 19 | if self.cur_task == self.data_object.task_nums: 20 | raise StopIteration 21 | x_train, y_train, labels = self.data_object.new_task(self.cur_task, cur_run=self.cur_run) 22 | self.cur_task += 1 23 | return x_train, y_train, labels 24 | 25 | def test_data(self): 26 | return self.data_object.get_test_set() 27 | 28 | def clean_mem_test_set(self): 29 | self.data_object.clean_mem_test_set() 30 | 31 | def reset_run(self): 32 | self.cur_task = 0 33 | 34 | def new_run(self): 35 | self.cur_task = 0 36 | self.cur_run += 1 37 | self.data_object.new_run(cur_run=self.cur_run) 38 | 39 | 40 | -------------------------------------------------------------------------------- /utils/global_vars.py: -------------------------------------------------------------------------------- 1 | MODLES_NDPM_VAE_NF_BASE = 32 2 | MODELS_NDPM_VAE_NF_EXT = 4 3 | MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER = False 4 | MODELS_NDPM_VAE_Z_DIM = 64 5 | MODELS_NDPM_VAE_RECON_LOSS = 'gaussian' 6 | MODELS_NDPM_VAE_LEARN_X_LOG_VAR = False 7 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM = 0 8 | MODELS_NDPM_VAE_Z_SAMPLES = 16 9 | MODELS_NDPM_CLASSIFIER_NUM_BLOCKS = [1, 1, 1, 1] 10 | MODELS_NDPM_CLASSIFIER_NORM_LAYER = 'InstanceNorm2d' 11 | MODELS_NDPM_CLASSIFIER_CLS_NF_BASE = 20 12 | MODELS_NDPM_CLASSIFIER_CLS_NF_EXT = 4 13 | MODELS_NDPM_NDPM_DISABLE_D = False 14 | MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS = False 15 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE = 50 16 | MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS = 0 17 | MODELS_NDPM_NDPM_SLEEP_STEP_G = 4000 18 | MODELS_NDPM_NDPM_SLEEP_STEP_D = 1000 19 | MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE = 0 20 | MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP = 500 21 | MODELS_NDPM_NDPM_WEIGHT_DECAY = 0.00001 22 | MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY = False 23 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 24 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 25 | MODELS_NDPM_COMPONENT_CLIP_GRAD = {'type': 'value', 'options': {'clip_value': 0.5}} 26 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/dataset_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | 4 | class DatasetBase(ABC): 5 | def __init__(self, dataset, scenario, task_nums, run, params): 6 | super(DatasetBase, self).__init__() 7 | self.params = params 8 | self.scenario = scenario 9 | self.dataset = dataset 10 | self.task_nums = task_nums 11 | self.run = run 12 | self.root = './datasets' 13 | self.test_set = [] 14 | self.val_set = [] 15 | self._is_properly_setup() 16 | self.download_load() 17 | 18 | 19 | @abstractmethod 20 | def download_load(self): 21 | pass 22 | 23 | @abstractmethod 24 | def setup(self, **kwargs): 25 | pass 26 | 27 | @abstractmethod 28 | def new_task(self, cur_task, **kwargs): 29 | pass 30 | 31 | def _is_properly_setup(self): 32 | pass 33 | 34 | @abstractmethod 35 | def new_run(self, **kwargs): 36 | pass 37 | 38 | @property 39 | def dataset_info(self): 40 | return self.dataset 41 | 42 | def get_test_set(self): 43 | return self.test_set 44 | 45 | def clean_mem_test_set(self): 46 | self.test_set = None 47 | self.test_data = None 48 | self.test_label = None 49 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pandas as pd 3 | import os 4 | import psutil 5 | import torch 6 | 7 | def load_yaml(path, key='parameters'): 8 | with open(path, 'r') as stream: 9 | try: 10 | return yaml.load(stream, Loader=yaml.FullLoader)[key] 11 | except yaml.YAMLError as exc: 12 | print(exc) 13 | 14 | def save_dataframe_csv(df, path, name): 15 | df.to_csv(path + '/' + name, index=False) 16 | 17 | 18 | def load_dataframe_csv(path, name=None, delimiter=None, names=None): 19 | if not name: 20 | return pd.read_csv(path, delimiter=delimiter, names=names) 21 | else: 22 | return pd.read_csv(path+name, delimiter=delimiter, names=names) 23 | 24 | def check_ram_usage(): 25 | """ 26 | Compute the RAM usage of the current process. 27 | Returns: 28 | mem (float): Memory occupation in Megabytes 29 | """ 30 | 31 | process = psutil.Process(os.getpid()) 32 | mem = process.memory_info().rss / (1024 * 1024) 33 | 34 | return mem 35 | 36 | def save_model(model, optimizer, opt, epoch, save_file): 37 | print('==> Saving...') 38 | state = { 39 | 'opt': opt, 40 | 'model': model.state_dict(), 41 | 'optimizer': optimizer.state_dict(), 42 | 'epoch': epoch, 43 | } 44 | torch.save(state, save_file) 45 | del state 46 | -------------------------------------------------------------------------------- /process_tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | with open('wnids.txt', 'r') as fp: 8 | wnids = fp.readlines() 9 | wnids = [wnid.strip() for wnid in wnids] 10 | wnid2index = {wnid: index for index, wnid in enumerate(wnids)} 11 | 12 | train_data = {'data': [], 'target': []} 13 | for wnid in wnids: 14 | image_list = os.listdir(f'train/{wnid}/images') 15 | for image_name in image_list: 16 | image = np.asarray(Image.open(f'train/{wnid}/images/{image_name}').convert('RGB')).transpose(2, 0, 1) 17 | train_data['data'].append(image) 18 | train_data['target'].append(wnid2index[wnid]) 19 | train_data['data'] = np.stack(train_data['data']) 20 | train_data['target'] = np.array(train_data['target']) 21 | pickle.dump(train_data, open('train.pkl', 'wb')) 22 | 23 | val_data = {'data': [], 'target': []} 24 | with open('val/val_annotations.txt', 'r') as fp: 25 | val_annos = fp.readlines() 26 | for val_anno in val_annos: 27 | image_name, wnid = val_anno.split('\t')[:2] 28 | image = np.asarray(Image.open(f'val/images/{image_name}').convert('RGB')).transpose(2, 0, 1) 29 | val_data['data'].append(image) 30 | val_data['target'].append(wnid2index[wnid]) 31 | val_data['data'] = np.stack(val_data['data']) 32 | val_data['target'] = np.array(val_data['target']) 33 | pickle.dump(val_data, open('val.pkl', 'wb')) 34 | -------------------------------------------------------------------------------- /models/ndpm/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.functional import binary_cross_entropy 5 | 6 | 7 | def gaussian_nll(x, mean, log_var, min_noise=0.001): 8 | return ( 9 | ((x - mean) ** 2 + min_noise) / (2 * log_var.exp() + 1e-8) 10 | + 0.5 * log_var + 0.5 * np.log(2 * np.pi) 11 | ) 12 | 13 | 14 | def laplace_nll(x, median, log_scale, min_noise=0.01): 15 | return ( 16 | ((x - median).abs() + min_noise) / (log_scale.exp() + 1e-8) 17 | + log_scale + np.log(2) 18 | ) 19 | 20 | 21 | def bernoulli_nll(x, p): 22 | # Broadcast 23 | x_exp, p_exp = [], [] 24 | for x_size, p_size in zip(x.size(), p.size()): 25 | if x_size > p_size: 26 | x_exp.append(-1) 27 | p_exp.append(x_size) 28 | elif x_size < p_size: 29 | x_exp.append(p_size) 30 | p_exp.append(-1) 31 | else: 32 | x_exp.append(-1) 33 | p_exp.append(-1) 34 | x = x.expand(*x_exp) 35 | p = p.expand(*p_exp) 36 | 37 | return binary_cross_entropy(p, x, reduction='none') 38 | 39 | 40 | def logistic_nll(x, mean, log_scale): 41 | bin_size = 1 / 256 42 | scale = log_scale.exp() 43 | x_centered = x - mean 44 | cdf1 = x_centered / scale 45 | cdf2 = (x_centered + bin_size) / scale 46 | p = torch.sigmoid(cdf2) - torch.sigmoid(cdf1) + 1e-12 47 | return -p.log() 48 | -------------------------------------------------------------------------------- /main_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.io import load_yaml 3 | from types import SimpleNamespace 4 | from utils.utils import boolean_string 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | from experiment.run import multiple_run 10 | 11 | def main(args): 12 | genereal_params = load_yaml(args.general) 13 | data_params = load_yaml(args.data) 14 | agent_params = load_yaml(args.agent) 15 | genereal_params['verbose'] = args.verbose 16 | genereal_params['cuda'] = torch.cuda.is_available() 17 | final_params = SimpleNamespace(**genereal_params, **data_params, **agent_params) 18 | time_start = time.time() 19 | print(final_params) 20 | 21 | #reproduce 22 | np.random.seed(final_params.seed) 23 | random.seed(final_params.seed) 24 | torch.manual_seed(final_params.seed) 25 | if final_params.cuda: 26 | torch.cuda.manual_seed(final_params.seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | #run 31 | multiple_run(final_params) 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | # Commandline arguments 37 | parser = argparse.ArgumentParser('CVPR Continual Learning Challenge') 38 | parser.add_argument('--general', dest='general', default='config/general.yml') 39 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml') 40 | parser.add_argument('--agent', dest='agent', default='config/agent/er.yml') 41 | 42 | parser.add_argument('--verbose', type=boolean_string, default=True, 43 | help='print information or not') 44 | args = parser.parse_args() 45 | main(args) -------------------------------------------------------------------------------- /agents/cndpm.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from models.ndpm.ndpm import Ndpm 4 | from utils.setup_elements import transforms_match 5 | from torch.utils import data 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class Cndpm(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Cndpm, self).__init__(model, opt, params) 13 | self.model = model 14 | 15 | 16 | def train_learner(self, x_train, y_train): 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | # setup tracker 22 | losses_batch = AverageMeter() 23 | acc_batch = AverageMeter() 24 | 25 | self.model.train() 26 | 27 | for ep in range(self.epoch): 28 | for i, batch_data in enumerate(train_loader): 29 | # batch update 30 | batch_x, batch_y = batch_data 31 | batch_x = maybe_cuda(batch_x, self.cuda) 32 | batch_y = maybe_cuda(batch_y, self.cuda) 33 | self.model.learn(batch_x, batch_y) 34 | if self.params.verbose: 35 | print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format( 36 | i, 37 | len(self.model.stm_x), self.params.stm_capacity, 38 | len(self.model.experts) - 1 39 | ), end='') 40 | print() 41 | -------------------------------------------------------------------------------- /models/ndpm/priors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | from utils.utils import maybe_cuda 5 | 6 | 7 | class Prior(ABC): 8 | def __init__(self, params): 9 | self.params = params 10 | 11 | @abstractmethod 12 | def add_expert(self): 13 | pass 14 | 15 | @abstractmethod 16 | def record_usage(self, usage, index=None): 17 | pass 18 | 19 | @abstractmethod 20 | def nl_prior(self, normalize=False): 21 | pass 22 | 23 | 24 | class CumulativePrior(Prior): 25 | def __init__(self, params): 26 | super().__init__(params) 27 | self.log_counts = maybe_cuda(torch.tensor( 28 | params.log_alpha 29 | )).float().unsqueeze(0) 30 | 31 | def add_expert(self): 32 | self.log_counts = torch.cat( 33 | [self.log_counts, maybe_cuda(torch.zeros(1))], 34 | dim=0 35 | ) 36 | 37 | def record_usage(self, usage, index=None): 38 | """Record expert usage 39 | 40 | Args: 41 | usage: Tensor of shape [K+1] if index is None else scalar 42 | index: expert index 43 | """ 44 | if index is None: 45 | self.log_counts = torch.logsumexp(torch.stack([ 46 | self.log_counts, 47 | usage.log() 48 | ], dim=1), dim=1) 49 | else: 50 | self.log_counts[index] = torch.logsumexp(torch.stack([ 51 | self.log_counts[index], 52 | maybe_cuda(torch.tensor(usage)).float().log() 53 | ], dim=0), dim=0) 54 | 55 | def nl_prior(self, normalize=False): 56 | nl_prior = -self.log_counts 57 | if normalize: 58 | nl_prior += torch.logsumexp(self.log_counts, dim=0) 59 | return nl_prior 60 | 61 | @property 62 | def counts(self): 63 | return self.log_counts.exp() 64 | -------------------------------------------------------------------------------- /experiment/tune_hyperparam.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from sklearn.model_selection import ParameterGrid 3 | from utils.setup_elements import setup_opt, setup_architecture 4 | from utils.utils import maybe_cuda 5 | from utils.name_match import agents 6 | import numpy as np 7 | from experiment.metrics import compute_performance 8 | 9 | 10 | def tune_hyper(tune_data, tune_test_loaders, default_params, tune_params): 11 | param_grid_list = list(ParameterGrid(tune_params)) 12 | print(len(param_grid_list)) 13 | tune_accs = [] 14 | tune_fgt = [] 15 | for param_set in param_grid_list: 16 | final_params = vars(default_params) 17 | print(param_set) 18 | final_params.update(param_set) 19 | final_params = SimpleNamespace(**final_params) 20 | accuracy_list = [] 21 | for run in range(final_params.num_runs_val): 22 | tmp_acc = [] 23 | model = setup_architecture(final_params) 24 | model = maybe_cuda(model, final_params.cuda) 25 | opt = setup_opt(final_params.optimizer, model, final_params.learning_rate, final_params.weight_decay) 26 | agent = agents[final_params.agent](model, opt, final_params) 27 | for i, (x_train, y_train, labels) in enumerate(tune_data): 28 | print("-----------tune run {} task {}-------------".format(run, i)) 29 | print('size: {}, {}'.format(x_train.shape, y_train.shape)) 30 | agent.train_learner(x_train, y_train) 31 | acc_array = agent.evaluate(tune_test_loaders) 32 | tmp_acc.append(acc_array) 33 | print( 34 | "-----------tune run {}-----------avg_end_acc {}-----------".format(run, np.mean(tmp_acc[-1]))) 35 | accuracy_list.append(np.array(tmp_acc)) 36 | accuracy_list = np.array(accuracy_list) 37 | avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt = compute_performance(accuracy_list) 38 | tune_accs.append(avg_end_acc[0]) 39 | tune_fgt.append(avg_end_fgt[0]) 40 | best_tune = param_grid_list[tune_accs.index(max(tune_accs))] 41 | return best_tune -------------------------------------------------------------------------------- /utils/name_match.py: -------------------------------------------------------------------------------- 1 | from agents.gdumb import Gdumb 2 | from continuum.dataset_scripts.cifar100 import CIFAR100 3 | from continuum.dataset_scripts.cifar10 import CIFAR10 4 | from continuum.dataset_scripts.core50 import CORE50 5 | from continuum.dataset_scripts.mini_imagenet import Mini_ImageNet 6 | from continuum.dataset_scripts.tiny_imagenet import TinyImageNet 7 | from continuum.dataset_scripts.openloris import OpenLORIS 8 | from agents.exp_replay import ExperienceReplay 9 | from agents.agem import AGEM 10 | from agents.ewc_pp import EWC_pp 11 | from agents.cndpm import Cndpm 12 | from agents.lwf import Lwf 13 | from agents.icarl import Icarl 14 | from agents.scr import SupContrastReplay 15 | from agents.summarize import SummarizeContrastReplay 16 | from utils.buffer.random_retrieve import Random_retrieve, BalancedRetrieve 17 | from utils.buffer.reservoir_update import Reservoir_update 18 | from utils.buffer.summarize_update import SummarizeUpdate 19 | from utils.buffer.mir_retrieve import MIR_retrieve 20 | from utils.buffer.gss_greedy_update import GSSGreedyUpdate 21 | from utils.buffer.aser_retrieve import ASER_retrieve 22 | from utils.buffer.aser_update import ASER_update 23 | from utils.buffer.sc_retrieve import Match_retrieve 24 | from utils.buffer.mem_match import MemMatch_retrieve 25 | 26 | data_objects = { 27 | 'cifar100': CIFAR100, 28 | 'cifar10': CIFAR10, 29 | 'core50': CORE50, 30 | 'mini_imagenet': Mini_ImageNet, 31 | 'tiny_imagenet': TinyImageNet, 32 | 'openloris': OpenLORIS 33 | } 34 | 35 | agents = { 36 | 'ER': ExperienceReplay, 37 | 'EWC': EWC_pp, 38 | 'AGEM': AGEM, 39 | 'CNDPM': Cndpm, 40 | 'LWF': Lwf, 41 | 'ICARL': Icarl, 42 | 'GDUMB': Gdumb, 43 | 'SCR': SupContrastReplay, 44 | 'SSCR': SummarizeContrastReplay, 45 | } 46 | 47 | retrieve_methods = { 48 | 'MIR': MIR_retrieve, 49 | 'random': Random_retrieve, 50 | 'ASER': ASER_retrieve, 51 | 'match': Match_retrieve, 52 | 'mem_match': MemMatch_retrieve, 53 | } 54 | 55 | update_methods = { 56 | 'random': Reservoir_update, 57 | 'GSS': GSSGreedyUpdate, 58 | 'ASER': ASER_update, 59 | 'summarize': SummarizeUpdate, 60 | } 61 | 62 | -------------------------------------------------------------------------------- /models/ndpm/expert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.ndpm.classifier import ResNetSharingClassifier 4 | from models.ndpm.vae import CnnSharingVae 5 | from utils.utils import maybe_cuda 6 | 7 | from utils.global_vars import * 8 | 9 | 10 | class Expert(nn.Module): 11 | def __init__(self, params, experts=()): 12 | super().__init__() 13 | self.id = len(experts) 14 | self.experts = experts 15 | 16 | self.g = maybe_cuda(CnnSharingVae(params, experts)) 17 | self.d = maybe_cuda(ResNetSharingClassifier(params, experts)) if not MODELS_NDPM_NDPM_DISABLE_D else None 18 | 19 | 20 | # use random initialized g if it's a placeholder 21 | if self.id == 0: 22 | self.eval() 23 | for p in self.g.parameters(): 24 | p.requires_grad = False 25 | 26 | # use random initialized d if it's a placeholder 27 | if self.id == 0 and self.d is not None: 28 | for p in self.d.parameters(): 29 | p.requires_grad = False 30 | 31 | def forward(self, x): 32 | return self.d(x) 33 | 34 | def nll(self, x, y, step=None): 35 | """Negative log likelihood""" 36 | nll = self.g.nll(x, step) 37 | if self.d is not None: 38 | d_nll = self.d.nll(x, y, step) 39 | nll = nll + d_nll 40 | return nll 41 | 42 | def collect_nll(self, x, y, step=None): 43 | if self.id == 0: 44 | nll = self.nll(x, y, step) 45 | return nll.unsqueeze(1) 46 | 47 | nll = self.g.collect_nll(x, step) 48 | if self.d is not None: 49 | d_nll = self.d.collect_nll(x, y, step) 50 | nll = nll + d_nll 51 | 52 | return nll 53 | 54 | def lr_scheduler_step(self): 55 | if self.g.lr_scheduler is not NotImplemented: 56 | self.g.lr_scheduler.step() 57 | if self.d is not None and self.d.lr_scheduler is not NotImplemented: 58 | self.d.lr_scheduler.step() 59 | 60 | def clip_grad(self): 61 | self.g.clip_grad() 62 | if self.d is not None: 63 | self.d.clip_grad() 64 | 65 | def optimizer_step(self): 66 | self.g.optimizer.step() 67 | if self.d is not None: 68 | self.d.optimizer.step() 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Summarizing Stream Data for Memory-Restricted Online Continual Learning 2 | 3 | Official implementation of "[Summarizing Stream Data for Memory-Restricted Online Continual Learning](https://arxiv.org/abs/2305.16645)" 4 | 5 | 6 |

7 | 8 | ## Highlights :sparkles: 9 | - SSD is accepted by AAAI 2024! 10 | - SSD summarizes the knowledge in the stream data into informative images for the replay memory. 11 | - Through maintaining the consistency of training gradients and relationship to the past tasks, the summarized samples are more representative for the stream data compared with original images. 12 | - SSD significantly enhances the replay effects for online continual learning methods with limited extra computational overhead. 13 | 14 | ## Datasets 15 | 16 | ### Online Class Incremental 17 | - Split CIFAR100 18 | - Split Mini-ImageNet 19 | - Split Tiny-ImageNet 20 | 21 | ### Data preparation 22 | - CIFAR10 & CIFAR100 will be downloaded during the first run 23 | - Mini-ImageNet: Download from https://www.kaggle.com/whitemoon/miniimagenet/download, and place it in datasets/mini_imagenet/ 24 | - Tiny-ImageNet: Download from http://cs231n.stanford.edu/tiny-imagenet-200.zip, place it in datasets/tiny-imagenet-200/. Copy `process_tiny_imagenet.py` to the directory and run it to suppress the dataset into pickles 25 | 26 | ## Run commands 27 | Detailed descriptions of options can be found in the `SSD` section in [general_main.py](general_main.py) 28 | 29 | ### Sample commands to run algorithms on Split-CIFAR100 30 | ```shell 31 | python general_main.py --data cifar100 --cl_type nc --agent SSCR --retrieve random --update summarize --mem_size 1000 --images_per_class 10 --head mlp --temp 0.07 --eps_mem_batch 100 --lr_img 4e-3 --summarize_interval 6 --queue_size 64 --mem_weight 1 --num_runs 10 32 | ``` 33 | 34 | ## Acknowledgement 35 | 36 | This project is mainly based on [online-continual-learning](https://github.com/RaptorMai/online-continual-learning) 37 | 38 | ## Citation 39 | 40 | If you find this work helpful, please cite: 41 | ``` 42 | @article{gu2023summarizing, 43 | title={Summarizing Stream Data for Memory-Restricted Online Continual Learning}, 44 | author={Gu, Jianyang and Wang, Kai and Jiang, Wei and You, Yang}, 45 | journal={arXiv preprint arXiv:2305.16645}, 46 | year={2023} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /agents/lwf.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | import copy 8 | 9 | 10 | class Lwf(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Lwf, self).__init__(model, opt, params) 13 | 14 | def train_learner(self, x_train, y_train): 15 | self.before_train(x_train, y_train) 16 | 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | 22 | # set up model 23 | self.model = self.model.train() 24 | 25 | # setup tracker 26 | losses_batch = AverageMeter() 27 | acc_batch = AverageMeter() 28 | 29 | for ep in range(self.epoch): 30 | for i, batch_data in enumerate(train_loader): 31 | # batch update 32 | batch_x, batch_y = batch_data 33 | batch_x = maybe_cuda(batch_x, self.cuda) 34 | batch_y = maybe_cuda(batch_y, self.cuda) 35 | 36 | logits = self.forward(batch_x) 37 | loss_old = self.kd_manager.get_kd_loss(logits, batch_x) 38 | loss_new = self.criterion(logits, batch_y) 39 | loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old 40 | _, pred_label = torch.max(logits, 1) 41 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 42 | # update tracker 43 | acc_batch.update(correct_cnt, batch_y.size(0)) 44 | losses_batch.update(loss, batch_y.size(0)) 45 | # backward 46 | self.opt.zero_grad() 47 | loss.backward() 48 | self.opt.step() 49 | 50 | if i % 100 == 1 and self.verbose: 51 | print( 52 | '==>>> it: {}, avg. loss: {:.6f}, ' 53 | 'running train acc: {:.3f}' 54 | .format(i, losses_batch.avg(), acc_batch.avg()) 55 | ) 56 | self.after_train() 57 | -------------------------------------------------------------------------------- /main_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.io import load_yaml 3 | from types import SimpleNamespace 4 | from utils.utils import boolean_string 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | from experiment.run import multiple_run_tune_separate 10 | from utils.setup_elements import default_trick 11 | 12 | def main(args): 13 | genereal_params = load_yaml(args.general) 14 | data_params = load_yaml(args.data) 15 | default_params = load_yaml(args.default) 16 | tune_params = load_yaml(args.tune) 17 | genereal_params['verbose'] = args.verbose 18 | genereal_params['cuda'] = torch.cuda.is_available() 19 | genereal_params['train_val'] = args.train_val 20 | if args.trick: 21 | default_trick[args.trick] = True 22 | genereal_params['trick'] = default_trick 23 | final_default_params = SimpleNamespace(**genereal_params, **data_params, **default_params) 24 | 25 | time_start = time.time() 26 | print(final_default_params) 27 | print() 28 | 29 | #reproduce 30 | np.random.seed(final_default_params.seed) 31 | random.seed(final_default_params.seed) 32 | torch.manual_seed(final_default_params.seed) 33 | if final_default_params.cuda: 34 | torch.cuda.manual_seed(final_default_params.seed) 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | #run 39 | multiple_run_tune_separate(final_default_params, tune_params, args.save_path) 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | # Commandline arguments 45 | parser = argparse.ArgumentParser('Continual Learning') 46 | parser.add_argument('--general', dest='general', default='config/general_1.yml') 47 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml') 48 | parser.add_argument('--default', dest='default', default='config/agent/er/er_1k.yml') 49 | parser.add_argument('--tune', dest='tune', default='config/agent/er/er_tune.yml') 50 | parser.add_argument('--save-path', dest='save_path', default=None) 51 | parser.add_argument('--verbose', type=boolean_string, default=False, 52 | help='print information or not') 53 | parser.add_argument('--train_val', type=boolean_string, default=False, 54 | help='use tha val batches to train') 55 | parser.add_argument('--trick', type=str, default=None) 56 | args = parser.parse_args() 57 | main(args) -------------------------------------------------------------------------------- /experiment/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import sem 3 | import scipy.stats as stats 4 | 5 | def compute_performance(end_task_acc_arr): 6 | """ 7 | Given test accuracy results from multiple runs saved in end_task_acc_arr, 8 | compute the average accuracy, forgetting, and task accuracies as well as their confidence intervals. 9 | 10 | :param end_task_acc_arr: (list) List of lists 11 | :param task_ids: (list or tuple) Task ids to keep track of 12 | :return: (avg_end_acc, forgetting, avg_acc_task) 13 | """ 14 | n_run, n_tasks = end_task_acc_arr.shape[:2] 15 | t_coef = stats.t.ppf((1+0.95) / 2, n_run-1) # t coefficient used to compute 95% CIs: mean +- t * 16 | 17 | # compute average test accuracy and CI 18 | end_acc = end_task_acc_arr[:, -1, :] # shape: (num_run, num_task) 19 | avg_acc_per_run = np.mean(end_acc, axis=1) # mean of end task accuracies per run 20 | avg_end_acc = (np.mean(avg_acc_per_run), t_coef * sem(avg_acc_per_run)) 21 | 22 | # compute forgetting 23 | best_acc = np.max(end_task_acc_arr, axis=1) 24 | final_forgets = best_acc - end_acc 25 | avg_fgt = np.mean(final_forgets, axis=1) 26 | avg_end_fgt = (np.mean(avg_fgt), t_coef * sem(avg_fgt)) 27 | 28 | # compute ACC 29 | acc_per_run = np.mean((np.sum(np.tril(end_task_acc_arr), axis=2) / 30 | (np.arange(n_tasks) + 1)), axis=1) 31 | avg_acc = (np.mean(acc_per_run), t_coef * sem(acc_per_run)) 32 | 33 | 34 | # compute BWT+ 35 | bwt_per_run = (np.sum(np.tril(end_task_acc_arr, -1), axis=(1,2)) - 36 | np.sum(np.diagonal(end_task_acc_arr, axis1=1, axis2=2) * 37 | (np.arange(n_tasks, 0, -1) - 1), axis=1)) / (n_tasks * (n_tasks - 1) / 2) 38 | bwtp_per_run = np.maximum(bwt_per_run, 0) 39 | avg_bwtp = (np.mean(bwtp_per_run), t_coef * sem(bwtp_per_run)) 40 | 41 | # compute FWT 42 | fwt_per_run = np.sum(np.triu(end_task_acc_arr, 1), axis=(1,2)) / (n_tasks * (n_tasks - 1) / 2) 43 | avg_fwt = (np.mean(fwt_per_run), t_coef * sem(fwt_per_run)) 44 | return avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt 45 | 46 | 47 | 48 | 49 | def single_run_avg_end_fgt(acc_array): 50 | best_acc = np.max(acc_array, axis=1) 51 | end_acc = acc_array[-1] 52 | final_forgets = best_acc - end_acc 53 | avg_fgt = np.mean(final_forgets) 54 | return avg_fgt 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /utils/buffer/reservoir_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Reservoir_update(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | 8 | def update(self, buffer, x, y, **kwargs): 9 | batch_size = x.size(0) 10 | 11 | # add whatever still fits in the buffer 12 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index) 13 | if place_left: 14 | offset = min(place_left, batch_size) 15 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset]) 16 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset]) 17 | 18 | 19 | buffer.current_index += offset 20 | buffer.n_seen_so_far += offset 21 | 22 | # everything was added 23 | if offset == x.size(0): 24 | filled_idx = list(range(buffer.current_index - offset, buffer.current_index, )) 25 | if buffer.params.buffer_tracker: 26 | buffer.buffer_tracker.update_cache(buffer.buffer_label, y[:offset], filled_idx) 27 | return filled_idx 28 | 29 | 30 | #TODO: the buffer tracker will have bug when the mem size can't be divided by batch size 31 | 32 | # remove what is already in the buffer 33 | x, y = x[place_left:], y[place_left:] 34 | 35 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long() 36 | valid_indices = (indices < buffer.buffer_img.size(0)).long() 37 | 38 | idx_new_data = valid_indices.nonzero().squeeze(-1) 39 | idx_buffer = indices[idx_new_data] 40 | 41 | buffer.n_seen_so_far += x.size(0) 42 | 43 | if idx_buffer.numel() == 0: 44 | return [] 45 | 46 | assert idx_buffer.max() < buffer.buffer_img.size(0) 47 | assert idx_buffer.max() < buffer.buffer_label.size(0) 48 | # assert idx_buffer.max() < self.buffer_task.size(0) 49 | 50 | assert idx_new_data.max() < x.size(0) 51 | assert idx_new_data.max() < y.size(0) 52 | 53 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))} 54 | 55 | replace_y = y[list(idx_map.values())] 56 | if buffer.params.buffer_tracker: 57 | buffer.buffer_tracker.update_cache(buffer.buffer_label, replace_y, list(idx_map.keys())) 58 | # perform overwrite op 59 | buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 60 | buffer.buffer_label[list(idx_map.keys())] = replace_y 61 | return list(idx_map.keys()) -------------------------------------------------------------------------------- /utils/buffer/mir_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import maybe_cuda 3 | import torch.nn.functional as F 4 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector 5 | import copy 6 | 7 | 8 | class MIR_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.params = params 12 | self.subsample = params.subsample 13 | self.num_retrieve = params.eps_mem_batch 14 | 15 | def retrieve(self, buffer, **kwargs): 16 | sub_x, sub_y = random_retrieve(buffer, self.subsample) 17 | grad_dims = [] 18 | for param in buffer.model.parameters(): 19 | grad_dims.append(param.data.numel()) 20 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 21 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 22 | if sub_x.size(0) > 0: 23 | with torch.no_grad(): 24 | logits_pre = buffer.model.forward(sub_x) 25 | logits_post = model_temp.forward(sub_x) 26 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none') 27 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none') 28 | scores = post_loss - pre_loss 29 | big_ind = scores.sort(descending=True)[1][:self.num_retrieve] 30 | return sub_x[big_ind], sub_y[big_ind] 31 | else: 32 | return sub_x, sub_y 33 | 34 | def get_future_step_parameters(self, model, grad_vector, grad_dims): 35 | """ 36 | computes \theta-\delta\theta 37 | :param this_net: 38 | :param grad_vector: 39 | :return: 40 | """ 41 | new_model = copy.deepcopy(model) 42 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims) 43 | with torch.no_grad(): 44 | for param in new_model.parameters(): 45 | if param.grad is not None: 46 | param.data = param.data - self.params.learning_rate * param.grad.data 47 | return new_model 48 | 49 | def overwrite_grad(self, pp, new_grad, grad_dims): 50 | """ 51 | This is used to overwrite the gradients with a new gradient 52 | vector, whenever violations occur. 53 | pp: parameters 54 | newgrad: corrected gradient 55 | grad_dims: list storing number of parameters at each layer 56 | """ 57 | cnt = 0 58 | for param in pp(): 59 | param.grad = torch.zeros_like(param.data) 60 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 61 | en = sum(grad_dims[:cnt + 1]) 62 | this_grad = new_grad[beg: en].contiguous().view( 63 | param.data.size()) 64 | param.grad.data.copy_(this_grad) 65 | cnt += 1 -------------------------------------------------------------------------------- /utils/setup_elements.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.resnet import Reduced_ResNet18, SupConResNet 3 | from torchvision import transforms 4 | import torch.nn as nn 5 | 6 | 7 | default_trick = {'labels_trick': False, 'kd_trick': False, 'separated_softmax': False, 8 | 'review_trick': False, 'ncm_trick': False, 'kd_trick_star': False} 9 | 10 | 11 | input_size_match = { 12 | 'cifar100': [3, 32, 32], 13 | 'cifar10': [3, 32, 32], 14 | 'core50': [3, 128, 128], 15 | 'mini_imagenet': [3, 84, 84], 16 | 'tiny_imagenet': [3, 64, 64], 17 | 'openloris': [3, 50, 50] 18 | } 19 | 20 | 21 | n_classes = { 22 | 'cifar100': 100, 23 | 'cifar10': 10, 24 | 'core50': 50, 25 | 'mini_imagenet': 100, 26 | 'tiny_imagenet': 200, 27 | 'openloris': 69 28 | } 29 | 30 | 31 | transforms_match = { 32 | 'core50': transforms.Compose([ 33 | transforms.ToTensor(), 34 | ]), 35 | 'cifar100': transforms.Compose([ 36 | transforms.ToTensor(), 37 | ]), 38 | 'cifar10': transforms.Compose([ 39 | transforms.ToTensor(), 40 | ]), 41 | 'mini_imagenet': transforms.Compose([ 42 | transforms.ToTensor()]), 43 | 'tiny_imagenet': transforms.Compose([ 44 | transforms.ToTensor()]), 45 | 'openloris': transforms.Compose([ 46 | transforms.ToTensor()]) 47 | } 48 | 49 | 50 | def setup_architecture(params): 51 | nclass = n_classes[params.data] 52 | if params.agent in ['SCR', 'SCP', 'SSCR']: 53 | if params.data == 'mini_imagenet' or params.data == 'tiny_imagenet': 54 | return SupConResNet(640, head=params.head) 55 | return SupConResNet(head=params.head) 56 | if params.agent == 'CNDPM': 57 | from models.ndpm.ndpm import Ndpm 58 | return Ndpm(params) 59 | if params.data == 'cifar100': 60 | return Reduced_ResNet18(nclass) 61 | elif params.data == 'cifar10': 62 | return Reduced_ResNet18(nclass) 63 | elif params.data == 'core50': 64 | model = Reduced_ResNet18(nclass) 65 | model.linear = nn.Linear(2560, nclass, bias=True) 66 | return model 67 | elif params.data == 'mini_imagenet' or params.data == 'tiny_imagenet': 68 | model = Reduced_ResNet18(nclass) 69 | model.linear = nn.Linear(640, nclass, bias=True) 70 | return model 71 | elif params.data == 'openloris': 72 | return Reduced_ResNet18(nclass) 73 | 74 | 75 | def setup_opt(optimizer, model, lr, wd): 76 | if optimizer == 'SGD': 77 | optim = torch.optim.SGD(model.parameters(), 78 | lr=lr, 79 | weight_decay=wd) 80 | elif optimizer == 'Adam': 81 | optim = torch.optim.Adam(model.parameters(), 82 | lr=lr, 83 | betas=(0.9, 0.99), 84 | weight_decay=wd) 85 | else: 86 | raise Exception('wrong optimizer name') 87 | return optim 88 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR100(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar100' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR100, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR100(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR100(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc': 36 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.test_set = [] 38 | for labels in self.task_labels: 39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 40 | self.test_set.append((x_test, y_test)) 41 | else: 42 | raise Exception('wrong scenario') 43 | 44 | def new_task(self, cur_task, **kwargs): 45 | if self.scenario == 'ni': 46 | x_train, y_train = self.train_set[cur_task] 47 | labels = set(y_train) 48 | elif self.scenario == 'nc': 49 | labels = self.task_labels[cur_task] 50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 51 | return x_train, y_train, labels 52 | 53 | def new_run(self, **kwargs): 54 | self.setup() 55 | return self.test_set 56 | 57 | def test_plot(self): 58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 59 | self.params.ns_factor) 60 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR10(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar10' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR10, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR10(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR10(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc': 36 | self.task_labels = create_task_composition(class_nums=10, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.test_set = [] 38 | for labels in self.task_labels: 39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 40 | self.test_set.append((x_test, y_test)) 41 | else: 42 | raise Exception('wrong scenario') 43 | 44 | def new_task(self, cur_task, **kwargs): 45 | if self.scenario == 'ni': 46 | x_train, y_train = self.train_set[cur_task] 47 | labels = set(y_train) 48 | elif self.scenario == 'nc': 49 | labels = self.task_labels[cur_task] 50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 51 | return x_train, y_train, labels 52 | 53 | def new_run(self, **kwargs): 54 | self.setup() 55 | return self.test_set 56 | 57 | def test_plot(self): 58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 59 | self.params.ns_factor) 60 | -------------------------------------------------------------------------------- /agents/icarl.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils import utils 4 | from utils.buffer.buffer_utils import random_retrieve 5 | from utils.setup_elements import transforms_match 6 | from torch.utils import data 7 | import numpy as np 8 | from torch.nn import functional as F 9 | from utils.utils import maybe_cuda, AverageMeter 10 | from utils.buffer.buffer import Buffer 11 | import torch 12 | import copy 13 | 14 | 15 | class Icarl(ContinualLearner): 16 | def __init__(self, model, opt, params): 17 | super(Icarl, self).__init__(model, opt, params) 18 | self.model = model 19 | self.mem_size = params.mem_size 20 | self.buffer = Buffer(model, params) 21 | self.prev_model = None 22 | 23 | def train_learner(self, x_train, y_train): 24 | self.before_train(x_train, y_train) 25 | # set up loader 26 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 27 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 28 | drop_last=True) 29 | self.model.train() 30 | self.update_representation(train_loader) 31 | self.prev_model = copy.deepcopy(self.model) 32 | self.after_train() 33 | 34 | def update_representation(self, train_loader): 35 | updated_idx = [] 36 | for ep in range(self.epoch): 37 | for i, train_data in enumerate(train_loader): 38 | # batch update 39 | train_x, train_y = train_data 40 | train_x = maybe_cuda(train_x, self.cuda) 41 | train_y = maybe_cuda(train_y, self.cuda) 42 | train_y_copy = train_y.clone() 43 | for k, y in enumerate(train_y_copy): 44 | train_y_copy[k] = len(self.old_labels) + self.new_labels.index(y) 45 | all_cls_num = len(self.new_labels) + len(self.old_labels) 46 | target_labels = utils.ohe_label(train_y_copy, all_cls_num, device=train_y_copy.device).float() 47 | if self.prev_model is not None: 48 | mem_x, mem_y = random_retrieve(self.buffer, self.batch, 49 | excl_indices=updated_idx) 50 | mem_x = maybe_cuda(mem_x, self.cuda) 51 | batch_x = torch.cat([train_x, mem_x]) 52 | target_labels = torch.cat([target_labels, torch.zeros_like(target_labels)]) 53 | else: 54 | batch_x = train_x 55 | logits = self.forward(batch_x) 56 | self.opt.zero_grad() 57 | if self.prev_model is not None: 58 | with torch.no_grad(): 59 | q = torch.sigmoid(self.prev_model.forward(batch_x)) 60 | for k, y in enumerate(self.old_labels): 61 | target_labels[:, k] = q[:, k] 62 | loss = F.binary_cross_entropy_with_logits(logits[:, :all_cls_num], target_labels, reduction='none').sum(dim=1).mean() 63 | loss.backward() 64 | self.opt.step() 65 | updated_idx += self.buffer.update(train_x, train_y) 66 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | TEST_SPLIT = 1 / 6 8 | 9 | 10 | class TinyImageNet(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'tiny_imagenet' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | super(TinyImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 18 | 19 | def download_load(self): 20 | train_dir = './datasets/tiny-imagenet-200/train.pkl' 21 | test_dir = './datasets/tiny-imagenet-200/val.pkl' 22 | 23 | train = pickle.load(open(train_dir, 'rb')) 24 | self.train_data = train['data'].reshape((100000, 64, 64, 3)) 25 | self.train_label = train['target'] 26 | 27 | test = pickle.load(open(test_dir, 'rb')) 28 | self.test_data = test['data'].reshape((10000, 64, 64, 3)) 29 | self.test_label = test['target'] 30 | 31 | def new_run(self, **kwargs): 32 | self.setup() 33 | return self.test_set 34 | 35 | def new_task(self, cur_task, **kwargs): 36 | if self.scenario == 'ni': 37 | x_train, y_train = self.train_set[cur_task] 38 | labels = set(y_train) 39 | elif self.scenario == 'nc': 40 | labels = self.task_labels[cur_task] 41 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 42 | else: 43 | raise Exception('unrecognized scenario') 44 | return x_train, y_train, labels 45 | 46 | def setup(self): 47 | if self.scenario == 'ni': 48 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 49 | self.train_label, 50 | self.test_data, self.test_label, 51 | self.task_nums, 84, 52 | self.params.val_size, 53 | self.params.ns_type, self.params.ns_factor, 54 | plot=self.params.plot_sample) 55 | 56 | elif self.scenario == 'nc': 57 | self.task_labels = create_task_composition(class_nums=200, num_tasks=self.task_nums, 58 | fixed_order=self.params.fix_order) 59 | self.test_set = [] 60 | for labels in self.task_labels: 61 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 62 | self.test_set.append((x_test, y_test)) 63 | 64 | def test_plot(self): 65 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 66 | self.params.ns_factor) 67 | -------------------------------------------------------------------------------- /agents/scr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer, DynamicBuffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform, BalancedSampler 6 | from utils.setup_elements import transforms_match, input_size_match 7 | from utils.utils import maybe_cuda, AverageMeter 8 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale 9 | import torch.nn as nn 10 | 11 | class SupContrastReplay(ContinualLearner): 12 | def __init__(self, model, opt, params): 13 | super(SupContrastReplay, self).__init__(model, opt, params) 14 | self.buffer = Buffer(model, params) 15 | self.mem_size = params.mem_size 16 | self.eps_mem_batch = params.eps_mem_batch 17 | self.mem_iters = params.mem_iters 18 | self.transform = nn.Sequential( 19 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)), 20 | RandomHorizontalFlip(), 21 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 22 | RandomGrayscale(p=0.2) 23 | 24 | ) 25 | 26 | def train_learner(self, x_train, y_train, **kwargs): 27 | self.before_train(x_train, y_train) 28 | #self.buffer.new_condense_task() 29 | self.buffer.new_task() 30 | # set up loader 31 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 32 | train_sampler = BalancedSampler(x_train, y_train, self.batch) 33 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, num_workers=0, 34 | drop_last=True, sampler=train_sampler) 35 | # set up model 36 | self.model = self.model.train() 37 | 38 | # setup tracker 39 | losses = AverageMeter() 40 | acc_batch = AverageMeter() 41 | 42 | for ep in range(self.epoch): 43 | for i, batch_data in enumerate(train_loader): 44 | # batch update 45 | batch_x, batch_y = batch_data 46 | batch_x = maybe_cuda(batch_x, self.cuda) 47 | batch_y = maybe_cuda(batch_y, self.cuda) 48 | 49 | for j in range(self.mem_iters): 50 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 51 | 52 | if mem_x.size(0) > 0: 53 | mem_x = maybe_cuda(mem_x, self.cuda) 54 | mem_y = maybe_cuda(mem_y, self.cuda) 55 | combined_batch = torch.cat((mem_x, batch_x)) 56 | combined_labels = torch.cat((mem_y, batch_y)) 57 | combined_batch_aug = self.transform(combined_batch) 58 | features = torch.cat([self.model.forward(combined_batch).unsqueeze(1), self.model.forward(combined_batch_aug).unsqueeze(1)], dim=1) 59 | loss = self.criterion(features, combined_labels) 60 | losses.update(loss, batch_y.size(0)) 61 | self.opt.zero_grad() 62 | loss.backward() 63 | self.opt.step() 64 | 65 | # update mem 66 | self.buffer.update(batch_x, batch_y) 67 | if i % 100 == 1 and self.verbose: 68 | print( 69 | '==>>> it: {}, avg. loss: {:.6f}, ' 70 | .format(i, losses.avg(), acc_batch.avg()) 71 | ) 72 | self.after_train() 73 | -------------------------------------------------------------------------------- /agents/gdumb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import math 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match, setup_architecture, setup_opt 7 | from utils.utils import maybe_cuda, EarlyStopping 8 | import numpy as np 9 | import random 10 | 11 | 12 | class Gdumb(ContinualLearner): 13 | def __init__(self, model, opt, params): 14 | super(Gdumb, self).__init__(model, opt, params) 15 | self.mem_img = {} 16 | self.mem_c = {} 17 | #self.early_stopping = EarlyStopping(self.params.min_delta, self.params.patience, self.params.cumulative_delta) 18 | 19 | def greedy_balancing_update(self, x, y): 20 | k_c = self.params.mem_size // max(1, len(self.mem_img)) 21 | if y not in self.mem_img or self.mem_c[y] < k_c: 22 | if sum(self.mem_c.values()) >= self.params.mem_size: 23 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0] 24 | idx = random.randrange(self.mem_c[cls_max]) 25 | self.mem_img[cls_max].pop(idx) 26 | self.mem_c[cls_max] -= 1 27 | if y not in self.mem_img: 28 | self.mem_img[y] = [] 29 | self.mem_c[y] = 0 30 | self.mem_img[y].append(x) 31 | self.mem_c[y] += 1 32 | 33 | def train_learner(self, x_train, y_train): 34 | self.before_train(x_train, y_train) 35 | # set up loader 36 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 37 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 38 | drop_last=True) 39 | 40 | for i, batch_data in enumerate(train_loader): 41 | # batch update 42 | batch_x, batch_y = batch_data 43 | batch_x = maybe_cuda(batch_x, self.cuda) 44 | batch_y = maybe_cuda(batch_y, self.cuda) 45 | # update mem 46 | for j in range(len(batch_x)): 47 | self.greedy_balancing_update(batch_x[j], batch_y[j].item()) 48 | #self.early_stopping.reset() 49 | self.train_mem() 50 | self.after_train() 51 | 52 | def train_mem(self): 53 | mem_x = [] 54 | mem_y = [] 55 | for i in self.mem_img.keys(): 56 | mem_x += self.mem_img[i] 57 | mem_y += [i] * self.mem_c[i] 58 | 59 | mem_x = torch.stack(mem_x) 60 | mem_y = torch.LongTensor(mem_y) 61 | self.model = setup_architecture(self.params) 62 | self.model = maybe_cuda(self.model, self.cuda) 63 | opt = setup_opt(self.params.optimizer, self.model, self.params.learning_rate, self.params.weight_decay) 64 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1, T_mult=2, eta_min=self.params.minlr) 65 | 66 | #loss = math.inf 67 | for i in range(self.params.mem_epoch): 68 | idx = np.random.permutation(len(mem_x)).tolist() 69 | mem_x = maybe_cuda(mem_x[idx], self.cuda) 70 | mem_y = maybe_cuda(mem_y[idx], self.cuda) 71 | self.model = self.model.train() 72 | batch_size = self.params.batch 73 | #scheduler.step() 74 | #if opt.param_groups[0]['lr'] == self.params.learning_rate: 75 | # if self.early_stopping.step(-loss): 76 | # return 77 | for j in range(len(mem_y) // batch_size): 78 | opt.zero_grad() 79 | logits = self.model.forward(mem_x[batch_size * j:batch_size * (j + 1)]) 80 | loss = self.criterion(logits, mem_y[batch_size * j:batch_size * (j + 1)]) 81 | loss.backward() 82 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip) 83 | opt.step() 84 | 85 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all'): 15 | super(SupConLoss, self).__init__() 16 | self.temperature = temperature 17 | self.contrast_mode = contrast_mode 18 | 19 | def forward(self, features, labels=None, mask=None): 20 | """Compute loss for model. If both `labels` and `mask` are None, 21 | it degenerates to SimCLR unsupervised loss: 22 | https://arxiv.org/pdf/2002.05709.pdf 23 | 24 | Args: 25 | features: hidden vector of shape [bsz, n_views, ...]. 26 | labels: ground truth of shape [bsz]. 27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 28 | has the same class as sample i. Can be asymmetric. 29 | Returns: 30 | A loss scalar. 31 | """ 32 | device = (torch.device('cuda') 33 | if features.is_cuda 34 | else torch.device('cpu')) 35 | 36 | if len(features.shape) < 3: 37 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 38 | 'at least 3 dimensions are required') 39 | if len(features.shape) > 3: 40 | features = features.view(features.shape[0], features.shape[1], -1) 41 | 42 | batch_size = features.shape[0] 43 | if labels is not None and mask is not None: 44 | raise ValueError('Cannot define both `labels` and `mask`') 45 | elif labels is None and mask is None: 46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 47 | elif labels is not None: 48 | labels = labels.contiguous().view(-1, 1) 49 | if labels.shape[0] != batch_size: 50 | raise ValueError('Num of labels does not match num of features') 51 | mask = torch.eq(labels, labels.T).float().to(device) 52 | else: 53 | mask = mask.float().to(device) 54 | 55 | contrast_count = features.shape[1] 56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 57 | if self.contrast_mode == 'one': 58 | anchor_feature = features[:, 0] 59 | anchor_count = 1 60 | elif self.contrast_mode == 'all': 61 | anchor_feature = contrast_feature 62 | anchor_count = contrast_count 63 | else: 64 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 65 | 66 | # compute logits 67 | anchor_dot_contrast = torch.div( 68 | torch.matmul(anchor_feature, contrast_feature.T), 69 | self.temperature) 70 | # for numerical stability 71 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 72 | logits = anchor_dot_contrast - logits_max.detach() 73 | 74 | # tile mask 75 | mask = mask.repeat(anchor_count, contrast_count) 76 | # mask-out self-contrast cases 77 | logits_mask = torch.scatter( 78 | torch.ones_like(mask), 79 | 1, 80 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 81 | 0 82 | ) 83 | mask = mask * logits_mask 84 | 85 | # compute log_prob 86 | exp_logits = torch.exp(logits) * logits_mask 87 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 88 | 89 | # compute mean of log-likelihood over positive 90 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 91 | 92 | # loss 93 | loss = -1 * mean_log_prob_pos 94 | loss = loss.view(anchor_count, batch_size).mean() 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /agents/summarize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils import data 4 | from utils.buffer.buffer import Buffer, DynamicBuffer 5 | from agents.base import ContinualLearner 6 | from continuum.data_utils import dataset_transform, BalancedSampler 7 | from utils.setup_elements import transforms_match, input_size_match 8 | from utils.utils import maybe_cuda, AverageMeter 9 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale 10 | import torch.nn as nn 11 | from torchvision.utils import make_grid, save_image 12 | 13 | 14 | class SummarizeContrastReplay(ContinualLearner): 15 | def __init__(self, model, opt, params): 16 | super(SummarizeContrastReplay, self).__init__(model, opt, params) 17 | self.buffer = DynamicBuffer(model, params) 18 | self.mem_size = params.mem_size 19 | self.eps_mem_batch = params.eps_mem_batch 20 | self.mem_iters = params.mem_iters 21 | self.transform = nn.Sequential( 22 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)), 23 | RandomHorizontalFlip(), 24 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 25 | RandomGrayscale(p=0.2) 26 | 27 | ) 28 | self.queue_size = params.queue_size 29 | 30 | def train_learner(self, x_train, y_train, labels): 31 | self.before_train(x_train, y_train) 32 | # set up loader 33 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 34 | train_sampler = BalancedSampler(x_train, y_train, self.batch) 35 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, num_workers=0, 36 | drop_last=True, sampler=train_sampler) 37 | # set up model 38 | self.model = self.model.train() 39 | self.buffer.new_condense_task(labels) 40 | 41 | # setup tracker 42 | losses = AverageMeter() 43 | acc_batch = AverageMeter() 44 | 45 | aff_x = [] 46 | aff_y = [] 47 | for ep in range(self.epoch): 48 | for i, batch_data in enumerate(train_loader): 49 | # batch update 50 | batch_x, batch_y = batch_data 51 | batch_x = maybe_cuda(batch_x, self.cuda) 52 | batch_y = maybe_cuda(batch_y, self.cuda) 53 | 54 | for j in range(self.mem_iters): 55 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 56 | 57 | if mem_x.size(0) > 0: 58 | mem_x = maybe_cuda(mem_x, self.cuda) 59 | mem_y = maybe_cuda(mem_y, self.cuda) 60 | combined_batch = torch.cat((mem_x, batch_x)) 61 | combined_labels = torch.cat((mem_y, batch_y)) 62 | combined_batch_aug = self.transform(combined_batch) 63 | features = torch.cat([self.model.forward(combined_batch).unsqueeze(1), self.model.forward(combined_batch_aug).unsqueeze(1)], dim=1) 64 | loss = self.criterion(features, combined_labels) 65 | losses.update(loss, batch_y.size(0)) 66 | self.opt.zero_grad() 67 | loss.backward() 68 | self.opt.step() 69 | 70 | # update memory 71 | aff_x.append(batch_x) 72 | aff_y.append(batch_y) 73 | if len(aff_x) > self.queue_size: 74 | aff_x.pop(0) 75 | aff_y.pop(0) 76 | self.buffer.update(batch_x, batch_y, aff_x=aff_x, aff_y=aff_y, update_index=i, transform=self.transform) 77 | 78 | if i % 100 == 1 and self.verbose: 79 | print( 80 | '==>>> it: {}, avg. loss: {:.6f}, ' 81 | .format(i, losses.avg(), acc_batch.avg()) 82 | ) 83 | self.after_train() 84 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | TEST_SPLIT = 1 / 6 8 | 9 | 10 | class Mini_ImageNet(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'mini_imagenet' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | super(Mini_ImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | train_in = open("./datasets/mini_imagenet/mini-imagenet-cache-train.pkl", "rb") 22 | train = pickle.load(train_in) 23 | train_x = train["image_data"].reshape([64, 600, 84, 84, 3]) 24 | val_in = open("./datasets/mini_imagenet/mini-imagenet-cache-val.pkl", "rb") 25 | val = pickle.load(val_in) 26 | val_x = val['image_data'].reshape([16, 600, 84, 84, 3]) 27 | test_in = open("./datasets/mini_imagenet/mini-imagenet-cache-test.pkl", "rb") 28 | test = pickle.load(test_in) 29 | test_x = test['image_data'].reshape([20, 600, 84, 84, 3]) 30 | all_data = np.vstack((train_x, val_x, test_x)) 31 | train_data = [] 32 | train_label = [] 33 | test_data = [] 34 | test_label = [] 35 | for i in range(len(all_data)): 36 | cur_x = all_data[i] 37 | cur_y = np.ones((600,)) * i 38 | rdm_x, rdm_y = shuffle_data(cur_x, cur_y) 39 | x_test = rdm_x[: int(600 * TEST_SPLIT)] 40 | y_test = rdm_y[: int(600 * TEST_SPLIT)] 41 | x_train = rdm_x[int(600 * TEST_SPLIT):] 42 | y_train = rdm_y[int(600 * TEST_SPLIT):] 43 | train_data.append(x_train) 44 | train_label.append(y_train) 45 | test_data.append(x_test) 46 | test_label.append(y_test) 47 | self.train_data = np.concatenate(train_data) 48 | self.train_label = np.concatenate(train_label) 49 | self.test_data = np.concatenate(test_data) 50 | self.test_label = np.concatenate(test_label) 51 | 52 | def new_run(self, **kwargs): 53 | self.setup() 54 | return self.test_set 55 | 56 | def new_task(self, cur_task, **kwargs): 57 | if self.scenario == 'ni': 58 | x_train, y_train = self.train_set[cur_task] 59 | labels = set(y_train) 60 | elif self.scenario == 'nc': 61 | labels = self.task_labels[cur_task] 62 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 63 | else: 64 | raise Exception('unrecognized scenario') 65 | return x_train, y_train, labels 66 | 67 | def setup(self): 68 | if self.scenario == 'ni': 69 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 70 | self.train_label, 71 | self.test_data, self.test_label, 72 | self.task_nums, 84, 73 | self.params.val_size, 74 | self.params.ns_type, self.params.ns_factor, 75 | plot=self.params.plot_sample) 76 | 77 | elif self.scenario == 'nc': 78 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, 79 | fixed_order=self.params.fix_order) 80 | self.test_set = [] 81 | for labels in self.task_labels: 82 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 83 | self.test_set.append((x_test, y_test)) 84 | 85 | def test_plot(self): 86 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 87 | self.params.ns_factor) 88 | -------------------------------------------------------------------------------- /models/ndpm/component.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | from typing import Tuple 5 | 6 | from utils.utils import maybe_cuda 7 | from utils.global_vars import * 8 | 9 | 10 | class Component(nn.Module, ABC): 11 | def __init__(self, params, experts: Tuple): 12 | super().__init__() 13 | self.params = params 14 | self.experts = experts 15 | 16 | self.optimizer = NotImplemented 17 | self.lr_scheduler = NotImplemented 18 | 19 | @abstractmethod 20 | def nll(self, x, y, step=None): 21 | """Return NLL""" 22 | pass 23 | 24 | @abstractmethod 25 | def collect_nll(self, x, y, step=None): 26 | """Return NLLs including previous experts""" 27 | pass 28 | 29 | def _clip_grad_value(self, clip_value): 30 | for group in self.optimizer.param_groups: 31 | nn.utils.clip_grad_value_(group['params'], clip_value) 32 | 33 | def _clip_grad_norm(self, max_norm, norm_type=2): 34 | for group in self.optimizer.param_groups: 35 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type) 36 | 37 | def clip_grad(self): 38 | clip_grad_config = MODELS_NDPM_COMPONENT_CLIP_GRAD 39 | if clip_grad_config['type'] == 'value': 40 | self._clip_grad_value(**clip_grad_config['options']) 41 | elif clip_grad_config['type'] == 'norm': 42 | self._clip_grad_norm(**clip_grad_config['options']) 43 | else: 44 | raise ValueError('Invalid clip_grad type: {}' 45 | .format(clip_grad_config['type'])) 46 | 47 | @staticmethod 48 | def build_optimizer(optim_config, params): 49 | return getattr(torch.optim, optim_config['type'])( 50 | params, **optim_config['options']) 51 | 52 | @staticmethod 53 | def build_lr_scheduler(lr_config, optimizer): 54 | return getattr(torch.optim.lr_scheduler, lr_config['type'])( 55 | optimizer, **lr_config['options']) 56 | 57 | def weight_decay_loss(self): 58 | loss = maybe_cuda(torch.zeros([])) 59 | for param in self.parameters(): 60 | loss += torch.norm(param) ** 2 61 | return loss 62 | 63 | 64 | class ComponentG(Component, ABC): 65 | def setup_optimizer(self): 66 | self.optimizer = self.build_optimizer( 67 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 68 | self.lr_scheduler = self.build_lr_scheduler( 69 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G, self.optimizer) 70 | 71 | def collect_nll(self, x, y=None, step=None): 72 | """Default `collect_nll` 73 | 74 | Warning: Parameter-sharing components should implement their own 75 | `collect_nll` 76 | 77 | Returns: 78 | nll: Tensor of shape [B, 1+K] 79 | """ 80 | outputs = [expert.g.nll(x, y, step) for expert in self.experts] 81 | nll = outputs 82 | output = self.nll(x, y, step) 83 | nll.append(output) 84 | return torch.stack(nll, dim=1) 85 | 86 | 87 | 88 | class ComponentD(Component, ABC): 89 | def setup_optimizer(self): 90 | self.optimizer = self.build_optimizer( 91 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 92 | self.lr_scheduler = self.build_lr_scheduler( 93 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D, self.optimizer) 94 | 95 | def collect_forward(self, x): 96 | """Default `collect_forward` 97 | 98 | Warning: Parameter-sharing components should implement their own 99 | `collect_forward` 100 | 101 | Returns: 102 | output: Tensor of shape [B, 1+K, C] 103 | """ 104 | outputs = [expert.d(x) for expert in self.experts] 105 | outputs.append(self.forward(x)) 106 | return torch.stack(outputs, 1) 107 | 108 | def collect_nll(self, x, y, step=None): 109 | """Default `collect_nll` 110 | 111 | Warning: Parameter-sharing components should implement their own 112 | `collect_nll` 113 | 114 | Returns: 115 | nll: Tensor of shape [B, 1+K] 116 | """ 117 | outputs = [expert.d.nll(x, y, step) for expert in self.experts] 118 | nll = outputs 119 | output = self.nll(x, y, step) 120 | nll.append(output) 121 | return torch.stack(nll, dim=1) 122 | -------------------------------------------------------------------------------- /utils/buffer/aser_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.buffer_utils import random_retrieve, ClassBalancedRandomSampling 3 | from utils.buffer.aser_utils import compute_knn_sv 4 | from utils.utils import maybe_cuda 5 | from utils.setup_elements import n_classes 6 | 7 | 8 | class ASER_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.num_retrieve = params.eps_mem_batch 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.aser_type = params.aser_type 16 | self.n_smp_cls = int(params.n_smp_cls) 17 | self.out_dim = n_classes[params.data] 18 | self.is_aser_upt = params.update == "ASER" 19 | ClassBalancedRandomSampling.class_index_cache = None 20 | 21 | def retrieve(self, buffer, **kwargs): 22 | model = buffer.model 23 | 24 | if buffer.n_seen_so_far <= self.mem_size: 25 | # Use random retrieval until buffer is filled 26 | ret_x, ret_y = random_retrieve(buffer, self.num_retrieve) 27 | else: 28 | # Use ASER retrieval if buffer is filled 29 | cur_x, cur_y = kwargs['x'], kwargs['y'] 30 | buffer_x, buffer_y = buffer.buffer_img, buffer.buffer_label 31 | ret_x, ret_y = self._retrieve_by_knn_sv(model, buffer_x, buffer_y, cur_x, cur_y, self.num_retrieve) 32 | return ret_x, ret_y 33 | 34 | def _retrieve_by_knn_sv(self, model, buffer_x, buffer_y, cur_x, cur_y, num_retrieve): 35 | """ 36 | Retrieves data instances with top-N Shapley Values from candidate set. 37 | Args: 38 | model (object): neural network. 39 | buffer_x (tensor): data buffer. 40 | buffer_y (tensor): label buffer. 41 | cur_x (tensor): current input data tensor. 42 | cur_y (tensor): current input label tensor. 43 | num_retrieve (int): number of data instances to be retrieved. 44 | Returns 45 | ret_x (tensor): retrieved data tensor. 46 | ret_y (tensor): retrieved label tensor. 47 | """ 48 | cur_x = maybe_cuda(cur_x) 49 | cur_y = maybe_cuda(cur_y) 50 | 51 | # Reset and update ClassBalancedRandomSampling cache if ASER update is not enabled 52 | if not self.is_aser_upt: 53 | ClassBalancedRandomSampling.update_cache(buffer_y, self.out_dim) 54 | 55 | # Get candidate data for retrieval (i.e., cand <- class balanced subsamples from memory) 56 | cand_x, cand_y, cand_ind = \ 57 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, device=self.device) 58 | 59 | # Type 1 - Adversarial SV 60 | # Get evaluation data for type 1 (i.e., eval <- current input) 61 | eval_adv_x, eval_adv_y = cur_x, cur_y 62 | # Compute adversarial Shapley value of candidate data 63 | # (i.e., sv wrt current input) 64 | sv_matrix_adv = compute_knn_sv(model, eval_adv_x, eval_adv_y, cand_x, cand_y, self.k, device=self.device) 65 | 66 | if self.aser_type != "neg_sv": 67 | # Type 2 - Cooperative SV 68 | # Get evaluation data for type 2 69 | # (i.e., eval <- class balanced subsamples from memory excluding those already in candidate set) 70 | excl_indices = set(cand_ind.tolist()) 71 | eval_coop_x, eval_coop_y, _ = \ 72 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, 73 | excl_indices=excl_indices, device=self.device) 74 | # Compute Shapley value 75 | sv_matrix_coop = \ 76 | compute_knn_sv(model, eval_coop_x, eval_coop_y, cand_x, cand_y, self.k, device=self.device) 77 | if self.aser_type == "asv": 78 | # Use extremal SVs for computation 79 | sv = sv_matrix_coop.max(0).values - sv_matrix_adv.min(0).values 80 | else: 81 | # Use mean variation for aser_type == "asvm" or anything else 82 | sv = sv_matrix_coop.mean(0) - sv_matrix_adv.mean(0) 83 | else: 84 | # aser_type == "neg_sv" 85 | # No Type 1 - Cooperative SV; Use sum of Adversarial SV only 86 | sv = sv_matrix_adv.sum(0) * -1 87 | 88 | ret_ind = sv.argsort(descending=True) 89 | 90 | ret_x = cand_x[ret_ind][:num_retrieve] 91 | ret_y = cand_y[ret_ind][:num_retrieve] 92 | return ret_x, ret_y 93 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def maybe_cuda(what, use_cuda=True, **kw): 5 | """ 6 | Moves `what` to CUDA and returns it, if `use_cuda` and it's available. 7 | Args: 8 | what (object): any object to move to eventually gpu 9 | use_cuda (bool): if we want to use gpu or cpu. 10 | Returns 11 | object: the same object but eventually moved to gpu. 12 | """ 13 | 14 | if use_cuda is not False and torch.cuda.is_available(): 15 | what = what.cuda() 16 | return what 17 | 18 | 19 | def boolean_string(s): 20 | if s not in {'False', 'True'}: 21 | raise ValueError('Not a valid boolean string') 22 | return s == 'True' 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n): 36 | self.sum += val * n 37 | self.count += n 38 | 39 | def avg(self): 40 | if self.count == 0: 41 | return 0 42 | return float(self.sum) / self.count 43 | 44 | 45 | def mini_batch_deep_features(model, total_x, num): 46 | """ 47 | Compute deep features with mini-batches. 48 | Args: 49 | model (object): neural network. 50 | total_x (tensor): data tensor. 51 | num (int): number of data. 52 | Returns 53 | deep_features (tensor): deep feature representation of data tensor. 54 | """ 55 | is_train = False 56 | if model.training: 57 | is_train = True 58 | model.eval() 59 | if hasattr(model, "features"): 60 | model_has_feature_extractor = True 61 | else: 62 | model_has_feature_extractor = False 63 | # delete the last fully connected layer 64 | modules = list(model.children())[:-1] 65 | # make feature extractor 66 | model_features = torch.nn.Sequential(*modules) 67 | 68 | with torch.no_grad(): 69 | bs = 64 70 | num_itr = num // bs + int(num % bs > 0) 71 | sid = 0 72 | deep_features_list = [] 73 | for i in range(num_itr): 74 | eid = sid + bs if i != num_itr - 1 else num 75 | batch_x = total_x[sid: eid] 76 | 77 | if model_has_feature_extractor: 78 | batch_deep_features_ = model.features(batch_x) 79 | else: 80 | batch_deep_features_ = torch.squeeze(model_features(batch_x)) 81 | 82 | deep_features_list.append(batch_deep_features_.reshape((batch_x.size(0), -1))) 83 | sid = eid 84 | if num_itr == 1: 85 | deep_features_ = deep_features_list[0] 86 | else: 87 | deep_features_ = torch.cat(deep_features_list, 0) 88 | if is_train: 89 | model.train() 90 | return deep_features_ 91 | 92 | 93 | def euclidean_distance(u, v): 94 | euclidean_distance_ = (u - v).pow(2).sum(1) 95 | return euclidean_distance_ 96 | 97 | 98 | def ohe_label(label_tensor, dim, device="cpu"): 99 | # Returns one-hot-encoding of input label tensor 100 | n_labels = label_tensor.size(0) 101 | zero_tensor = torch.zeros((n_labels, dim), device=device, dtype=torch.long) 102 | return zero_tensor.scatter_(1, label_tensor.reshape((n_labels, 1)), 1) 103 | 104 | 105 | def nonzero_indices(bool_mask_tensor): 106 | # Returns tensor which contains indices of nonzero elements in bool_mask_tensor 107 | return bool_mask_tensor.nonzero(as_tuple=True)[0] 108 | 109 | 110 | class EarlyStopping(): 111 | def __init__(self, min_delta, patience, cumulative_delta): 112 | self.min_delta = min_delta 113 | self.patience = patience 114 | self.cumulative_delta = cumulative_delta 115 | self.counter = 0 116 | self.best_score = None 117 | 118 | def step(self, score): 119 | if self.best_score is None: 120 | self.best_score = score 121 | elif score <= self.best_score + self.min_delta: 122 | if not self.cumulative_delta and score > self.best_score: 123 | self.best_score = score 124 | self.counter += 1 125 | if self.counter >= self.patience: 126 | return True 127 | else: 128 | self.best_score = score 129 | self.counter = 0 130 | return False 131 | 132 | def reset(self): 133 | self.counter = 0 134 | self.best_score = None 135 | -------------------------------------------------------------------------------- /utils/buffer/buffer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | from utils.setup_elements import input_size_match 7 | from utils import name_match 8 | from utils.utils import maybe_cuda 9 | from utils.buffer.buffer_utils import BufferClassTracker 10 | from utils.setup_elements import n_classes 11 | 12 | 13 | class Buffer(torch.nn.Module): 14 | def __init__(self, model, params): 15 | super().__init__() 16 | self.params = params 17 | self.model = model 18 | self.cuda = self.params.cuda 19 | self.current_index = 0 20 | self.n_seen_so_far = 0 21 | self.device = "cuda" if self.params.cuda else "cpu" 22 | self.num_classes_per_task = self.params.num_classes_per_task 23 | self.num_classes = 0 24 | 25 | # define buffer 26 | buffer_size = params.mem_size 27 | print('buffer has %d slots' % buffer_size) 28 | input_size = input_size_match[params.data] 29 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0)) 30 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0)) 31 | 32 | # registering as buffer allows us to save the object using `torch.save` 33 | self.register_buffer('buffer_img', buffer_img) 34 | self.register_buffer('buffer_label', buffer_label) 35 | self.labeldict = defaultdict(list) 36 | self.labelsize = params.images_per_class 37 | self.avail_indices = list(np.arange(buffer_size)) 38 | 39 | # define update and retrieve method 40 | self.update_method = name_match.update_methods[params.update](params) 41 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params) 42 | 43 | if self.params.buffer_tracker: 44 | self.buffer_tracker = BufferClassTracker(n_classes[params.data], self.device) 45 | 46 | def update(self, x, y,**kwargs): 47 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs) 48 | 49 | def retrieve(self, **kwargs): 50 | return self.retrieve_method.retrieve(buffer=self, **kwargs) 51 | 52 | def new_task(self, **kwargs): 53 | self.num_classes += self.num_classes_per_task 54 | self.labelsize = self.params.mem_size // self.num_classes 55 | 56 | def new_condense_task(self, **kwargs): 57 | self.num_classes += self.num_classes_per_task 58 | self.update_method.new_task(self.num_classes) 59 | 60 | 61 | class DynamicBuffer(torch.nn.Module): 62 | def __init__(self, model, params): 63 | super().__init__() 64 | self.params = params 65 | self.model = model 66 | self.cuda = self.params.cuda 67 | self.current_index = 0 68 | self.n_seen_so_far = 0 69 | self.device = "cuda" if self.params.cuda else "cpu" 70 | self.num_classes_per_task = self.params.num_classes_per_task 71 | self.images_per_class = self.params.images_per_class 72 | self.num_classes = 0 73 | 74 | # define buffer 75 | buffer_size = params.mem_size 76 | print('buffer has %d slots' % buffer_size) 77 | input_size = input_size_match[params.data] 78 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0)) 79 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0)) 80 | 81 | # registering as buffer allows us to save the object using `torch.save` 82 | self.register_buffer('buffer_img', buffer_img) 83 | self.register_buffer('buffer_img_rep', copy.deepcopy(buffer_img)) 84 | self.register_buffer('buffer_label', buffer_label) 85 | self.condense_dict = defaultdict(list) 86 | self.labelsize = params.images_per_class 87 | self.avail_indices = list(np.arange(buffer_size)) 88 | 89 | # define update and retrieve method 90 | self.update_method = name_match.update_methods[params.update](params) 91 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params) 92 | 93 | def update(self, x, y,**kwargs): 94 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs) 95 | 96 | def retrieve(self, **kwargs): 97 | return self.retrieve_method.retrieve(buffer=self, **kwargs) 98 | 99 | def new_task(self, **kwargs): 100 | self.num_classes += self.num_classes_per_task 101 | self.labelsize = self.params.mem_size // self.num_classes 102 | 103 | def new_condense_task(self, labels, **kwargs): 104 | self.num_classes += self.num_classes_per_task 105 | self.update_method.new_task(self.num_classes, labels) 106 | 107 | -------------------------------------------------------------------------------- /agents/agem.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.buffer.buffer import Buffer 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class AGEM(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(AGEM, self).__init__(model, opt, params) 13 | self.buffer = Buffer(model, params) 14 | self.mem_size = params.mem_size 15 | self.eps_mem_batch = params.eps_mem_batch 16 | self.mem_iters = params.mem_iters 17 | 18 | def train_learner(self, x_train, y_train): 19 | self.before_train(x_train, y_train) 20 | 21 | # set up loader 22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 24 | drop_last=True) 25 | # set up model 26 | self.model = self.model.train() 27 | 28 | # setup tracker 29 | losses_batch = AverageMeter() 30 | acc_batch = AverageMeter() 31 | 32 | for ep in range(self.epoch): 33 | for i, batch_data in enumerate(train_loader): 34 | # batch update 35 | batch_x, batch_y = batch_data 36 | batch_x = maybe_cuda(batch_x, self.cuda) 37 | batch_y = maybe_cuda(batch_y, self.cuda) 38 | for j in range(self.mem_iters): 39 | logits = self.forward(batch_x) 40 | loss = self.criterion(logits, batch_y) 41 | if self.params.trick['kd_trick']: 42 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 43 | self.kd_manager.get_kd_loss(logits, batch_x) 44 | if self.params.trick['kd_trick_star']: 45 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 46 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x) 47 | _, pred_label = torch.max(logits, 1) 48 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 49 | # update tracker 50 | acc_batch.update(correct_cnt, batch_y.size(0)) 51 | losses_batch.update(loss, batch_y.size(0)) 52 | # backward 53 | self.opt.zero_grad() 54 | loss.backward() 55 | 56 | if self.task_seen > 0: 57 | # sample from memory of previous tasks 58 | mem_x, mem_y = self.buffer.retrieve() 59 | if mem_x.size(0) > 0: 60 | params = [p for p in self.model.parameters() if p.requires_grad] 61 | # gradient computed using current batch 62 | grad = [p.grad.clone() for p in params] 63 | mem_x = maybe_cuda(mem_x, self.cuda) 64 | mem_y = maybe_cuda(mem_y, self.cuda) 65 | mem_logits = self.forward(mem_x) 66 | loss_mem = self.criterion(mem_logits, mem_y) 67 | self.opt.zero_grad() 68 | loss_mem.backward() 69 | # gradient computed using memory samples 70 | grad_ref = [p.grad.clone() for p in params] 71 | 72 | # inner product of grad and grad_ref 73 | prod = sum([torch.sum(g * g_r) for g, g_r in zip(grad, grad_ref)]) 74 | if prod < 0: 75 | prod_ref = sum([torch.sum(g_r ** 2) for g_r in grad_ref]) 76 | # do projection 77 | grad = [g - prod / prod_ref * g_r for g, g_r in zip(grad, grad_ref)] 78 | # replace params' grad 79 | for g, p in zip(grad, params): 80 | p.grad.data.copy_(g) 81 | self.opt.step() 82 | # update mem 83 | self.buffer.update(batch_x, batch_y) 84 | 85 | if i % 100 == 1 and self.verbose: 86 | print( 87 | '==>>> it: {}, avg. loss: {:.6f}, ' 88 | 'running train acc: {:.3f}' 89 | .format(i, losses_batch.avg(), acc_batch.avg()) 90 | ) 91 | self.after_train() -------------------------------------------------------------------------------- /continuum/dataset_scripts/openloris.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from PIL import Image 3 | import numpy as np 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | import time 6 | from continuum.data_utils import shuffle_data 7 | 8 | 9 | class OpenLORIS(DatasetBase): 10 | """ 11 | tasks_nums is predefined and it depends on the ns_type. 12 | """ 13 | def __init__(self, scenario, params): # scenario refers to "ni" or "nc" 14 | dataset = 'openloris' 15 | self.ns_type = params.ns_type 16 | task_nums = openloris_ntask[self.ns_type] # ns_type can be (illumination, occlusion, pixel, clutter, sequence) 17 | super(OpenLORIS, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | s = time.time() 22 | self.train_set = [] 23 | for batch_num in range(1, self.task_nums+1): 24 | train_x = [] 25 | train_y = [] 26 | test_x = [] 27 | test_y = [] 28 | for i in range(len(datapath)): 29 | train_temp = glob.glob('datasets/openloris/' + self.ns_type + '/train/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 30 | 31 | train_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in train_temp]) 32 | train_y.extend([i] * len(train_temp)) 33 | 34 | test_temp = glob.glob( 35 | 'datasets/openloris/' + self.ns_type + '/test/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 36 | 37 | test_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in test_temp]) 38 | test_y.extend([i] * len(test_temp)) 39 | 40 | print(" --> batch{}'-dataset consisting of {} samples".format(batch_num, len(train_x))) 41 | print(" --> test'-dataset consisting of {} samples".format(len(test_x))) 42 | self.train_set.append((np.array(train_x), np.array(train_y))) 43 | self.test_set.append((np.array(test_x), np.array(test_y))) 44 | e = time.time() 45 | print('loading time: {}'.format(str(e - s))) 46 | 47 | def new_run(self, **kwargs): 48 | pass 49 | 50 | def new_task(self, cur_task, **kwargs): 51 | train_x, train_y = self.train_set[cur_task] 52 | # get val set 53 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 54 | val_size = int(len(train_x_rdm) * self.params.val_size) 55 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 56 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 57 | self.val_set.append((val_data_rdm, val_label_rdm)) 58 | labels = set(train_label_rdm) 59 | return train_data_rdm, train_label_rdm, labels 60 | 61 | def setup(self, **kwargs): 62 | pass 63 | 64 | 65 | 66 | openloris_ntask = { 67 | 'illumination': 9, 68 | 'occlusion': 9, 69 | 'pixel': 9, 70 | 'clutter': 9, 71 | 'sequence': 12 72 | } 73 | 74 | datapath = ['bottle_01', 'bottle_02', 'bottle_03', 'bottle_04', 'bowl_01', 'bowl_02', 'bowl_03', 'bowl_04', 'bowl_05', 75 | 'corkscrew_01', 'cottonswab_01', 'cottonswab_02', 'cup_01', 'cup_02', 'cup_03', 'cup_04', 'cup_05', 76 | 'cup_06', 'cup_07', 'cup_08', 'cup_10', 'cushion_01', 'cushion_02', 'cushion_03', 'glasses_01', 77 | 'glasses_02', 'glasses_03', 'glasses_04', 'knife_01', 'ladle_01', 'ladle_02', 'ladle_03', 'ladle_04', 78 | 'mask_01', 'mask_02', 'mask_03', 'mask_04', 'mask_05', 'paper_cutter_01', 'paper_cutter_02', 79 | 'paper_cutter_03', 'paper_cutter_04', 'pencil_01', 'pencil_02', 'pencil_03', 'pencil_04', 'pencil_05', 80 | 'plasticbag_01', 'plasticbag_02', 'plasticbag_03', 'plug_01', 'plug_02', 'plug_03', 'plug_04', 'pot_01', 81 | 'scissors_01', 'scissors_02', 'scissors_03', 'stapler_01', 'stapler_02', 'stapler_03', 'thermometer_01', 82 | 'thermometer_02', 'thermometer_03', 'toy_01', 'toy_02', 'toy_03', 'toy_04', 'toy_05','nail_clippers_01','nail_clippers_02', 83 | 'nail_clippers_03', 'bracelet_01', 'bracelet_02','bracelet_03', 'comb_01','comb_02', 84 | 'comb_03', 'umbrella_01','umbrella_02','umbrella_03','socks_01','socks_02','socks_03', 85 | 'toothpaste_01','toothpaste_02','toothpaste_03','wallet_01','wallet_02','wallet_03', 86 | 'headphone_01','headphone_02','headphone_03', 'key_01','key_02','key_03', 87 | 'battery_01', 'battery_02', 'mouse_01', 'pencilcase_01', 'pencilcase_02', 'tape_01', 88 | 'chopsticks_01', 'chopsticks_02', 'chopsticks_03', 89 | 'notebook_01', 'notebook_02', 'notebook_03', 90 | 'spoon_01', 'spoon_02', 'spoon_03', 91 | 'tissue_01', 'tissue_02', 'tissue_03', 92 | 'clamp_01', 'clamp_02', 'hat_01', 'hat_02', 'u_disk_01', 'u_disk_02', 'swimming_glasses_01' 93 | ] 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /agents/ewc_pp.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | 8 | class EWC_pp(ContinualLearner): 9 | def __init__(self, model, opt, params): 10 | super(EWC_pp, self).__init__(model, opt, params) 11 | self.weights = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 12 | self.lambda_ = params.lambda_ 13 | self.alpha = params.alpha 14 | self.fisher_update_after = params.fisher_update_after 15 | self.prev_params = {} 16 | self.running_fisher = self.init_fisher() 17 | self.tmp_fisher = self.init_fisher() 18 | self.normalized_fisher = self.init_fisher() 19 | 20 | def train_learner(self, x_train, y_train): 21 | self.before_train(x_train, y_train) 22 | # set up loader 23 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 24 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 25 | drop_last=True) 26 | # setup tracker 27 | losses_batch = AverageMeter() 28 | acc_batch = AverageMeter() 29 | 30 | # set up model 31 | self.model.train() 32 | 33 | for ep in range(self.epoch): 34 | for i, batch_data in enumerate(train_loader): 35 | # batch update 36 | batch_x, batch_y = batch_data 37 | batch_x = maybe_cuda(batch_x, self.cuda) 38 | batch_y = maybe_cuda(batch_y, self.cuda) 39 | 40 | # update the running fisher 41 | if (ep * len(train_loader) + i + 1) % self.fisher_update_after == 0: 42 | self.update_running_fisher() 43 | 44 | out = self.forward(batch_x) 45 | loss = self.total_loss(out, batch_y) 46 | if self.params.trick['kd_trick']: 47 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 48 | self.kd_manager.get_kd_loss(out, batch_x) 49 | if self.params.trick['kd_trick_star']: 50 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 51 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x) 52 | # update tracker 53 | losses_batch.update(loss.item(), batch_y.size(0)) 54 | _, pred_label = torch.max(out, 1) 55 | acc = (pred_label == batch_y).sum().item() / batch_y.size(0) 56 | acc_batch.update(acc, batch_y.size(0)) 57 | # backward 58 | self.opt.zero_grad() 59 | loss.backward() 60 | 61 | # accumulate the fisher of current batch 62 | self.accum_fisher() 63 | self.opt.step() 64 | 65 | if i % 100 == 1 and self.verbose: 66 | print( 67 | '==>>> it: {}, avg. loss: {:.6f}, ' 68 | 'running train acc: {:.3f}' 69 | .format(i, losses_batch.avg(), acc_batch.avg()) 70 | ) 71 | 72 | # save params for current task 73 | for n, p in self.weights.items(): 74 | self.prev_params[n] = p.clone().detach() 75 | 76 | # update normalized fisher of current task 77 | max_fisher = max([torch.max(m) for m in self.running_fisher.values()]) 78 | min_fisher = min([torch.min(m) for m in self.running_fisher.values()]) 79 | for n, p in self.running_fisher.items(): 80 | self.normalized_fisher[n] = (p - min_fisher) / (max_fisher - min_fisher + 1e-32) 81 | self.after_train() 82 | 83 | def total_loss(self, inputs, targets): 84 | # cross entropy loss 85 | loss = self.criterion(inputs, targets) 86 | if len(self.prev_params) > 0: 87 | # add regularization loss 88 | reg_loss = 0 89 | for n, p in self.weights.items(): 90 | reg_loss += (self.normalized_fisher[n] * (p - self.prev_params[n]) ** 2).sum() 91 | loss += self.lambda_ * reg_loss 92 | return loss 93 | 94 | def init_fisher(self): 95 | return {n: p.clone().detach().fill_(0) for n, p in self.model.named_parameters() if p.requires_grad} 96 | 97 | def update_running_fisher(self): 98 | for n, p in self.running_fisher.items(): 99 | self.running_fisher[n] = (1. - self.alpha) * p \ 100 | + 1. / self.fisher_update_after * self.alpha * self.tmp_fisher[n] 101 | # reset the accumulated fisher 102 | self.tmp_fisher = self.init_fisher() 103 | 104 | def accum_fisher(self): 105 | for n, p in self.tmp_fisher.items(): 106 | p += self.weights[n].grad ** 2 -------------------------------------------------------------------------------- /continuum/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils import data 4 | from utils.setup_elements import transforms_match 5 | from collections import defaultdict 6 | import copy 7 | 8 | def create_task_composition(class_nums, num_tasks, fixed_order=False): 9 | classes_per_task = class_nums // num_tasks 10 | total_classes = classes_per_task * num_tasks 11 | label_array = np.arange(0, total_classes) 12 | if not fixed_order: 13 | np.random.shuffle(label_array) 14 | 15 | task_labels = [] 16 | for tt in range(num_tasks): 17 | tt_offset = tt * classes_per_task 18 | task_labels.append(list(label_array[tt_offset:tt_offset + classes_per_task])) 19 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 20 | return task_labels 21 | 22 | 23 | def load_task_with_labels_torch(x, y, labels): 24 | tmp = [] 25 | for i in labels: 26 | tmp.append((y == i).nonzero().view(-1)) 27 | idx = torch.cat(tmp) 28 | return x[idx], y[idx] 29 | 30 | 31 | def load_task_with_labels(x, y, labels): 32 | tmp = [] 33 | for i in labels: 34 | tmp.append((np.where(y == i)[0])) 35 | idx = np.concatenate(tmp, axis=None) 36 | return x[idx], y[idx] 37 | 38 | 39 | 40 | class dataset_transform(data.Dataset): 41 | def __init__(self, x, y, transform=None): 42 | self.x = x 43 | self.y = torch.from_numpy(y).type(torch.LongTensor) 44 | self.transform = transform # save the transform 45 | 46 | def __len__(self): 47 | return len(self.y)#self.x.shape[0] # return 1 as we have only one image 48 | 49 | def __getitem__(self, idx): 50 | # return the augmented image 51 | if self.transform: 52 | x = self.transform(self.x[idx]) 53 | else: 54 | x = self.x[idx] 55 | 56 | return x.float(), self.y[idx] 57 | 58 | 59 | class BalancedSampler(data.Sampler): 60 | def __init__(self, x, y, batch_size): 61 | self.x = x 62 | self.y = y 63 | self.batch_size = batch_size 64 | self.labeldict = defaultdict(list) 65 | for idx, label in enumerate(y): 66 | self.labeldict[label].append(idx) 67 | self.labelset = set(self.y) 68 | self.num_classes = len(set(self.y)) 69 | self.num_instances = batch_size // self.num_classes 70 | 71 | def __iter__(self): 72 | batch_idx_dict = defaultdict(list) 73 | for label in self.labelset: 74 | indices = copy.deepcopy(self.labeldict[label]) 75 | if len(indices) < self.num_instances: 76 | indices = np.random.choice(indices, size=self.num_instances, replace=True) 77 | np.random.shuffle(indices) 78 | batch_idx = [] 79 | for idx in indices: 80 | batch_idx.append(idx) 81 | if len(batch_idx) == self.num_instances: 82 | batch_idx_dict[label].append(batch_idx) 83 | batch_idx = [] 84 | 85 | avail_labels = copy.deepcopy(self.labelset) 86 | final_indices = [] 87 | 88 | while len(avail_labels) >= self.num_classes: 89 | batch_indices = [] 90 | for label in self.labelset: 91 | batch_idx = batch_idx_dict[label].pop(0) 92 | batch_indices.extend(batch_idx) 93 | if len(batch_idx_dict[label]) == 0: 94 | avail_labels.remove(label) 95 | np.random.shuffle(batch_indices) 96 | final_indices.extend(batch_indices) 97 | 98 | return iter(final_indices) 99 | 100 | def __len__(self): 101 | return len(self.y) // self.batch_size 102 | 103 | 104 | def setup_test_loader(test_data, params): 105 | test_loaders = [] 106 | 107 | for (x_test, y_test) in test_data: 108 | test_dataset = dataset_transform(x_test, y_test, transform=transforms_match[params.data]) 109 | test_loader = data.DataLoader(test_dataset, batch_size=params.test_batch, shuffle=True, num_workers=0) 110 | test_loaders.append(test_loader) 111 | return test_loaders 112 | 113 | 114 | def shuffle_data(x, y): 115 | perm_inds = np.arange(0, x.shape[0]) 116 | np.random.shuffle(perm_inds) 117 | rdm_x = x[perm_inds] 118 | rdm_y = y[perm_inds] 119 | return rdm_x, rdm_y 120 | 121 | 122 | def train_val_test_split_ni(train_data, train_label, test_data, test_label, task_nums, img_size, val_size=0.1): 123 | train_data_rdm, train_label_rdm = shuffle_data(train_data, train_label) 124 | val_size = int(len(train_data_rdm) * val_size) 125 | val_data_rdm, val_label_rdm = train_data_rdm[:val_size], train_label_rdm[:val_size] 126 | train_data_rdm, train_label_rdm = train_data_rdm[val_size:], train_label_rdm[val_size:] 127 | test_data_rdm, test_label_rdm = shuffle_data(test_data, test_label) 128 | train_data_rdm_split = train_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 129 | train_label_rdm_split = train_label_rdm.reshape(task_nums, -1) 130 | val_data_rdm_split = val_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 131 | val_label_rdm_split = val_label_rdm.reshape(task_nums, -1) 132 | test_data_rdm_split = test_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 133 | test_label_rdm_split = test_label_rdm.reshape(task_nums, -1) 134 | return train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split 135 | -------------------------------------------------------------------------------- /utils/buffer/aser_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.reservoir_update import Reservoir_update 3 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve 4 | from utils.buffer.aser_utils import compute_knn_sv, add_minority_class_input 5 | from utils.setup_elements import n_classes 6 | from utils.utils import nonzero_indices, maybe_cuda 7 | 8 | 9 | class ASER_update(object): 10 | def __init__(self, params, **kwargs): 11 | super().__init__() 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.num_tasks = params.num_tasks 16 | self.out_dim = n_classes[params.data] 17 | self.n_smp_cls = int(params.n_smp_cls) 18 | self.n_total_smp = int(params.n_smp_cls * self.out_dim) 19 | self.reservoir_update = Reservoir_update(params) 20 | ClassBalancedRandomSampling.class_index_cache = None 21 | 22 | def update(self, buffer, x, y, **kwargs): 23 | model = buffer.model 24 | 25 | place_left = self.mem_size - buffer.current_index 26 | 27 | # If buffer is not filled, use available space to store whole or part of batch 28 | if place_left: 29 | x_fit = x[:place_left] 30 | y_fit = y[:place_left] 31 | 32 | ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device) 33 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 34 | new_y=y_fit, ind=ind, device=self.device) 35 | self.reservoir_update.update(buffer, x_fit, y_fit) 36 | 37 | # If buffer is filled, update buffer by sv 38 | if buffer.current_index == self.mem_size: 39 | # remove what is already in the buffer 40 | cur_x, cur_y = x[place_left:], y[place_left:] 41 | self._update_by_knn_sv(model, buffer, cur_x, cur_y) 42 | 43 | def _update_by_knn_sv(self, model, buffer, cur_x, cur_y): 44 | """ 45 | Returns indices for replacement. 46 | Buffered instances with smallest SV are replaced by current input with higher SV. 47 | Args: 48 | model (object): neural network. 49 | buffer (object): buffer object. 50 | cur_x (tensor): current input data tensor. 51 | cur_y (tensor): current input label tensor. 52 | Returns 53 | ind_buffer (tensor): indices of buffered instances to be replaced. 54 | ind_cur (tensor): indices of current data to do replacement. 55 | """ 56 | cur_x = maybe_cuda(cur_x) 57 | cur_y = maybe_cuda(cur_y) 58 | 59 | # Find minority class samples from current input batch 60 | minority_batch_x, minority_batch_y = add_minority_class_input(cur_x, cur_y, self.mem_size, self.out_dim) 61 | 62 | # Evaluation set 63 | eval_x, eval_y, eval_indices = \ 64 | ClassBalancedRandomSampling.sample(buffer.buffer_img, buffer.buffer_label, self.n_smp_cls, 65 | device=self.device) 66 | 67 | # Concatenate minority class samples from current input batch to evaluation set 68 | eval_x = torch.cat((eval_x, minority_batch_x)) 69 | eval_y = torch.cat((eval_y, minority_batch_y)) 70 | 71 | # Candidate set 72 | cand_excl_indices = set(eval_indices.tolist()) 73 | cand_x, cand_y, cand_ind = random_retrieve(buffer, self.n_total_smp, cand_excl_indices, return_indices=True) 74 | 75 | # Concatenate current input batch to candidate set 76 | cand_x = torch.cat((cand_x, cur_x)) 77 | cand_y = torch.cat((cand_y, cur_y)) 78 | 79 | sv_matrix = compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, self.k, device=self.device) 80 | sv = sv_matrix.sum(0) 81 | 82 | n_cur = cur_x.size(0) 83 | n_cand = cand_x.size(0) 84 | 85 | # Number of previously buffered instances in candidate set 86 | n_cand_buf = n_cand - n_cur 87 | 88 | sv_arg_sort = sv.argsort(descending=True) 89 | 90 | # Divide SV array into two segments 91 | # - large: candidate args to be retained; small: candidate args to be discarded 92 | sv_arg_large = sv_arg_sort[:n_cand_buf] 93 | sv_arg_small = sv_arg_sort[n_cand_buf:] 94 | 95 | # Extract args relevant to replacement operation 96 | # If current data instances are in 'large' segment, they are added to buffer 97 | # If buffered instances are in 'small' segment, they are discarded from buffer 98 | # Replacement happens between these two sets 99 | # Retrieve original indices from candidate args 100 | ind_cur = sv_arg_large[nonzero_indices(sv_arg_large >= n_cand_buf)] - n_cand_buf 101 | arg_buffer = sv_arg_small[nonzero_indices(sv_arg_small < n_cand_buf)] 102 | ind_buffer = cand_ind[arg_buffer] 103 | 104 | buffer.n_seen_so_far += n_cur 105 | 106 | # perform overwrite op 107 | y_upt = cur_y[ind_cur] 108 | x_upt = cur_x[ind_cur] 109 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 110 | new_y=y_upt, ind=ind_buffer, device=self.device) 111 | buffer.buffer_img[ind_buffer] = x_upt 112 | buffer.buffer_label[ind_buffer] = y_upt 113 | -------------------------------------------------------------------------------- /agents/exp_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match 7 | from utils.utils import maybe_cuda, AverageMeter 8 | 9 | 10 | class ExperienceReplay(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(ExperienceReplay, self).__init__(model, opt, params) 13 | self.buffer = Buffer(model, params) 14 | self.mem_size = params.mem_size 15 | self.eps_mem_batch = params.eps_mem_batch 16 | self.mem_iters = params.mem_iters 17 | 18 | def train_learner(self, x_train, y_train, **kwargs): 19 | self.before_train(x_train, y_train) 20 | # set up loader 21 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 22 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 23 | drop_last=True) 24 | # set up model 25 | self.model = self.model.train() 26 | 27 | # setup tracker 28 | losses_batch = AverageMeter() 29 | losses_mem = AverageMeter() 30 | acc_batch = AverageMeter() 31 | acc_mem = AverageMeter() 32 | 33 | for ep in range(self.epoch): 34 | for i, batch_data in enumerate(train_loader): 35 | # batch update 36 | batch_x, batch_y = batch_data 37 | batch_x = maybe_cuda(batch_x, self.cuda) 38 | batch_y = maybe_cuda(batch_y, self.cuda) 39 | for j in range(self.mem_iters): 40 | logits = self.model.forward(batch_x) 41 | loss = self.criterion(logits, batch_y) 42 | if self.params.trick['kd_trick']: 43 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 44 | self.kd_manager.get_kd_loss(logits, batch_x) 45 | if self.params.trick['kd_trick_star']: 46 | loss = 1/((self.task_seen + 1) ** 0.5) * loss + \ 47 | (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x) 48 | _, pred_label = torch.max(logits, 1) 49 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 50 | # update tracker 51 | acc_batch.update(correct_cnt, batch_y.size(0)) 52 | losses_batch.update(loss, batch_y.size(0)) 53 | # backward 54 | self.opt.zero_grad() 55 | loss.backward() 56 | 57 | # mem update 58 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 59 | if mem_x.size(0) > 0: 60 | mem_x = maybe_cuda(mem_x, self.cuda) 61 | mem_y = maybe_cuda(mem_y, self.cuda) 62 | mem_logits = self.model.forward(mem_x) 63 | loss_mem = self.criterion(mem_logits, mem_y) 64 | if self.params.trick['kd_trick']: 65 | loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \ 66 | self.kd_manager.get_kd_loss(mem_logits, mem_x) 67 | if self.params.trick['kd_trick_star']: 68 | loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \ 69 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(mem_logits, 70 | mem_x) 71 | # update tracker 72 | losses_mem.update(loss_mem, mem_y.size(0)) 73 | _, pred_label = torch.max(mem_logits, 1) 74 | correct_cnt = (pred_label == mem_y).sum().item() / mem_y.size(0) 75 | acc_mem.update(correct_cnt, mem_y.size(0)) 76 | 77 | loss_mem.backward() 78 | 79 | if self.params.update == 'ASER' or self.params.retrieve == 'ASER': 80 | # opt update 81 | self.opt.zero_grad() 82 | combined_batch = torch.cat((mem_x, batch_x)) 83 | combined_labels = torch.cat((mem_y, batch_y)) 84 | combined_logits = self.model.forward(combined_batch) 85 | loss_combined = self.criterion(combined_logits, combined_labels) 86 | loss_combined.backward() 87 | self.opt.step() 88 | else: 89 | self.opt.step() 90 | 91 | # update mem 92 | self.buffer.update(batch_x, batch_y) 93 | 94 | if i % 100 == 1 and self.verbose: 95 | print( 96 | '==>>> it: {}, avg. loss: {:.6f}, ' 97 | 'running train acc: {:.3f}' 98 | .format(i, losses_batch.avg(), acc_batch.avg()) 99 | ) 100 | print( 101 | '==>>> it: {}, mem avg. loss: {:.6f}, ' 102 | 'running mem acc: {:.3f}' 103 | .format(i, losses_mem.avg(), acc_mem.avg()) 104 | ) 105 | self.after_train() 106 | -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class ConvNet(nn.Module): 8 | def __init__(self, 9 | num_classes, 10 | net_norm='instance', 11 | net_depth=3, 12 | net_width=128, 13 | channel=3, 14 | net_act='relu', 15 | net_pooling='avgpooling', 16 | im_size=(32, 32)): 17 | super(ConvNet, self).__init__() 18 | if net_act == 'sigmoid': 19 | self.net_act = nn.Sigmoid() 20 | elif net_act == 'relu': 21 | self.net_act = nn.ReLU() 22 | elif net_act == 'leakyrelu': 23 | self.net_act = nn.LeakyReLU(negative_slope=0.01) 24 | else: 25 | exit('unknown activation function: %s' % net_act) 26 | 27 | if net_pooling == 'maxpooling': 28 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2) 29 | elif net_pooling == 'avgpooling': 30 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2) 31 | elif net_pooling == 'none': 32 | self.net_pooling = None 33 | else: 34 | exit('unknown net_pooling: %s' % net_pooling) 35 | 36 | self.depth = net_depth 37 | self.net_norm = net_norm 38 | 39 | self.layers, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, 40 | net_pooling, im_size) 41 | num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] 42 | self.num_feat = num_feat 43 | self.classifier = nn.Linear(num_feat, num_classes) 44 | 45 | def forward(self, x, return_features=False): 46 | for d in range(self.depth): 47 | x = self.layers['conv'][d](x) 48 | if len(self.layers['norm']) > 0: 49 | x = self.layers['norm'][d](x) 50 | x = self.layers['act'][d](x) 51 | if len(self.layers['pool']) > 0: 52 | x = self.layers['pool'][d](x) 53 | 54 | out = x.view(x.shape[0], -1) 55 | logit = self.classifier(out) 56 | 57 | if return_features: 58 | return logit, out 59 | else: 60 | return logit 61 | 62 | def features(self, x): 63 | for d in range(self.depth): 64 | x = self.layers['conv'][d](x) 65 | if len(self.layers['norm']) > 0: 66 | x = self.layers['norm'][d](x) 67 | x = self.layers['act'][d](x) 68 | if len(self.layers['pool']) > 0: 69 | x = self.layers['pool'][d](x) 70 | 71 | out = x.view(x.shape[0], -1) 72 | 73 | return out 74 | 75 | def get_feature(self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False): 76 | if idx_to == -1: 77 | idx_to = idx_from 78 | features = [] 79 | 80 | for d in range(self.depth): 81 | x = self.layers['conv'][d](x) 82 | if self.net_norm: 83 | x = self.layers['norm'][d](x) 84 | x = self.layers['act'][d](x) 85 | if self.net_pooling: 86 | x = self.layers['pool'][d](x) 87 | features.append(x) 88 | if idx_to < len(features): 89 | return features[idx_from:idx_to + 1] 90 | 91 | if return_prob: 92 | out = x.view(x.size(0), -1) 93 | logit = self.classifier(out) 94 | prob = torch.softmax(logit, dim=-1) 95 | return features, prob 96 | elif return_logit: 97 | out = x.view(x.size(0), -1) 98 | logit = self.classifier(out) 99 | return features, logit 100 | else: 101 | return features[idx_from:idx_to + 1] 102 | 103 | def _get_normlayer(self, net_norm, shape_feat): 104 | if net_norm == 'batch': 105 | norm = nn.BatchNorm2d(shape_feat[0], affine=True) 106 | elif net_norm == 'layer': 107 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True) 108 | elif net_norm == 'instance': 109 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 110 | elif net_norm == 'group': 111 | norm = nn.GroupNorm(4, shape_feat[0], affine=True) 112 | elif net_norm == 'none': 113 | norm = None 114 | else: 115 | norm = None 116 | exit('unknown net_norm: %s' % net_norm) 117 | return norm 118 | 119 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_pooling, im_size): 120 | layers = {'conv': [], 'norm': [], 'act': [], 'pool': []} 121 | 122 | in_channels = channel 123 | if im_size[0] == 28: 124 | im_size = (32, 32) 125 | shape_feat = [in_channels, im_size[0], im_size[1]] 126 | 127 | for d in range(net_depth): 128 | layers['conv'] += [ 129 | nn.Conv2d(in_channels, 130 | net_width, 131 | kernel_size=3, 132 | padding=3 if channel == 1 and d == 0 else 1) 133 | ] 134 | shape_feat[0] = net_width 135 | if net_norm != 'none': 136 | layers['norm'] += [self._get_normlayer(net_norm, shape_feat)] 137 | layers['act'] += [self.net_act] 138 | in_channels = net_width 139 | if net_pooling != 'none': 140 | layers['pool'] += [self.net_pooling] 141 | shape_feat[1] //= 2 142 | shape_feat[2] //= 2 143 | 144 | layers['conv'] = nn.ModuleList(layers['conv']) 145 | layers['norm'] = nn.ModuleList(layers['norm']) 146 | layers['act'] = nn.ModuleList(layers['act']) 147 | layers['pool'] = nn.ModuleList(layers['pool']) 148 | layers = nn.ModuleDict(layers) 149 | 150 | return layers, shape_feat 151 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/core50.py: -------------------------------------------------------------------------------- 1 | import os 2 | from continuum.dataset_scripts.dataset_base import DatasetBase 3 | import pickle as pkl 4 | import logging 5 | from hashlib import md5 6 | import numpy as np 7 | from PIL import Image 8 | from continuum.data_utils import shuffle_data, load_task_with_labels 9 | import time 10 | 11 | core50_ntask = { 12 | 'ni': 8, 13 | 'nc': 9, 14 | 'nic': 79, 15 | 'nicv2_79': 79, 16 | 'nicv2_196': 196, 17 | 'nicv2_391': 391 18 | } 19 | 20 | 21 | class CORE50(DatasetBase): 22 | def __init__(self, scenario, params): 23 | if isinstance(params.num_runs, int) and params.num_runs > 10: 24 | raise Exception('the max number of runs for CORE50 is 10') 25 | dataset = 'core50' 26 | task_nums = core50_ntask[scenario] 27 | super(CORE50, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 28 | 29 | 30 | def download_load(self): 31 | 32 | print("Loading paths...") 33 | with open(os.path.join(self.root, 'paths.pkl'), 'rb') as f: 34 | self.paths = pkl.load(f) 35 | 36 | print("Loading LUP...") 37 | with open(os.path.join(self.root, 'LUP.pkl'), 'rb') as f: 38 | self.LUP = pkl.load(f) 39 | 40 | print("Loading labels...") 41 | with open(os.path.join(self.root, 'labels.pkl'), 'rb') as f: 42 | self.labels = pkl.load(f) 43 | 44 | 45 | 46 | def setup(self, cur_run): 47 | self.val_set = [] 48 | self.test_set = [] 49 | print('Loading test set...') 50 | test_idx_list = self.LUP[self.scenario][cur_run][-1] 51 | 52 | #test paths 53 | test_paths = [] 54 | for idx in test_idx_list: 55 | test_paths.append(os.path.join(self.root, self.paths[idx])) 56 | 57 | # test imgs 58 | self.test_data = self.get_batch_from_paths(test_paths) 59 | self.test_label = np.asarray(self.labels[self.scenario][cur_run][-1]) 60 | 61 | 62 | if self.scenario == 'nc': 63 | self.task_labels = self.labels[self.scenario][cur_run][:-1] 64 | for labels in self.task_labels: 65 | labels = list(set(labels)) 66 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 67 | self.test_set.append((x_test, y_test)) 68 | elif self.scenario == 'ni': 69 | self.test_set = [(self.test_data, self.test_label)] 70 | 71 | def new_task(self, cur_task, **kwargs): 72 | cur_run = kwargs['cur_run'] 73 | s = time.time() 74 | train_idx_list = self.LUP[self.scenario][cur_run][cur_task] 75 | print("Loading data...") 76 | # Getting the actual paths 77 | train_paths = [] 78 | for idx in train_idx_list: 79 | train_paths.append(os.path.join(self.root, self.paths[idx])) 80 | # loading imgs 81 | train_x = self.get_batch_from_paths(train_paths) 82 | train_y = self.labels[self.scenario][cur_run][cur_task] 83 | train_y = np.asarray(train_y) 84 | # get val set 85 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 86 | val_size = int(len(train_x_rdm) * self.params.val_size) 87 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 88 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 89 | self.val_set.append((val_data_rdm, val_label_rdm)) 90 | e = time.time() 91 | print('loading time {}'.format(str(e-s))) 92 | return train_data_rdm, train_label_rdm, set(train_label_rdm) 93 | 94 | 95 | def new_run(self, **kwargs): 96 | cur_run = kwargs['cur_run'] 97 | self.setup(cur_run) 98 | 99 | 100 | @staticmethod 101 | def get_batch_from_paths(paths, compress=False, snap_dir='', 102 | on_the_fly=True, verbose=False): 103 | """ Given a number of abs. paths it returns the numpy array 104 | of all the images. """ 105 | 106 | # Getting root logger 107 | log = logging.getLogger('mylogger') 108 | 109 | # If we do not process data on the fly we check if the same train 110 | # filelist has been already processed and saved. If so, we load it 111 | # directly. In either case we end up returning x and y, as the full 112 | # training set and respective labels. 113 | num_imgs = len(paths) 114 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest() 115 | log.debug("Paths Hex: " + str(hexdigest)) 116 | loaded = False 117 | x = None 118 | file_path = None 119 | 120 | if compress: 121 | file_path = snap_dir + hexdigest + ".npz" 122 | if os.path.exists(file_path) and not on_the_fly: 123 | loaded = True 124 | with open(file_path, 'rb') as f: 125 | npzfile = np.load(f) 126 | x, y = npzfile['x'] 127 | else: 128 | x_file_path = snap_dir + hexdigest + "_x.bin" 129 | if os.path.exists(x_file_path) and not on_the_fly: 130 | loaded = True 131 | with open(x_file_path, 'rb') as f: 132 | x = np.fromfile(f, dtype=np.uint8) \ 133 | .reshape(num_imgs, 128, 128, 3) 134 | 135 | # Here we actually load the images. 136 | if not loaded: 137 | # Pre-allocate numpy arrays 138 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8) 139 | 140 | for i, path in enumerate(paths): 141 | if verbose: 142 | print("\r" + path + " processed: " + str(i + 1), end='') 143 | x[i] = np.array(Image.open(path)) 144 | 145 | if verbose: 146 | print() 147 | 148 | if not on_the_fly: 149 | # Then we save x 150 | if compress: 151 | with open(file_path, 'wb') as g: 152 | np.savez_compressed(g, x=x) 153 | else: 154 | x.tofile(snap_dir + hexdigest + "_x.bin") 155 | 156 | assert (x is not None), 'Problems loading data. x is None!' 157 | 158 | return x 159 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/facebookresearch/GradientEpisodicMemory 3 | & 4 | https://github.com/kuangliu/pytorch-cifar 5 | """ 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from torch.nn.functional import relu, avg_pool2d 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(in_planes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion * planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 28 | stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion * planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | out = relu(out) 37 | return out 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 47 | stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * 50 | planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion * planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion * planes, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion * planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = relu(self.bn1(self.conv1(x))) 63 | out = relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = relu(out) 67 | return out 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes, nf, bias): 71 | super(ResNet, self).__init__() 72 | self.in_planes = nf 73 | self.conv1 = conv3x3(3, nf * 1) 74 | self.bn1 = nn.BatchNorm2d(nf * 1) 75 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) 78 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) 79 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes, bias=bias) 80 | 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1] * (num_blocks - 1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def features(self, x): 91 | '''Features before FC layers''' 92 | out = relu(self.bn1(self.conv1(x))) 93 | out = self.layer1(out) 94 | out = self.layer2(out) 95 | out = self.layer3(out) 96 | out = self.layer4(out) 97 | out = avg_pool2d(out, 4) 98 | out = out.view(out.size(0), -1) 99 | return out 100 | 101 | def logits(self, x): 102 | '''Apply the last FC linear mapping to get logits''' 103 | x = self.linear(x) 104 | return x 105 | 106 | def forward(self, x): 107 | out = self.features(x) 108 | logits = self.logits(out) 109 | return logits 110 | 111 | 112 | def Reduced_ResNet18(nclasses, nf=20, bias=True): 113 | """ 114 | Reduced ResNet18 as in GEM MIR(note that nf=20). 115 | """ 116 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias) 117 | 118 | def ResNet18(nclasses, nf=64, bias=True): 119 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias) 120 | 121 | ''' 122 | See https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 123 | ''' 124 | 125 | def ResNet34(nclasses, nf=64, bias=True): 126 | return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf, bias) 127 | 128 | def ResNet50(nclasses, nf=64, bias=True): 129 | return ResNet(Bottleneck, [3, 4, 6, 3], nclasses, nf, bias) 130 | 131 | 132 | def ResNet101(nclasses, nf=64, bias=True): 133 | return ResNet(Bottleneck, [3, 4, 23, 3], nclasses, nf, bias) 134 | 135 | 136 | def ResNet152(nclasses, nf=64, bias=True): 137 | return ResNet(Bottleneck, [3, 8, 36, 3], nclasses, nf, bias) 138 | 139 | 140 | class SupConResNet(nn.Module): 141 | """backbone + projection head""" 142 | def __init__(self, dim_in=160, head='mlp', feat_dim=128): 143 | super(SupConResNet, self).__init__() 144 | self.encoder = Reduced_ResNet18(100) 145 | if head == 'linear': 146 | self.head = nn.Linear(dim_in, feat_dim) 147 | elif head == 'mlp': 148 | self.head = nn.Sequential( 149 | nn.Linear(dim_in, dim_in), 150 | nn.ReLU(inplace=True), 151 | nn.Linear(dim_in, feat_dim) 152 | ) 153 | elif head == 'None': 154 | self.head = None 155 | else: 156 | raise NotImplementedError( 157 | 'head not supported: {}'.format(head)) 158 | 159 | def forward(self, x): 160 | feat = self.encoder.features(x) 161 | if self.head: 162 | feat = F.normalize(self.head(feat), dim=1) 163 | else: 164 | feat = F.normalize(feat, dim=1) 165 | return feat 166 | 167 | def features(self, x): 168 | return self.encoder.features(x) -------------------------------------------------------------------------------- /utils/buffer/gss_greedy_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils.buffer.buffer_utils import get_grad_vector, cosine_similarity 5 | from utils.utils import maybe_cuda 6 | 7 | class GSSGreedyUpdate(object): 8 | def __init__(self, params): 9 | super().__init__() 10 | # the number of gradient vectors to estimate new samples similarity, line 5 in alg.2 11 | self.mem_strength = params.gss_mem_strength 12 | self.gss_batch_size = params.gss_batch_size 13 | self.buffer_score = maybe_cuda(torch.FloatTensor(params.mem_size).fill_(0)) 14 | 15 | def update(self, buffer, x, y, **kwargs): 16 | buffer.model.eval() 17 | 18 | grad_dims = [] 19 | for param in buffer.model.parameters(): 20 | grad_dims.append(param.data.numel()) 21 | 22 | place_left = buffer.buffer_img.size(0) - buffer.current_index 23 | if place_left <= 0: # buffer is full 24 | batch_sim, mem_grads = self.get_batch_sim(buffer, grad_dims, x, y) 25 | if batch_sim < 0: 26 | buffer_score = self.buffer_score[:buffer.current_index].cpu() 27 | buffer_sim = (buffer_score - torch.min(buffer_score)) / \ 28 | ((torch.max(buffer_score) - torch.min(buffer_score)) + 0.01) 29 | # draw candidates for replacement from the buffer 30 | index = torch.multinomial(buffer_sim, x.size(0), replacement=False) 31 | # estimate the similarity of each sample in the recieved batch 32 | # to the randomly drawn samples from the buffer. 33 | batch_item_sim = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y) 34 | # normalize to [0,1] 35 | scaled_batch_item_sim = ((batch_item_sim + 1) / 2).unsqueeze(1) 36 | buffer_repl_batch_sim = ((self.buffer_score[index] + 1) / 2).unsqueeze(1) 37 | # draw an event to decide on replacement decision 38 | outcome = torch.multinomial(torch.cat((scaled_batch_item_sim, buffer_repl_batch_sim), dim=1), 1, 39 | replacement=False) 40 | # replace samples with outcome =1 41 | added_indx = torch.arange(end=batch_item_sim.size(0)) 42 | sub_index = outcome.squeeze(1).bool() 43 | buffer.buffer_img[index[sub_index]] = x[added_indx[sub_index]].clone() 44 | buffer.buffer_label[index[sub_index]] = y[added_indx[sub_index]].clone() 45 | self.buffer_score[index[sub_index]] = batch_item_sim[added_indx[sub_index]].clone() 46 | else: 47 | offset = min(place_left, x.size(0)) 48 | x = x[:offset] 49 | y = y[:offset] 50 | # first buffer insertion 51 | if buffer.current_index == 0: 52 | batch_sample_memory_cos = torch.zeros(x.size(0)) + 0.1 53 | else: 54 | # draw random samples from buffer 55 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims) 56 | # estimate a score for each added sample 57 | batch_sample_memory_cos = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y) 58 | buffer.buffer_img[buffer.current_index:buffer.current_index + offset].data.copy_(x) 59 | buffer.buffer_label[buffer.current_index:buffer.current_index + offset].data.copy_(y) 60 | self.buffer_score[buffer.current_index:buffer.current_index + offset] \ 61 | .data.copy_(batch_sample_memory_cos) 62 | buffer.current_index += offset 63 | buffer.model.train() 64 | 65 | def get_batch_sim(self, buffer, grad_dims, batch_x, batch_y): 66 | """ 67 | Args: 68 | buffer: memory buffer 69 | grad_dims: gradient dimensions 70 | batch_x: batch images 71 | batch_y: batch labels 72 | Returns: score of current batch, gradient from memory subsets 73 | """ 74 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims) 75 | buffer.model.zero_grad() 76 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y) 77 | loss.backward() 78 | batch_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0) 79 | batch_sim = max(cosine_similarity(mem_grads, batch_grad)) 80 | return batch_sim, mem_grads 81 | 82 | def get_rand_mem_grads(self, buffer, grad_dims): 83 | """ 84 | Args: 85 | buffer: memory buffer 86 | grad_dims: gradient dimensions 87 | Returns: gradient from memory subsets 88 | """ 89 | gss_batch_size = min(self.gss_batch_size, buffer.current_index) 90 | num_mem_subs = min(self.mem_strength, buffer.current_index // gss_batch_size) 91 | mem_grads = maybe_cuda(torch.zeros(num_mem_subs, sum(grad_dims), dtype=torch.float32)) 92 | shuffeled_inds = torch.randperm(buffer.current_index) 93 | for i in range(num_mem_subs): 94 | random_batch_inds = shuffeled_inds[ 95 | i * gss_batch_size:i * gss_batch_size + gss_batch_size] 96 | batch_x = buffer.buffer_img[random_batch_inds] 97 | batch_y = buffer.buffer_label[random_batch_inds] 98 | buffer.model.zero_grad() 99 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y) 100 | loss.backward() 101 | mem_grads[i].data.copy_(get_grad_vector(buffer.model.parameters, grad_dims)) 102 | return mem_grads 103 | 104 | def get_each_batch_sample_sim(self, buffer, grad_dims, mem_grads, batch_x, batch_y): 105 | """ 106 | Args: 107 | buffer: memory buffer 108 | grad_dims: gradient dimensions 109 | mem_grads: gradient from memory subsets 110 | batch_x: batch images 111 | batch_y: batch labels 112 | Returns: score of each sample from current batch 113 | """ 114 | cosine_sim = maybe_cuda(torch.zeros(batch_x.size(0))) 115 | for i, (x, y) in enumerate(zip(batch_x, batch_y)): 116 | buffer.model.zero_grad() 117 | ptloss = F.cross_entropy(buffer.model.forward(x.unsqueeze(0)), y.unsqueeze(0)) 118 | ptloss.backward() 119 | # add the new grad to the memory grads and add it is cosine similarity 120 | this_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0) 121 | cosine_sim[i] = max(cosine_similarity(mem_grads, this_grad)) 122 | return cosine_sim 123 | -------------------------------------------------------------------------------- /utils/buffer/aser_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import maybe_cuda, mini_batch_deep_features, euclidean_distance, nonzero_indices, ohe_label 3 | from utils.setup_elements import n_classes 4 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling 5 | 6 | 7 | def compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, k, device="cpu"): 8 | """ 9 | Compute KNN SV of candidate data w.r.t. evaluation data. 10 | Args: 11 | model (object): neural network. 12 | eval_x (tensor): evaluation data tensor. 13 | eval_y (tensor): evaluation label tensor. 14 | cand_x (tensor): candidate data tensor. 15 | cand_y (tensor): candidate label tensor. 16 | k (int): number of nearest neighbours. 17 | device (str): device for tensor allocation. 18 | Returns 19 | sv_matrix (tensor): KNN Shapley value matrix of candidate data w.r.t. evaluation data. 20 | """ 21 | # Compute KNN SV score for candidate samples w.r.t. evaluation samples 22 | n_eval = eval_x.size(0) 23 | n_cand = cand_x.size(0) 24 | # Initialize SV matrix to matrix of -1 25 | sv_matrix = torch.zeros((n_eval, n_cand), device=device) 26 | # Get deep features 27 | eval_df, cand_df = deep_features(model, eval_x, n_eval, cand_x, n_cand) 28 | # Sort indices based on distance in deep feature space 29 | sorted_ind_mat = sorted_cand_ind(eval_df, cand_df, n_eval, n_cand) 30 | 31 | # Evaluation set labels 32 | el = eval_y 33 | el_vec = el.repeat([n_cand, 1]).T 34 | # Sorted candidate set labels 35 | cl = cand_y[sorted_ind_mat] 36 | 37 | # Indicator function matrix 38 | indicator = (el_vec == cl).float() 39 | indicator_next = torch.zeros_like(indicator, device=device) 40 | indicator_next[:, 0:n_cand - 1] = indicator[:, 1:] 41 | indicator_diff = indicator - indicator_next 42 | 43 | cand_ind = torch.arange(n_cand, dtype=torch.float, device=device) + 1 44 | denom_factor = cand_ind.clone() 45 | denom_factor[:n_cand - 1] = denom_factor[:n_cand - 1] * k 46 | numer_factor = cand_ind.clone() 47 | numer_factor[k:n_cand - 1] = k 48 | numer_factor[n_cand - 1] = 1 49 | factor = numer_factor / denom_factor 50 | 51 | indicator_factor = indicator_diff * factor 52 | indicator_factor_cumsum = indicator_factor.flip(1).cumsum(1).flip(1) 53 | 54 | # Row indices 55 | row_ind = torch.arange(n_eval, device=device) 56 | row_mat = torch.repeat_interleave(row_ind, n_cand).reshape([n_eval, n_cand]) 57 | 58 | # Compute SV recursively 59 | sv_matrix[row_mat, sorted_ind_mat] = indicator_factor_cumsum 60 | 61 | return sv_matrix 62 | 63 | 64 | def deep_features(model, eval_x, n_eval, cand_x, n_cand): 65 | """ 66 | Compute deep features of evaluation and candidate data. 67 | Args: 68 | model (object): neural network. 69 | eval_x (tensor): evaluation data tensor. 70 | n_eval (int): number of evaluation data. 71 | cand_x (tensor): candidate data tensor. 72 | n_cand (int): number of candidate data. 73 | Returns 74 | eval_df (tensor): deep features of evaluation data. 75 | cand_df (tensor): deep features of evaluation data. 76 | """ 77 | # Get deep features 78 | if cand_x is None: 79 | num = n_eval 80 | total_x = eval_x 81 | else: 82 | num = n_eval + n_cand 83 | total_x = torch.cat((eval_x, cand_x), 0) 84 | 85 | # compute deep features with mini-batches 86 | total_x = maybe_cuda(total_x) 87 | deep_features_ = mini_batch_deep_features(model, total_x, num) 88 | 89 | eval_df = deep_features_[0:n_eval] 90 | cand_df = deep_features_[n_eval:] 91 | return eval_df, cand_df 92 | 93 | 94 | def sorted_cand_ind(eval_df, cand_df, n_eval, n_cand): 95 | """ 96 | Sort indices of candidate data according to 97 | their Euclidean distance to each evaluation data in deep feature space. 98 | Args: 99 | eval_df (tensor): deep features of evaluation data. 100 | cand_df (tensor): deep features of evaluation data. 101 | n_eval (int): number of evaluation data. 102 | n_cand (int): number of candidate data. 103 | Returns 104 | sorted_cand_ind (tensor): sorted indices of candidate set w.r.t. each evaluation data. 105 | """ 106 | # Sort indices of candidate set according to distance w.r.t. evaluation set in deep feature space 107 | # Preprocess feature vectors to facilitate vector-wise distance computation 108 | eval_df_repeat = eval_df.repeat([1, n_cand]).reshape([n_eval * n_cand, eval_df.shape[1]]) 109 | cand_df_tile = cand_df.repeat([n_eval, 1]) 110 | # Compute distance between evaluation and candidate feature vectors 111 | distance_vector = euclidean_distance(eval_df_repeat, cand_df_tile) 112 | # Turn distance vector into distance matrix 113 | distance_matrix = distance_vector.reshape((n_eval, n_cand)) 114 | # Sort candidate set indices based on distance 115 | sorted_cand_ind_ = distance_matrix.argsort(1) 116 | return sorted_cand_ind_ 117 | 118 | 119 | def add_minority_class_input(cur_x, cur_y, mem_size, num_class): 120 | """ 121 | Find input instances from minority classes, and concatenate them to evaluation data/label tensors later. 122 | This facilitates the inclusion of minority class samples into memory when ASER's update method is used under online-class incremental setting. 123 | 124 | More details: 125 | 126 | Evaluation set may not contain any samples from minority classes (i.e., those classes with very few number of corresponding samples stored in the memory). 127 | This happens after task changes in online-class incremental setting. 128 | Minority class samples can then get very low or negative KNN-SV, making it difficult to store any of them in the memory. 129 | 130 | By identifying minority class samples in the current input batch, and concatenating them to the evaluation set, 131 | KNN-SV of the minority class samples can be artificially boosted (i.e., positive value with larger magnitude). 132 | This allows to quickly accomodate new class samples in the memory right after task changes. 133 | 134 | Threshold for being a minority class is a hyper-parameter related to the class proportion. 135 | In this implementation, it is randomly selected between 0 and 1 / number of all classes for each current input batch. 136 | 137 | 138 | Args: 139 | cur_x (tensor): current input data tensor. 140 | cur_y (tensor): current input label tensor. 141 | mem_size (int): memory size. 142 | num_class (int): number of classes in dataset. 143 | Returns 144 | minority_batch_x (tensor): subset of current input data from minority class. 145 | minority_batch_y (tensor): subset of current input label from minority class. 146 | """ 147 | # Select input instances from minority classes that will be concatenated to pre-selected data 148 | threshold = torch.tensor(1).float().uniform_(0, 1 / num_class).item() 149 | 150 | # If number of buffered samples from certain class is lower than random threshold, 151 | # that class is minority class 152 | cls_proportion = ClassBalancedRandomSampling.class_num_cache.float() / mem_size 153 | minority_ind = nonzero_indices(cls_proportion[cur_y] < threshold) 154 | 155 | minority_batch_x = cur_x[minority_ind] 156 | minority_batch_y = cur_y[minority_ind] 157 | return minority_batch_x, minority_batch_y 158 | -------------------------------------------------------------------------------- /models/ndpm/ndpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler 4 | 5 | from utils.utils import maybe_cuda 6 | from utils.global_vars import * 7 | from .expert import Expert 8 | from .priors import CumulativePrior 9 | 10 | 11 | class Ndpm(nn.Module): 12 | def __init__(self, params): 13 | super().__init__() 14 | self.params = params 15 | self.experts = nn.ModuleList([Expert(params)]) 16 | self.stm_capacity = params.stm_capacity 17 | self.stm_x = [] 18 | self.stm_y = [] 19 | self.prior = CumulativePrior(params) 20 | 21 | def get_experts(self): 22 | return tuple(self.experts.children()) 23 | 24 | def forward(self, x): 25 | with torch.no_grad(): 26 | if len(self.experts) == 1: 27 | raise RuntimeError('There\'s no expert to run on the input') 28 | x = maybe_cuda(x) 29 | log_evid = -self.experts[-1].g.collect_nll(x) # [B, 1+K] 30 | log_evid = log_evid[:, 1:].unsqueeze(2) # [B, K, 1] 31 | log_prior = -self.prior.nl_prior()[1:] # [K] 32 | log_prior -= torch.logsumexp(log_prior, dim=0) 33 | log_prior = log_prior.unsqueeze(0).unsqueeze(2) # [1, K, 1] 34 | log_joint = log_prior + log_evid # [B, K, 1] 35 | if not MODELS_NDPM_NDPM_DISABLE_D: 36 | log_pred = self.experts[-1].d.collect_forward(x) # [B, 1+K, C] 37 | log_pred = log_pred[:, 1:, :] # [B, K, C] 38 | log_joint = log_joint + log_pred # [B, K, C] 39 | 40 | log_joint = log_joint.logsumexp(dim=1).squeeze() # [B,] or [B, C] 41 | return log_joint 42 | 43 | 44 | def learn(self, x, y): 45 | x, y = maybe_cuda(x), maybe_cuda(y) 46 | 47 | if MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS: 48 | self.stm_x.extend(torch.unbind(x.cpu())) 49 | self.stm_y.extend(torch.unbind(y.cpu())) 50 | else: 51 | # Determine the destination of each data point 52 | nll = self.experts[-1].collect_nll(x, y) # [B, 1+K] 53 | nl_prior = self.prior.nl_prior() # [1+K] 54 | nl_joint = nll + nl_prior.unsqueeze(0).expand( 55 | nll.size(0), -1) # [B, 1+K] 56 | 57 | # Save to short-term memory 58 | destination = maybe_cuda(torch.argmin(nl_joint, dim=1)) # [B] 59 | to_stm = destination == 0 # [B] 60 | self.stm_x.extend(torch.unbind(x[to_stm].cpu())) 61 | self.stm_y.extend(torch.unbind(y[to_stm].cpu())) 62 | 63 | # Train expert 64 | with torch.no_grad(): 65 | min_joint = nl_joint.min(dim=1)[0].view(-1, 1) 66 | to_expert = torch.exp(-nl_joint + min_joint) # [B, 1+K] 67 | to_expert[:, 0] = 0. # [B, 1+K] 68 | to_expert = \ 69 | to_expert / (to_expert.sum(dim=1).view(-1, 1) + 1e-7) 70 | 71 | # Compute losses per expert 72 | nll_for_train = nll * (1. - to_stm.float()).unsqueeze(1) # [B,1+K] 73 | losses = (nll_for_train * to_expert).sum(0) # [1+K] 74 | 75 | # Record expert usage 76 | expert_usage = to_expert.sum(dim=0) # [K+1] 77 | self.prior.record_usage(expert_usage) 78 | 79 | # Do lr_decay implicitly 80 | if MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY: 81 | losses = losses \ 82 | * self.params.stm_capacity / (self.prior.counts + 1e-8) 83 | loss = losses.sum() 84 | 85 | if loss.requires_grad: 86 | update_threshold = 0 87 | for k, usage in enumerate(expert_usage): 88 | if usage > update_threshold: 89 | self.experts[k].zero_grad() 90 | loss.backward() 91 | for k, usage in enumerate(expert_usage): 92 | if usage > update_threshold: 93 | self.experts[k].clip_grad() 94 | self.experts[k].optimizer_step() 95 | self.experts[k].lr_scheduler_step() 96 | 97 | # Sleep 98 | if len(self.stm_x) >= self.stm_capacity: 99 | dream_dataset = TensorDataset( 100 | torch.stack(self.stm_x), torch.stack(self.stm_y)) 101 | self.sleep(dream_dataset) 102 | self.stm_x = [] 103 | self.stm_y = [] 104 | 105 | def sleep(self, dream_dataset): 106 | print('\nGoing to sleep...') 107 | # Add new expert and optimizer 108 | expert = Expert(self.params, self.get_experts()) 109 | self.experts.append(expert) 110 | self.prior.add_expert() 111 | 112 | stacked_stm_x = torch.stack(self.stm_x) 113 | stacked_stm_y = torch.stack(self.stm_y) 114 | indices = torch.randperm(stacked_stm_x.size(0)) 115 | train_size = stacked_stm_x.size(0) - MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE 116 | dream_dataset = TensorDataset( 117 | stacked_stm_x[indices[:train_size]], 118 | stacked_stm_y[indices[:train_size]]) 119 | 120 | # Prepare data iterator 121 | self.prior.record_usage(len(dream_dataset), index=-1) 122 | dream_iterator = iter(DataLoader( 123 | dream_dataset, 124 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE, 125 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS, 126 | sampler=RandomSampler( 127 | dream_dataset, 128 | replacement=True, 129 | num_samples=( 130 | MODELS_NDPM_NDPM_SLEEP_STEP_G * 131 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE 132 | )) 133 | )) 134 | 135 | # Train generative component 136 | for step, (x, y) in enumerate(dream_iterator): 137 | step += 1 138 | x, y = maybe_cuda(x), maybe_cuda(y) 139 | g_loss = expert.g.nll(x, y, step=step) 140 | g_loss = (g_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY 141 | * expert.g.weight_decay_loss()) 142 | expert.g.zero_grad() 143 | g_loss.mean().backward() 144 | expert.g.clip_grad() 145 | expert.g.optimizer.step() 146 | 147 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0: 148 | print('\r [Sleep-G %6d] loss: %5.1f' % ( 149 | step, g_loss.mean() 150 | ), end='') 151 | print() 152 | 153 | dream_iterator = iter(DataLoader( 154 | dream_dataset, 155 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE, 156 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS, 157 | sampler=RandomSampler( 158 | dream_dataset, 159 | replacement=True, 160 | num_samples=( 161 | MODELS_NDPM_NDPM_SLEEP_STEP_D * 162 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE) 163 | ) 164 | )) 165 | 166 | # Train discriminative component 167 | if not MODELS_NDPM_NDPM_DISABLE_D: 168 | for step, (x, y) in enumerate(dream_iterator): 169 | step += 1 170 | x, y = maybe_cuda(x), maybe_cuda(y) 171 | d_loss = expert.d.nll(x, y, step=step) 172 | d_loss = (d_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY 173 | * expert.d.weight_decay_loss()) 174 | expert.d.zero_grad() 175 | d_loss.mean().backward() 176 | expert.d.clip_grad() 177 | expert.d.optimizer.step() 178 | 179 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0: 180 | print('\r [Sleep-D %6d] loss: %5.1f' % ( 181 | step, d_loss.mean() 182 | ), end='') 183 | 184 | expert.lr_scheduler_step() 185 | expert.lr_scheduler_step() 186 | expert.eval() 187 | print() 188 | 189 | @staticmethod 190 | def _nl_joint(nl_prior, nll): 191 | batch = nll.size(0) 192 | nl_prior = nl_prior.unsqueeze(0).expand(batch, -1) # [B, 1+K] 193 | return nll + nl_prior 194 | 195 | def train(self, mode=True): 196 | # Disabled 197 | pass 198 | -------------------------------------------------------------------------------- /continuum/non_stationary.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from skimage.filters import gaussian 6 | from continuum.data_utils import train_val_test_split_ni 7 | 8 | 9 | class Original(object): 10 | def __init__(self, x, y, unroll=False, color=False): 11 | if color: 12 | self.x = x / 255.0 13 | else: 14 | self.x = x 15 | self.next_x = self.x 16 | self.next_y = y 17 | self.y = y 18 | self.unroll = unroll 19 | 20 | def get_dims(self): 21 | # Get data input and output dimensions 22 | print("input size {}\noutput size {}".format(self.x.shape[1], self.y.shape[1])) 23 | return self.x.shape[1], self.y.shape[1] 24 | 25 | def show_sample(self, num_plot=1): 26 | # idx = np.random.choice(self.x.shape[0]) 27 | for i in range(num_plot): 28 | plt.subplot(1, 2, 1) 29 | if self.x[i].shape[2] == 1: 30 | plt.imshow(np.squeeze(self.x[i])) 31 | else: 32 | plt.imshow(self.x[i]) 33 | plt.title("original task image") 34 | plt.subplot(1, 2, 2) 35 | if self.x[i].shape[2] == 1: 36 | plt.imshow(np.squeeze(self.next_x[i])) 37 | else: 38 | plt.imshow(self.next_x[i]) 39 | plt.title(self.get_name()) 40 | plt.axis('off') 41 | plt.show() 42 | 43 | def create_output(self): 44 | if self.unroll: 45 | ret = self.next_x.reshape((-1, self.x.shape[1] ** 2)), self.next_y 46 | else: 47 | ret = self.next_x, self.next_y 48 | return ret 49 | 50 | @staticmethod 51 | def clip_minmax(l, min_=0., max_=1.): 52 | return np.clip(l, min_, max_) 53 | 54 | def get_name(self): 55 | if hasattr(self, 'factor'): 56 | return str(self.__class__.__name__) + '_' + str(self.factor) 57 | 58 | def next_task(self, *args): 59 | self.next_x = self.x 60 | self.next_y = self.y 61 | return self.create_output() 62 | 63 | 64 | class Noisy(Original): 65 | def __init__(self, x, y, full=False, color=False): 66 | super(Noisy, self).__init__(x, y, full, color) 67 | 68 | def next_task(self, noise_factor=0.8, sig=0.1, noise_type='Gaussian'): 69 | next_x = deepcopy(self.x) 70 | self.factor = noise_factor 71 | if noise_type == 'Gaussian': 72 | self.next_x = next_x + noise_factor * np.random.normal(loc=0.0, scale=sig, size=next_x.shape) 73 | elif noise_factor == 'S&P': 74 | # TODO implement S&P 75 | pass 76 | 77 | self.next_x = super().clip_minmax(self.next_x, 0, 1) 78 | 79 | return super().create_output() 80 | 81 | 82 | class Blurring(Original): 83 | def __init__(self, x, y, full=False, color=False): 84 | super(Blurring, self).__init__(x, y, full, color) 85 | 86 | def next_task(self, blurry_factor=0.6, blurry_type='Gaussian'): 87 | next_x = deepcopy(self.x) 88 | self.factor = blurry_factor 89 | if blurry_type == 'Gaussian': 90 | self.next_x = gaussian(next_x, sigma=blurry_factor, multichannel=True) 91 | elif blurry_type == 'Average': 92 | pass 93 | # TODO implement average 94 | 95 | self.next_x = super().clip_minmax(self.next_x, 0, 1) 96 | 97 | return super().create_output() 98 | 99 | 100 | class Occlusion(Original): 101 | def __init__(self, x, y, full=False, color=False): 102 | super(Occlusion, self).__init__(x, y, full, color) 103 | 104 | def next_task(self, occlusion_factor=0.2): 105 | next_x = deepcopy(self.x) 106 | self.factor = occlusion_factor 107 | self.image_size = next_x.shape[1] 108 | 109 | occlusion_size = int(occlusion_factor * self.image_size) 110 | half_size = occlusion_size // 2 111 | occlusion_x = random.randint(min(half_size, self.image_size - half_size), 112 | max(half_size, self.image_size - half_size)) 113 | occlusion_y = random.randint(min(half_size, self.image_size - half_size), 114 | max(half_size, self.image_size - half_size)) 115 | 116 | # self.next_x = next_x.reshape((-1, self.image_size, self.image_size)) 117 | 118 | next_x[:, max((occlusion_x - half_size), 0):min((occlusion_x + half_size), self.image_size), \ 119 | max((occlusion_y - half_size), 0):min((occlusion_y + half_size), self.image_size)] = 1 120 | 121 | self.next_x = next_x 122 | super().clip_minmax(self.next_x, 0, 1) 123 | 124 | return super().create_output() 125 | 126 | 127 | def test_ns(x, y, ns_type, change_factor): 128 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring} 129 | change = ns_match[ns_type] 130 | tmp = change(x, y, color=True) 131 | tmp.next_task(change_factor) 132 | tmp.show_sample(10) 133 | 134 | 135 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring} 136 | 137 | 138 | def construct_ns_single(train_x_split, train_y_split, test_x_split, test_y_split, ns_type, change_factor, ns_task, 139 | plot=True): 140 | # Data splits 141 | train_list = [] 142 | test_list = [] 143 | change = ns_match[ns_type] 144 | i = 0 145 | if len(change_factor) == 1: 146 | change_factor = change_factor[0] 147 | for idx, val in enumerate(ns_task): 148 | if idx % 2 == 0: 149 | for _ in range(val): 150 | print(i, 'normal') 151 | # train 152 | tmp = Original(train_x_split[i], train_y_split[i], color=True) 153 | train_list.append(tmp.next_task()) 154 | if plot: 155 | tmp.show_sample() 156 | 157 | # test 158 | tmp_test = Original(test_x_split[i], test_y_split[i], color=True) 159 | test_list.append(tmp_test.next_task()) 160 | if plot: 161 | tmp_test.show_sample() 162 | 163 | i += 1 164 | else: 165 | for _ in range(val): 166 | print(i, 'change') 167 | # train 168 | tmp = change(train_x_split[i], train_y_split[i], color=True) 169 | train_list.append(tmp.next_task(change_factor)) 170 | if plot: 171 | tmp.show_sample() 172 | # test 173 | tmp_test = change(test_x_split[i], test_y_split[i], color=True) 174 | test_list.append(tmp_test.next_task(change_factor)) 175 | if plot: 176 | tmp_test.show_sample() 177 | 178 | i += 1 179 | return train_list, test_list 180 | 181 | 182 | def construct_ns_multiple(train_x_split, train_y_split, val_x_rdm_split, val_y_rdm_split, test_x_split, 183 | test_y_split, ns_type, change_factors, plot): 184 | train_list = [] 185 | val_list = [] 186 | test_list = [] 187 | ns_len = len(change_factors) 188 | for i in range(ns_len): 189 | factor = change_factors[i] 190 | if factor == 0: 191 | ns_generator = Original 192 | else: 193 | ns_generator = ns_match[ns_type] 194 | print(i, factor) 195 | # train 196 | tmp = ns_generator(train_x_split[i], train_y_split[i], color=True) 197 | train_list.append(tmp.next_task(factor)) 198 | if plot: 199 | tmp.show_sample() 200 | 201 | tmp_val = ns_generator(val_x_rdm_split[i], val_y_rdm_split[i], color=True) 202 | val_list.append(tmp_val.next_task(factor)) 203 | 204 | tmp_test = ns_generator(test_x_split[i], test_y_split[i], color=True) 205 | test_list.append(tmp_test.next_task(factor)) 206 | return train_list, val_list, test_list 207 | 208 | 209 | def construct_ns_multiple_wrapper(train_data, train_label, test_data, est_label, task_nums, img_size, 210 | val_size, ns_type, ns_factor, plot): 211 | train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split = train_val_test_split_ni( 212 | train_data, train_label, test_data, est_label, task_nums, img_size, 213 | val_size) 214 | train_set, val_set, test_set = construct_ns_multiple(train_data_rdm_split, train_label_rdm_split, 215 | val_data_rdm_split, val_label_rdm_split, 216 | test_data_rdm_split, test_label_rdm_split, 217 | ns_type, 218 | ns_factor, 219 | plot=plot) 220 | return train_set, val_set, test_set 221 | -------------------------------------------------------------------------------- /models/ndpm/classifier.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.utils import maybe_cuda 6 | from .component import ComponentD 7 | from utils.global_vars import * 8 | from utils.setup_elements import n_classes 9 | 10 | 11 | class Classifier(ComponentD, ABC): 12 | def __init__(self, params, experts): 13 | super().__init__(params, experts) 14 | self.ce_loss = nn.NLLLoss(reduction='none') 15 | 16 | @abstractmethod 17 | def forward(self, x): 18 | """Output log P(y|x)""" 19 | pass 20 | 21 | def nll(self, x, y, step=None): 22 | x, y = maybe_cuda(x), maybe_cuda(y) 23 | log_softmax = self.forward(x) 24 | loss_pred = self.ce_loss(log_softmax, y) 25 | 26 | # Classifier chilling 27 | chilled_log_softmax = F.log_softmax( 28 | log_softmax / self.params.classifier_chill, dim=1) 29 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y) 30 | 31 | # Value with chill & gradient without chill 32 | loss_pred = loss_pred - loss_pred.detach() \ 33 | + chilled_loss_pred.detach() 34 | 35 | return loss_pred 36 | 37 | 38 | class SharingClassifier(Classifier, ABC): 39 | @abstractmethod 40 | def forward(self, x, collect=False): 41 | pass 42 | 43 | def collect_forward(self, x): 44 | dummy_pred = self.experts[0](x) 45 | preds, _ = self.forward(x, collect=True) 46 | return torch.stack([dummy_pred] + preds, dim=1) 47 | 48 | def collect_nll(self, x, y, step=None): 49 | preds = self.collect_forward(x) # [B, 1+K, C] 50 | loss_preds = [] 51 | for log_softmax in preds.unbind(dim=1): 52 | loss_pred = self.ce_loss(log_softmax, y) 53 | 54 | # Classifier chilling 55 | chilled_log_softmax = F.log_softmax( 56 | log_softmax / self.params.classifier_chill, dim=1) 57 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y) 58 | 59 | # Value with chill & gradient without chill 60 | loss_pred = loss_pred - loss_pred.detach() \ 61 | + chilled_loss_pred.detach() 62 | 63 | loss_preds.append(loss_pred) 64 | return torch.stack(loss_preds, dim=1) 65 | 66 | 67 | class BasicBlock(nn.Module): 68 | expansion = 1 69 | 70 | def __init__(self, inplanes, planes, stride=1, 71 | downsample=None, upsample=None, 72 | dilation=1, norm_layer=nn.BatchNorm2d): 73 | super(BasicBlock, self).__init__() 74 | if dilation > 1: 75 | raise NotImplementedError( 76 | "Dilation > 1 not supported in BasicBlock" 77 | ) 78 | transpose = upsample is not None and stride != 1 79 | self.conv1 = ( 80 | conv4x4t(inplanes, planes, stride) if transpose else 81 | conv3x3(inplanes, planes, stride) 82 | ) 83 | self.bn1 = norm_layer(planes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.conv2 = conv3x3(planes, planes) 86 | self.bn2 = norm_layer(planes) 87 | self.downsample = downsample 88 | self.upsample = upsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | elif self.upsample is not None: 104 | identity = self.upsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | def conv4x4t(in_planes, out_planes, stride=1, groups=1, dilation=1): 113 | """4x4 transposed convolution with padding""" 114 | return nn.ConvTranspose2d( 115 | in_planes, out_planes, kernel_size=4, stride=stride, 116 | padding=dilation, groups=groups, bias=False, dilation=dilation, 117 | ) 118 | 119 | 120 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d( 123 | in_planes, out_planes, kernel_size=3, stride=stride, 124 | padding=dilation, groups=groups, bias=False, dilation=dilation 125 | ) 126 | 127 | 128 | def conv1x1(in_planes, out_planes, stride=1): 129 | """1x1 convolution""" 130 | return nn.Conv2d( 131 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 132 | ) 133 | 134 | 135 | class ResNetSharingClassifier(SharingClassifier): 136 | block = BasicBlock 137 | num_blocks = [2, 2, 2, 2] 138 | norm_layer = nn.InstanceNorm2d 139 | 140 | def __init__(self, params, experts): 141 | super().__init__(params, experts) 142 | self.precursors = [expert.d for expert in self.experts[1:]] 143 | first = len(self.precursors) == 0 144 | 145 | if MODELS_NDPM_CLASSIFIER_NUM_BLOCKS is not None: 146 | num_blocks = MODELS_NDPM_CLASSIFIER_NUM_BLOCKS 147 | else: 148 | num_blocks = self.num_blocks 149 | if MODELS_NDPM_CLASSIFIER_NORM_LAYER is not None: 150 | self.norm_layer = getattr(nn, MODELS_NDPM_CLASSIFIER_NORM_LAYER) 151 | else: 152 | self.norm_layer = nn.BatchNorm2d 153 | 154 | num_classes = n_classes[params.data] 155 | nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT 156 | nf_cat = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE \ 157 | + len(self.precursors) * MODELS_NDPM_CLASSIFIER_CLS_NF_EXT 158 | self.nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT 159 | self.nf_cat = nf_cat 160 | 161 | self.layer0 = nn.Sequential( 162 | nn.Conv2d( 163 | 3, nf * 1, kernel_size=3, stride=1, padding=1, bias=False 164 | ), 165 | self.norm_layer(nf * 1), 166 | nn.ReLU() 167 | ) 168 | self.layer1 = self._make_layer( 169 | nf_cat * 1, nf * 1, num_blocks[0], stride=1) 170 | self.layer2 = self._make_layer( 171 | nf_cat * 1, nf * 2, num_blocks[1], stride=2) 172 | self.layer3 = self._make_layer( 173 | nf_cat * 2, nf * 4, num_blocks[2], stride=2) 174 | self.layer4 = self._make_layer( 175 | nf_cat * 4, nf * 8, num_blocks[3], stride=2) 176 | self.predict = nn.Sequential( 177 | nn.Linear(nf_cat * 8, num_classes), 178 | nn.LogSoftmax(dim=1) 179 | ) 180 | self.setup_optimizer() 181 | 182 | def _make_layer(self, nf_in, nf_out, num_blocks, stride): 183 | norm_layer = self.norm_layer 184 | block = self.block 185 | downsample = None 186 | if stride != 1 or nf_in != nf_out: 187 | downsample = nn.Sequential( 188 | conv1x1(nf_in, nf_out, stride), 189 | norm_layer(nf_out), 190 | ) 191 | layers = [block( 192 | nf_in, nf_out, stride, 193 | downsample=downsample, 194 | norm_layer=norm_layer 195 | )] 196 | for _ in range(1, num_blocks): 197 | layers.append(block(nf_out, nf_out, norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x, collect=False): 202 | x = maybe_cuda(x) 203 | 204 | # First component 205 | if len(self.precursors) == 0: 206 | h1 = self.layer0(x) 207 | h2 = self.layer1(h1) 208 | h3 = self.layer2(h2) 209 | h4 = self.layer3(h3) 210 | h5 = self.layer4(h4) 211 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1) 212 | pred = self.predict(h5) 213 | 214 | if collect: 215 | return [pred], [ 216 | h1.detach(), h2.detach(), h3.detach(), 217 | h4.detach(), h5.detach()] 218 | else: 219 | return pred 220 | 221 | # Second or layer component 222 | preds, features = self.precursors[-1](x, collect=True) 223 | h1 = self.layer0(x) 224 | h1_cat = torch.cat([features[0], h1], dim=1) 225 | h2 = self.layer1(h1_cat) 226 | h2_cat = torch.cat([features[1], h2], dim=1) 227 | h3 = self.layer2(h2_cat) 228 | h3_cat = torch.cat([features[2], h3], dim=1) 229 | h4 = self.layer3(h3_cat) 230 | h4_cat = torch.cat([features[3], h4], dim=1) 231 | h5 = self.layer4(h4_cat) 232 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1) 233 | h5_cat = torch.cat([features[4], h5], dim=1) 234 | pred = self.predict(h5_cat) 235 | 236 | if collect: 237 | preds.append(pred) 238 | return preds, [ 239 | h1_cat.detach(), h2_cat.detach(), h3_cat.detach(), 240 | h4_cat.detach(), h5_cat.detach(), 241 | ] 242 | else: 243 | return pred 244 | -------------------------------------------------------------------------------- /utils/buffer/buffer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.utils import maybe_cuda 4 | from collections import defaultdict 5 | from collections import Counter 6 | import random 7 | 8 | 9 | def random_retrieve(buffer, num_retrieve, excl_indices=None, return_indices=False): 10 | filled_indices = np.arange(buffer.current_index) 11 | if excl_indices is not None: 12 | excl_indices = list(excl_indices) 13 | else: 14 | excl_indices = [] 15 | valid_indices = np.setdiff1d(filled_indices, np.array(excl_indices)) 16 | num_retrieve = min(num_retrieve, valid_indices.shape[0]) 17 | indices = torch.from_numpy(np.random.choice(valid_indices, num_retrieve, replace=False)).long() 18 | 19 | x = buffer.buffer_img[indices] 20 | 21 | y = buffer.buffer_label[indices] 22 | 23 | if return_indices: 24 | return x, y, indices 25 | else: 26 | return x, y 27 | 28 | 29 | def balanced_retrieve(buffer, num_retrieve, excl_labels=None, return_indices=False): 30 | avail_labels = list(buffer.labeldict.keys()) 31 | if excl_labels is not None: 32 | avail_labels = np.array(list(set(avail_labels) - set(excl_labels))) 33 | if len(avail_labels) > 0: 34 | num_instances = num_retrieve // 10 35 | random_labels = np.random.choice(avail_labels, 10, replace=False) 36 | indices = [] 37 | for label in random_labels: 38 | tmp_instances = min(num_instances, len(buffer.labeldict[label])) 39 | if tmp_instances: 40 | tmp_indices = np.random.choice(buffer.labeldict[label], tmp_instances, replace=False) 41 | indices.extend(tmp_indices) 42 | else: 43 | indices = [] 44 | 45 | indices = np.random.permutation(indices) 46 | x = buffer.buffer_img[indices] 47 | y = buffer.buffer_label[indices] 48 | 49 | if return_indices: 50 | return x, y, indices 51 | else: 52 | return x, y 53 | 54 | 55 | def match_retrieve(buffer, cur_y, exclud_idx=None): 56 | counter = Counter(cur_y.tolist()) 57 | idx_dict = defaultdict(list) 58 | for idx, val in enumerate(cur_y.tolist()): 59 | idx_dict[val].append(idx) 60 | select = [None] * len(cur_y) 61 | for y in counter: 62 | idx = buffer.buffer_tracker.class_index_cache[y] 63 | if exclud_idx is not None: 64 | idx = idx - set(exclud_idx.tolist()) 65 | if not idx or len(idx) < counter[y]: 66 | print('match retrieve attempt fail') 67 | return torch.tensor([]), torch.tensor([]) 68 | retrieved = random.sample(list(idx), counter[y]) 69 | for idx, val in zip(idx_dict[y], retrieved): 70 | select[idx] = val 71 | indices = torch.tensor(select) 72 | x = buffer.buffer_img[indices] 73 | y = buffer.buffer_label[indices] 74 | return x, y 75 | 76 | def cosine_similarity(x1, x2=None, eps=1e-8): 77 | x2 = x1 if x2 is None else x2 78 | w1 = x1.norm(p=2, dim=1, keepdim=True) 79 | w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True) 80 | sim = torch.mm(x1, x2.t())/(w1 * w2.t()).clamp(min=eps) 81 | return sim 82 | 83 | 84 | def get_grad_vector(pp, grad_dims): 85 | """ 86 | gather the gradients in one vector 87 | """ 88 | grads = maybe_cuda(torch.Tensor(sum(grad_dims))) 89 | grads.fill_(0.0) 90 | cnt = 0 91 | for param in pp(): 92 | if param.grad is not None: 93 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 94 | en = sum(grad_dims[:cnt + 1]) 95 | grads[beg: en].copy_(param.grad.data.view(-1)) 96 | cnt += 1 97 | return grads 98 | 99 | 100 | class ClassBalancedRandomSampling: 101 | # For faster label-based sampling (e.g., class balanced sampling), cache class-index via auxiliary dictionary 102 | # Store {class, set of memory sample indices from class} key-value pairs to speed up label-based sampling 103 | # e.g., {: {, }, : {}, : {}, ...} 104 | class_index_cache = None 105 | class_num_cache = None 106 | 107 | @classmethod 108 | def sample(cls, buffer_x, buffer_y, n_smp_cls, excl_indices=None, device="cpu"): 109 | """ 110 | Take same number of random samples from each class from buffer. 111 | Args: 112 | buffer_x (tensor): data buffer. 113 | buffer_y (tensor): label buffer. 114 | n_smp_cls (int): number of samples to take from each class. 115 | excl_indices (set): indices of buffered instances to be excluded from sampling. 116 | device (str): device for tensor allocation. 117 | Returns 118 | x (tensor): class balanced random sample data tensor. 119 | y (tensor): class balanced random sample label tensor. 120 | sample_ind (tensor): class balanced random sample index tensor. 121 | """ 122 | if excl_indices is None: 123 | excl_indices = set() 124 | 125 | # Get indices for class balanced random samples 126 | # cls_ind_cache = class_index_tensor_list_cache(buffer_y, num_class, excl_indices, device=device) 127 | 128 | sample_ind = torch.tensor([], device=device, dtype=torch.long) 129 | 130 | # Use cache to retrieve indices belonging to each class in buffer 131 | for ind_set in cls.class_index_cache.values(): 132 | if ind_set: 133 | # Exclude some indices 134 | valid_ind = ind_set - excl_indices 135 | # Auxiliary indices for permutation 136 | perm_ind = torch.randperm(len(valid_ind), device=device) 137 | # Apply permutation, and select indices 138 | ind = torch.tensor(list(valid_ind), device=device, dtype=torch.long)[perm_ind][:n_smp_cls] 139 | sample_ind = torch.cat((sample_ind, ind)) 140 | 141 | x = buffer_x[sample_ind] 142 | y = buffer_y[sample_ind] 143 | 144 | x = maybe_cuda(x) 145 | y = maybe_cuda(y) 146 | 147 | return x, y, sample_ind 148 | 149 | @classmethod 150 | def update_cache(cls, buffer_y, num_class, new_y=None, ind=None, device="cpu"): 151 | """ 152 | Collect indices of buffered data from each class in set. 153 | Update class_index_cache with list of such sets. 154 | Args: 155 | buffer_y (tensor): label buffer. 156 | num_class (int): total number of unique class labels. 157 | new_y (tensor): label tensor for replacing memory samples at ind in buffer. 158 | ind (tensor): indices of memory samples to be updated. 159 | device (str): device for tensor allocation. 160 | """ 161 | if cls.class_index_cache is None: 162 | # Initialize caches 163 | cls.class_index_cache = defaultdict(set) 164 | cls.class_num_cache = torch.zeros(num_class, dtype=torch.long, device=device) 165 | 166 | if new_y is not None: 167 | # If ASER update is being used, keep updating existing caches 168 | # Get labels of memory samples to be replaced 169 | orig_y = buffer_y[ind] 170 | # Update caches 171 | for i, ny, oy in zip(ind, new_y, orig_y): 172 | oy_int = oy.item() 173 | ny_int = ny.item() 174 | i_int = i.item() 175 | # Update dictionary according to new class label of index i 176 | if oy_int in cls.class_index_cache and i_int in cls.class_index_cache[oy_int]: 177 | cls.class_index_cache[oy_int].remove(i_int) 178 | cls.class_num_cache[oy_int] -= 1 179 | cls.class_index_cache[ny_int].add(i_int) 180 | cls.class_num_cache[ny_int] += 1 181 | else: 182 | # If only ASER retrieve is being used, reset cache and update it based on buffer 183 | cls_ind_cache = defaultdict(set) 184 | for i, c in enumerate(buffer_y): 185 | cls_ind_cache[c.item()].add(i) 186 | cls.class_index_cache = cls_ind_cache 187 | 188 | 189 | class BufferClassTracker(object): 190 | # For faster label-based sampling (e.g., class balanced sampling), cache class-index via auxiliary dictionary 191 | # Store {class, set of memory sample indices from class} key-value pairs to speed up label-based sampling 192 | # e.g., {: {, }, : {}, : {}, ...} 193 | 194 | def __init__(self, num_class, device="cpu"): 195 | super().__init__() 196 | # Initialize caches 197 | self.class_index_cache = defaultdict(set) 198 | self.class_num_cache = np.zeros(num_class) 199 | 200 | 201 | def update_cache(self, buffer_y, new_y=None, ind=None, ): 202 | """ 203 | Collect indices of buffered data from each class in set. 204 | Update class_index_cache with list of such sets. 205 | Args: 206 | buffer_y (tensor): label buffer. 207 | num_class (int): total number of unique class labels. 208 | new_y (tensor): label tensor for replacing memory samples at ind in buffer. 209 | ind (tensor): indices of memory samples to be updated. 210 | device (str): device for tensor allocation. 211 | """ 212 | 213 | # Get labels of memory samples to be replaced 214 | orig_y = buffer_y[ind] 215 | # Update caches 216 | for i, ny, oy in zip(ind, new_y, orig_y): 217 | oy_int = oy.item() 218 | ny_int = ny.item() 219 | # Update dictionary according to new class label of index i 220 | if oy_int in self.class_index_cache and i in self.class_index_cache[oy_int]: 221 | self.class_index_cache[oy_int].remove(i) 222 | self.class_num_cache[oy_int] -= 1 223 | 224 | self.class_index_cache[ny_int].add(i) 225 | self.class_num_cache[ny_int] += 1 226 | 227 | 228 | def check_tracker(self): 229 | print(self.class_num_cache.sum()) 230 | print(len([k for i in self.class_index_cache.values() for k in i])) 231 | -------------------------------------------------------------------------------- /models/ndpm/vae.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from itertools import accumulate 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils.utils import maybe_cuda 7 | from .loss import bernoulli_nll, logistic_nll, gaussian_nll, laplace_nll 8 | from .component import ComponentG 9 | from .utils import Lambda 10 | from utils.global_vars import * 11 | from utils.setup_elements import input_size_match 12 | 13 | class Vae(ComponentG, ABC): 14 | def __init__(self, params, experts): 15 | super().__init__(params, experts) 16 | x_c, x_h, x_w = input_size_match[params.data] 17 | bernoulli = MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' 18 | if bernoulli: 19 | self.log_var_param = None 20 | elif MODELS_NDPM_VAE_LEARN_X_LOG_VAR: 21 | self.log_var_param = nn.Parameter( 22 | torch.ones([x_c]) * MODELS_NDPM_VAE_X_LOG_VAR_PARAM, 23 | requires_grad=True 24 | ) 25 | else: 26 | self.log_var_param = ( 27 | maybe_cuda(torch.ones([x_c])) * 28 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM 29 | ) 30 | 31 | def forward(self, x): 32 | x = maybe_cuda(x) 33 | z_mean, z_log_var = self.encode(x) 34 | z = self.reparameterize(z_mean, z_log_var, 1) 35 | return self.decode(z) 36 | 37 | def nll(self, x, y=None, step=None): 38 | x = maybe_cuda(x) 39 | z_mean, z_log_var = self.encode(x) 40 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES) 41 | x_mean = self.decode(z) 42 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, *x.shape[1:]) 43 | x_log_var = ( 44 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else 45 | self.log_var.view(1, 1, -1, 1, 1) 46 | ) 47 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var) 48 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, -1) 49 | loss_recon = loss_recon.sum(2).mean(1) 50 | loss_kl = self.gaussian_kl(z_mean, z_log_var) 51 | loss_vae = loss_recon + loss_kl 52 | 53 | return loss_vae 54 | 55 | def sample(self, n=1): 56 | z = maybe_cuda(torch.randn(n, MODELS_NDPM_VAE_Z_DIM)) 57 | x_mean = self.decode(z) 58 | return x_mean 59 | 60 | def reconstruction_loss(self, x, x_mean, x_log_var=None): 61 | loss_type = MODELS_NDPM_VAE_RECON_LOSS 62 | loss = ( 63 | bernoulli_nll if loss_type == 'bernoulli' else 64 | gaussian_nll if loss_type == 'gaussian' else 65 | laplace_nll if loss_type == 'laplace' else 66 | logistic_nll if loss_type == 'logistic' else None 67 | ) 68 | if loss is None: 69 | raise ValueError('Unknown recon_loss type: {}'.format(loss_type)) 70 | 71 | if len(x_mean.size()) > len(x.size()): 72 | x = x.unsqueeze(1) 73 | 74 | return ( 75 | loss(x, x_mean) if x_log_var is None else 76 | loss(x, x_mean, x_log_var) 77 | ) 78 | 79 | @staticmethod 80 | def gaussian_kl(q_mean, q_log_var, p_mean=None, p_log_var=None): 81 | # p defaults to N(0, 1) 82 | zeros = torch.zeros_like(q_mean) 83 | p_mean = p_mean if p_mean is not None else zeros 84 | p_log_var = p_log_var if p_log_var is not None else zeros 85 | # calcaulate KL(q, p) 86 | kld = 0.5 * ( 87 | p_log_var - q_log_var + 88 | (q_log_var.exp() + (q_mean - p_mean) ** 2) / p_log_var.exp() - 1 89 | ) 90 | kld = kld.sum(1) 91 | return kld 92 | 93 | @staticmethod 94 | def reparameterize(z_mean, z_log_var, num_samples=1): 95 | z_std = (z_log_var * 0.5).exp() 96 | z_std = z_std.unsqueeze(1).expand(-1, num_samples, -1) 97 | z_mean = z_mean.unsqueeze(1).expand(-1, num_samples, -1) 98 | unit_normal = torch.randn_like(z_std) 99 | z = z_mean + unit_normal * z_std 100 | z = z.view(-1, z_std.size(2)) 101 | return z 102 | 103 | @abstractmethod 104 | def encode(self, x): 105 | pass 106 | 107 | @abstractmethod 108 | def decode(self, x): 109 | pass 110 | 111 | @property 112 | def log_var(self): 113 | return ( 114 | None if self.log_var_param is None else 115 | self.log_var_param 116 | ) 117 | 118 | 119 | class SharingVae(Vae, ABC): 120 | def collect_nll(self, x, y=None, step=None): 121 | """Collect NLL values 122 | 123 | Returns: 124 | loss_vae: Tensor of shape [B, 1+K] 125 | """ 126 | x = maybe_cuda(x) 127 | 128 | # Dummy VAE 129 | dummy_nll = self.experts[0].g.nll(x, y, step) 130 | 131 | # Encode 132 | z_means, z_log_vars, features = self.encode(x, collect=True) 133 | 134 | # Decode 135 | loss_vaes = [dummy_nll] 136 | vaes = [expert.g for expert in self.experts[1:]] + [self] 137 | x_logits = [] 138 | for z_mean, z_log_var, vae in zip(z_means, z_log_vars, vaes): 139 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES) 140 | if MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER: 141 | x_logit = vae.decode(z, as_logit=True) 142 | x_logits.append(x_logit) 143 | continue 144 | x_mean = vae.decode(z) 145 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 146 | *x.shape[1:]) 147 | x_log_var = ( 148 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else 149 | self.log_var.view(1, 1, -1, 1, 1) 150 | ) 151 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var) 152 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 153 | -1) 154 | loss_recon = loss_recon.sum(2).mean(1) 155 | loss_kl = self.gaussian_kl(z_mean, z_log_var) 156 | loss_vae = loss_recon + loss_kl 157 | 158 | loss_vaes.append(loss_vae) 159 | 160 | x_logits = list(accumulate( 161 | x_logits, func=(lambda x, y: x.detach() + y) 162 | )) 163 | for x_logit in x_logits: 164 | x_mean = torch.sigmoid(x_logit) 165 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 166 | *x.shape[1:]) 167 | x_log_var = ( 168 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else 169 | self.log_var.view(1, 1, -1, 1, 1) 170 | ) 171 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var) 172 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 173 | -1) 174 | loss_recon = loss_recon.sum(2).mean(1) 175 | loss_kl = self.gaussian_kl(z_mean, z_log_var) 176 | loss_vae = loss_recon + loss_kl 177 | loss_vaes.append(loss_vae) 178 | 179 | return torch.stack(loss_vaes, dim=1) 180 | 181 | @abstractmethod 182 | def encode(self, x, collect=False): 183 | pass 184 | 185 | @abstractmethod 186 | def decode(self, z, as_logit=False): 187 | """ 188 | Decode do not share parameters 189 | """ 190 | pass 191 | 192 | 193 | class CnnSharingVae(SharingVae): 194 | def __init__(self, params, experts): 195 | super().__init__(params, experts) 196 | self.precursors = [expert.g for expert in self.experts[1:]] 197 | first = len(self.precursors) == 0 198 | nf_base, nf_ext = MODLES_NDPM_VAE_NF_BASE, MODELS_NDPM_VAE_NF_EXT 199 | nf = nf_base if first else nf_ext 200 | nf_cat = nf_base + len(self.precursors) * nf_ext 201 | 202 | h1_dim = 1 * nf 203 | h2_dim = 2 * nf 204 | fc_dim = 4 * nf 205 | h1_cat_dim = 1 * nf_cat 206 | h2_cat_dim = 2 * nf_cat 207 | fc_cat_dim = 4 * nf_cat 208 | 209 | x_c, x_h, x_w = input_size_match[params.data] 210 | 211 | self.fc_dim = fc_dim 212 | feature_volume = ((x_h // 4) * (x_w // 4) * 213 | h2_cat_dim) 214 | 215 | self.enc1 = nn.Sequential( 216 | nn.Conv2d(x_c, h1_dim, 3, 1, 1), 217 | nn.MaxPool2d(2), 218 | nn.ReLU() 219 | ) 220 | self.enc2 = nn.Sequential( 221 | nn.Conv2d(h1_cat_dim, h2_dim, 3, 1, 1), 222 | nn.MaxPool2d(2), 223 | nn.ReLU(), 224 | Lambda(lambda x: x.view(x.size(0), -1)) 225 | ) 226 | self.enc3 = nn.Sequential( 227 | nn.Linear(feature_volume, fc_dim), 228 | nn.ReLU() 229 | ) 230 | self.enc_z_mean = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM) 231 | self.enc_z_log_var = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM) 232 | 233 | self.dec_z = nn.Sequential( 234 | nn.Linear(MODELS_NDPM_VAE_Z_DIM, 4 * nf_base), 235 | nn.ReLU() 236 | ) 237 | self.dec3 = nn.Sequential( 238 | nn.Linear( 239 | 4 * nf_base, 240 | (x_h // 4) * (x_w // 4) * 2 * nf_base), 241 | nn.ReLU() 242 | ) 243 | self.dec2 = nn.Sequential( 244 | Lambda(lambda x: x.view( 245 | x.size(0), 2 * nf_base, 246 | x_h // 4, x_w // 4)), 247 | nn.ConvTranspose2d(2 * nf_base, 1 * nf_base, 248 | kernel_size=4, stride=2, padding=1), 249 | nn.ReLU() 250 | ) 251 | self.dec1 = nn.ConvTranspose2d(1 * nf_base, x_c, 252 | kernel_size=4, stride=2, padding=1) 253 | 254 | self.setup_optimizer() 255 | 256 | def encode(self, x, collect=False): 257 | # When first component 258 | if len(self.precursors) == 0: 259 | h1 = self.enc1(x) 260 | h2 = self.enc2(h1) 261 | h3 = self.enc3(h2) 262 | z_mean = self.enc_z_mean(h3) 263 | z_log_var = self.enc_z_log_var(h3) 264 | 265 | if collect: 266 | return [z_mean], [z_log_var], \ 267 | [h1.detach(), h2.detach(), h3.detach()] 268 | else: 269 | return z_mean, z_log_var 270 | 271 | # Second or later component 272 | z_means, z_log_vars, features = \ 273 | self.precursors[-1].encode(x, collect=True) 274 | 275 | h1 = self.enc1(x) 276 | h1_cat = torch.cat([features[0], h1], dim=1) 277 | h2 = self.enc2(h1_cat) 278 | h2_cat = torch.cat([features[1], h2], dim=1) 279 | h3 = self.enc3(h2_cat) 280 | h3_cat = torch.cat([features[2], h3], dim=1) 281 | z_mean = self.enc_z_mean(h3_cat) 282 | z_log_var = self.enc_z_log_var(h3_cat) 283 | 284 | if collect: 285 | z_means.append(z_mean) 286 | z_log_vars.append(z_log_var) 287 | features = [h1_cat.detach(), h2_cat.detach(), h3_cat.detach()] 288 | return z_means, z_log_vars, features 289 | else: 290 | return z_mean, z_log_var 291 | 292 | def decode(self, z, as_logit=False): 293 | h3 = self.dec_z(z) 294 | h2 = self.dec3(h3) 295 | h1 = self.dec2(h2) 296 | x_logit = self.dec1(h1) 297 | return x_logit if as_logit else torch.sigmoid(x_logit) 298 | -------------------------------------------------------------------------------- /agents/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import abc 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | from utils.kd_manager import KdManager 7 | from utils.utils import maybe_cuda, AverageMeter 8 | from torch.utils.data import TensorDataset, DataLoader 9 | import copy 10 | from utils.loss import SupConLoss 11 | import pickle 12 | 13 | 14 | class ContinualLearner(torch.nn.Module, metaclass=abc.ABCMeta): 15 | ''' 16 | Abstract module which is inherited by each and every continual learning algorithm. 17 | ''' 18 | 19 | def __init__(self, model, opt, params): 20 | super(ContinualLearner, self).__init__() 21 | self.params = params 22 | self.model = model 23 | self.opt = opt 24 | self.data = params.data 25 | self.cuda = params.cuda 26 | self.epoch = params.epoch 27 | self.batch = params.batch 28 | self.verbose = params.verbose 29 | self.old_labels = [] 30 | self.new_labels = [] 31 | self.task_seen = 0 32 | self.kd_manager = KdManager() 33 | self.error_list = [] 34 | self.new_class_score = [] 35 | self.old_class_score = [] 36 | self.fc_norm_new = [] 37 | self.fc_norm_old = [] 38 | self.bias_norm_new = [] 39 | self.bias_norm_old = [] 40 | self.lbl_inv_map = {} 41 | self.class_task_map = {} 42 | 43 | def before_train(self, x_train, y_train): 44 | new_labels = list(set(y_train.tolist())) 45 | self.new_labels += new_labels 46 | for i, lbl in enumerate(new_labels): 47 | self.lbl_inv_map[lbl] = len(self.old_labels) + i 48 | 49 | for i in new_labels: 50 | self.class_task_map[i] = self.task_seen 51 | 52 | @abstractmethod 53 | def train_learner(self, x_train, y_train): 54 | pass 55 | 56 | def after_train(self): 57 | #self.old_labels = list(set(self.old_labels + self.new_labels)) 58 | self.old_labels += self.new_labels 59 | self.new_labels_zombie = copy.deepcopy(self.new_labels) 60 | self.new_labels.clear() 61 | self.task_seen += 1 62 | if self.params.trick['review_trick'] and hasattr(self, 'buffer'): 63 | self.model.train() 64 | mem_x = self.buffer.buffer_img[:self.buffer.current_index] 65 | mem_y = self.buffer.buffer_label[:self.buffer.current_index] 66 | # criterion = torch.nn.CrossEntropyLoss(reduction='mean') 67 | if mem_x.size(0) > 0: 68 | rv_dataset = TensorDataset(mem_x, mem_y) 69 | rv_loader = DataLoader(rv_dataset, batch_size=self.params.eps_mem_batch, shuffle=True, num_workers=0, 70 | drop_last=True) 71 | for ep in range(1): 72 | for i, batch_data in enumerate(rv_loader): 73 | # batch update 74 | batch_x, batch_y = batch_data 75 | batch_x = maybe_cuda(batch_x, self.cuda) 76 | batch_y = maybe_cuda(batch_y, self.cuda) 77 | logits = self.model.forward(batch_x) 78 | if self.params.agent in ['SCR', 'SSCR']: 79 | logits = torch.cat([self.model.forward(batch_x).unsqueeze(1), 80 | self.model.forward(self.transform(batch_x)).unsqueeze(1)], dim=1) 81 | loss = self.criterion(logits, batch_y) 82 | self.opt.zero_grad() 83 | loss.backward() 84 | params = [p for p in self.model.parameters() if p.requires_grad and p.grad is not None] 85 | grad = [p.grad.clone()/10. for p in params] 86 | for g, p in zip(grad, params): 87 | p.grad.data.copy_(g) 88 | self.opt.step() 89 | 90 | if self.params.trick['kd_trick'] or self.params.agent == 'LWF': 91 | self.kd_manager.update_teacher(self.model) 92 | 93 | def criterion(self, logits, labels): 94 | labels = labels.clone() 95 | ce = torch.nn.CrossEntropyLoss(reduction='mean') 96 | if self.params.trick['labels_trick']: 97 | unq_lbls = labels.unique().sort()[0] 98 | for lbl_idx, lbl in enumerate(unq_lbls): 99 | labels[labels == lbl] = lbl_idx 100 | # Calcualte loss only over the heads appear in the batch: 101 | return ce(logits[:, unq_lbls], labels) 102 | elif self.params.trick['separated_softmax']: 103 | old_ss = F.log_softmax(logits[:, self.old_labels], dim=1) 104 | new_ss = F.log_softmax(logits[:, self.new_labels], dim=1) 105 | ss = torch.cat([old_ss, new_ss], dim=1) 106 | for i, lbl in enumerate(labels): 107 | labels[i] = self.lbl_inv_map[lbl.item()] 108 | return F.nll_loss(ss, labels) 109 | elif self.params.agent in ['SCR', 'SCP', 'SSCR']: 110 | SC = SupConLoss(temperature=self.params.temp) 111 | return SC(logits, labels) 112 | else: 113 | return ce(logits, labels) 114 | 115 | def forward(self, x): 116 | return self.model.forward(x) 117 | 118 | def evaluate(self, test_loaders): 119 | self.model.eval() 120 | acc_array = np.zeros(len(test_loaders)) 121 | if self.params.trick['ncm_trick'] or self.params.agent in ['ICARL', 'SCR', 'SCP', 'SSCR']: 122 | exemplar_means = {} 123 | cls_exemplar = {cls: [] for cls in self.old_labels} 124 | buffer_filled = self.buffer.current_index 125 | for x, y in zip(self.buffer.buffer_img[:buffer_filled], self.buffer.buffer_label[:buffer_filled]): 126 | cls_exemplar[y.item()].append(x) 127 | for cls, exemplar in cls_exemplar.items(): 128 | features = [] 129 | # Extract feature for each exemplar in p_y 130 | for ex in exemplar: 131 | feature = self.model.features(ex.unsqueeze(0)).detach().clone() 132 | feature = feature.squeeze() 133 | feature.data = feature.data / feature.data.norm() # Normalize 134 | features.append(feature) 135 | if len(features) == 0: 136 | mu_y = maybe_cuda(torch.normal(0, 1, size=tuple(self.model.features(x.unsqueeze(0)).detach().size())), self.cuda) 137 | mu_y = mu_y.squeeze() 138 | else: 139 | features = torch.stack(features) 140 | mu_y = features.mean(0).squeeze() 141 | mu_y.data = mu_y.data / mu_y.data.norm() # Normalize 142 | exemplar_means[cls] = mu_y 143 | with torch.no_grad(): 144 | if self.params.error_analysis: 145 | error = 0 146 | no = 0 147 | nn = 0 148 | oo = 0 149 | on = 0 150 | new_class_score = AverageMeter() 151 | old_class_score = AverageMeter() 152 | correct_lb = [] 153 | predict_lb = [] 154 | for task, test_loader in enumerate(test_loaders): 155 | acc = AverageMeter() 156 | for i, (batch_x, batch_y) in enumerate(test_loader): 157 | batch_x = maybe_cuda(batch_x, self.cuda) 158 | batch_y = maybe_cuda(batch_y, self.cuda) 159 | if self.params.trick['ncm_trick'] or self.params.agent in ['ICARL', 'SCR', 'SCP', 'SSCR']: 160 | feature = self.model.features(batch_x) # (batch_size, feature_size) 161 | for j in range(feature.size(0)): # Normalize 162 | feature.data[j] = feature.data[j] / feature.data[j].norm() 163 | feature = feature.unsqueeze(2) # (batch_size, feature_size, 1) 164 | means = torch.stack([exemplar_means[cls] for cls in self.old_labels]) # (n_classes, feature_size) 165 | 166 | #old ncm 167 | means = torch.stack([means] * batch_x.size(0)) # (batch_size, n_classes, feature_size) 168 | means = means.transpose(1, 2) 169 | feature = feature.expand_as(means) # (batch_size, feature_size, n_classes) 170 | dists = (feature - means).pow(2).sum(1).squeeze() # (batch_size, n_classes) 171 | _, pred_label = dists.min(1) 172 | # may be faster 173 | # feature = feature.squeeze(2).T 174 | # _, preds = torch.matmul(means, feature).max(0) 175 | correct_cnt = (np.array(self.old_labels)[ 176 | pred_label.tolist()] == batch_y.cpu().numpy()).sum().item() / batch_y.size(0) 177 | else: 178 | logits = self.model.forward(batch_x) 179 | _, pred_label = torch.max(logits, 1) 180 | correct_cnt = (pred_label == batch_y).sum().item()/batch_y.size(0) 181 | 182 | if self.params.error_analysis: 183 | correct_lb += [task] * len(batch_y) 184 | for i in pred_label: 185 | predict_lb.append(self.class_task_map[i.item()]) 186 | if task < self.task_seen-1: 187 | # old test 188 | total = (pred_label != batch_y).sum().item() 189 | wrong = pred_label[pred_label != batch_y] 190 | error += total 191 | on_tmp = sum([(wrong == i).sum().item() for i in self.new_labels_zombie]) 192 | oo += total - on_tmp 193 | on += on_tmp 194 | old_class_score.update(logits[:, list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item(), batch_y.size(0)) 195 | elif task == self.task_seen -1: 196 | # new test 197 | total = (pred_label != batch_y).sum().item() 198 | error += total 199 | wrong = pred_label[pred_label != batch_y] 200 | no_tmp = sum([(wrong == i).sum().item() for i in list(set(self.old_labels) - set(self.new_labels_zombie))]) 201 | no += no_tmp 202 | nn += total - no_tmp 203 | new_class_score.update(logits[:, self.new_labels_zombie].mean().item(), batch_y.size(0)) 204 | else: 205 | pass 206 | acc.update(correct_cnt, batch_y.size(0)) 207 | acc_array[task] = acc.avg() 208 | print(acc_array) 209 | if self.params.error_analysis: 210 | self.error_list.append((no, nn, oo, on)) 211 | self.new_class_score.append(new_class_score.avg()) 212 | self.old_class_score.append(old_class_score.avg()) 213 | print("no ratio: {}\non ratio: {}".format(no/(no+nn+0.1), on/(oo+on+0.1))) 214 | print(self.error_list) 215 | print(self.new_class_score) 216 | print(self.old_class_score) 217 | self.fc_norm_new.append(self.model.linear.weight[self.new_labels_zombie].mean().item()) 218 | self.fc_norm_old.append(self.model.linear.weight[list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item()) 219 | self.bias_norm_new.append(self.model.linear.bias[self.new_labels_zombie].mean().item()) 220 | self.bias_norm_old.append(self.model.linear.bias[list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item()) 221 | print(self.fc_norm_old) 222 | print(self.fc_norm_new) 223 | print(self.bias_norm_old) 224 | print(self.bias_norm_new) 225 | with open('confusion', 'wb') as fp: 226 | pickle.dump([correct_lb, predict_lb], fp) 227 | return acc_array 228 | --------------------------------------------------------------------------------