├── agents ├── __init__.py ├── __pycache__ │ ├── agem.cpython-36.pyc │ ├── base.cpython-36.pyc │ ├── lwf.cpython-36.pyc │ ├── scr.cpython-36.pyc │ ├── cndpm.cpython-36.pyc │ ├── ewc_pp.cpython-36.pyc │ ├── gdumb.cpython-36.pyc │ ├── icarl.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── exp_replay.cpython-36.pyc │ └── exp_replay_dvc.cpython-36.pyc ├── cndpm.py ├── lwf.py ├── icarl.py ├── scr.py ├── gdumb.py ├── agem.py ├── ewc_pp.py ├── exp_replay.py └── exp_replay_dvc.py ├── utils ├── __init__.py ├── buffer │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── buffer.cpython-36.pyc │ │ ├── aser_utils.cpython-36.pyc │ │ ├── mem_match.cpython-36.pyc │ │ ├── aser_retrieve.cpython-36.pyc │ │ ├── aser_update.cpython-36.pyc │ │ ├── buffer_utils.cpython-36.pyc │ │ ├── mgi_retrieve.cpython-36.pyc │ │ ├── mir_retrieve.cpython-36.pyc │ │ ├── sc_retrieve.cpython-36.pyc │ │ ├── random_retrieve.cpython-36.pyc │ │ ├── gss_greedy_update.cpython-36.pyc │ │ └── reservoir_update.cpython-36.pyc │ ├── random_retrieve.py │ ├── sc_retrieve.py │ ├── mem_match.py │ ├── buffer.py │ ├── aser_update.py │ ├── reservoir_update.py │ ├── mir_retrieve.py │ ├── mgi_retrieve.py │ ├── aser_retrieve.py │ ├── gss_greedy_update.py │ └── aser_utils.py ├── __pycache__ │ ├── io.cpython-36.pyc │ ├── loss.cpython-36.pyc │ ├── utils.cpython-36.pyc │ ├── L_softmax.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── kd_manager.cpython-36.pyc │ ├── name_match.cpython-36.pyc │ ├── global_vars.cpython-36.pyc │ └── setup_elements.cpython-36.pyc ├── kd_manager.py ├── global_vars.py ├── io.py ├── name_match.py ├── setup_elements.py ├── L_softmax.py ├── loss.py └── utils.py ├── continuum ├── __init__.py ├── dataset_scripts │ ├── __init__.py │ ├── __pycache__ │ │ ├── cifar10.cpython-36.pyc │ │ ├── core50.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── cifar100.cpython-36.pyc │ │ ├── openloris.cpython-36.pyc │ │ ├── dataset_base.cpython-36.pyc │ │ └── mini_imagenet.cpython-36.pyc │ ├── dataset_base.py │ ├── cifar100.py │ ├── cifar10.py │ ├── mini_imagenet.py │ ├── openloris.py │ └── core50.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── continuum.cpython-36.pyc │ ├── data_utils.cpython-36.pyc │ └── non_stationary.cpython-36.pyc ├── continuum.py ├── data_utils.py └── non_stationary.py ├── experiment ├── __init__.py ├── __pycache__ │ ├── run.cpython-36.pyc │ ├── metrics.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── tune_hyperparam.cpython-36.pyc ├── tune_hyperparam.py └── metrics.py ├── models ├── ndpm │ ├── __init__.py │ ├── __pycache__ │ │ ├── loss.cpython-36.pyc │ │ ├── ndpm.cpython-36.pyc │ │ ├── vae.cpython-36.pyc │ │ ├── expert.cpython-36.pyc │ │ ├── priors.cpython-36.pyc │ │ ├── utils.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── component.cpython-36.pyc │ │ └── classifier.cpython-36.pyc │ ├── utils.py │ ├── loss.py │ ├── priors.py │ ├── expert.py │ ├── component.py │ ├── ndpm.py │ ├── classifier.py │ └── vae.py ├── __init__.py ├── __pycache__ │ ├── resnet.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── pretrained.py └── resnet.py ├── __pycache__ └── loss.cpython-36.pyc ├── features ├── __pycache__ │ └── extractor.cpython-36.pyc ├── configs.py ├── regnet.py ├── lenet.py ├── efficientnet.py ├── extractor.py ├── wide_resnet.py └── vgg.py ├── requirement.txt ├── .idea ├── misc.xml ├── modules.xml ├── online-continual-learning-main-realori.iml └── inspectionProfiles │ └── Project_Default.xml ├── README.md └── loss.py /agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/ndpm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet -------------------------------------------------------------------------------- /__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/io.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/agem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/agem.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/lwf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/lwf.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/scr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/scr.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/cndpm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/cndpm.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/ewc_pp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/ewc_pp.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/gdumb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/gdumb.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/icarl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/icarl.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/run.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/run.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/ndpm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/ndpm.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/vae.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/L_softmax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/L_softmax.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/kd_manager.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/kd_manager.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/name_match.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/name_match.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/exp_replay.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/exp_replay.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /features/__pycache__/extractor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/features/__pycache__/extractor.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/expert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/expert.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/priors.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/priors.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/global_vars.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/global_vars.cpython-36.pyc -------------------------------------------------------------------------------- /agents/__pycache__/exp_replay_dvc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/exp_replay_dvc.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/continuum.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/continuum.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/component.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/component.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/setup_elements.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/setup_elements.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/buffer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/buffer.cpython-36.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/classifier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/classifier.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/aser_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/aser_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/mem_match.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/mem_match.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/non_stationary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/non_stationary.cpython-36.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/tune_hyperparam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/tune_hyperparam.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/aser_retrieve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/aser_retrieve.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/aser_update.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/aser_update.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/buffer_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/buffer_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/mgi_retrieve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/mgi_retrieve.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/mir_retrieve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/mir_retrieve.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/sc_retrieve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/sc_retrieve.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/random_retrieve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/random_retrieve.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/gss_greedy_update.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/gss_greedy_update.cpython-36.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/reservoir_update.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/reservoir_update.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/cifar10.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/cifar10.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/core50.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/core50.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/cifar100.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/cifar100.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/openloris.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/openloris.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/dataset_base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/dataset_base.cpython-36.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | torchvision==0.10.0 3 | matplotlib==3.3.4 4 | scikit-image==0.17.2 5 | timm==0.4.12 6 | kornia==0.5.4 7 | scikit-learn==0.24.2 8 | pandas==1.1.5 9 | psutil==5.8.0 10 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /models/pretrained.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | import torch 3 | 4 | def ResNet18_pretrained(n_classes): 5 | classifier = models.resnet18(pretrained=True) 6 | classifier.fc = torch.nn.Linear(512, n_classes) 7 | return classifier -------------------------------------------------------------------------------- /models/ndpm/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Lambda(nn.Module): 5 | def __init__(self, f=None): 6 | super().__init__() 7 | self.f = f if f is not None else (lambda x: x) 8 | 9 | def forward(self, *args, **kwargs): 10 | return self.f(*args, **kwargs) 11 | -------------------------------------------------------------------------------- /utils/buffer/random_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import random_retrieve 2 | 3 | class Random_retrieve(object): 4 | def __init__(self, params): 5 | super().__init__() 6 | self.num_retrieve = params.eps_mem_batch 7 | 8 | def retrieve(self, buffer, **kwargs): 9 | return random_retrieve(buffer, self.num_retrieve) -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/online-continual-learning-main-realori.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /utils/buffer/sc_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | import torch 3 | 4 | class Match_retrieve(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | self.num_retrieve = params.eps_mem_batch 8 | self.warmup = params.warmup 9 | 10 | def retrieve(self, buffer, **kwargs): 11 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 12 | cur_x, cur_y = kwargs['x'], kwargs['y'] 13 | return match_retrieve(buffer, cur_y) 14 | else: 15 | return torch.tensor([]), torch.tensor([]) -------------------------------------------------------------------------------- /features/configs.py: -------------------------------------------------------------------------------- 1 | 2 | vgg = { 3 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 4 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 5 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 6 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 7 | 'VGG7': [64, 'M', 128, 'M', 256, 'M', 512], 8 | } 9 | 10 | wide_resnet = { 11 | 'WideResNet28-10' : [28, 10, 0.3], 12 | 'WideResNet40-2' : [40, 2, 0.3], 13 | } 14 | 15 | shake_resnet = { 16 | 'ShakeResNet26-2x32d' : [26, 32], 17 | } 18 | -------------------------------------------------------------------------------- /features/regnet.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from . import extractor 3 | 4 | class RegNetX002(extractor.BaseModule): 5 | def __init__(self, config, name): 6 | super(RegNetX002, self).__init__() 7 | 8 | self.name = name 9 | self.features = timm.create_model('regnetx_002') 10 | self.n_features = 368 11 | 12 | def forward(self, x): 13 | return self.features.forward_features(x) 14 | 15 | class RegNetY004(extractor.BaseModule): 16 | def __init__(self, config, name): 17 | super(RegNetY004, self).__init__() 18 | 19 | self.name = name 20 | self.features = timm.create_model('regnety_004') 21 | self.n_features = 440 22 | 23 | def forward(self, x): 24 | return self.features.forward_features(x) 25 | 26 | -------------------------------------------------------------------------------- /utils/kd_manager.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def loss_fn_kd(scores, target_scores, T=2.): 7 | log_scores_norm = F.log_softmax(scores / T, dim=1) 8 | targets_norm = F.softmax(target_scores / T, dim=1) 9 | # Calculate distillation loss (see e.g., Li and Hoiem, 2017) 10 | kd_loss = (-1 * targets_norm * log_scores_norm).sum(dim=1).mean() * T ** 2 11 | return kd_loss 12 | 13 | 14 | class KdManager: 15 | def __init__(self): 16 | self.teacher_model = None 17 | 18 | def update_teacher(self, model): 19 | self.teacher_model = copy.deepcopy(model) 20 | 21 | def get_kd_loss(self, cur_model_logits, x): 22 | if self.teacher_model is not None: 23 | with torch.no_grad(): 24 | prev_model_logits = self.teacher_model.forward(x) 25 | dist_loss = loss_fn_kd(cur_model_logits, prev_model_logits) 26 | else: 27 | dist_loss = 0 28 | return dist_loss 29 | -------------------------------------------------------------------------------- /features/lenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch.nn as nn 6 | from . import extractor 7 | 8 | 9 | class LeNet(extractor.BaseModule): 10 | def __init__(self, config, name): 11 | super(LeNet, self).__init__() 12 | self.name = name 13 | in_channels = config["channels"] 14 | self.conv1 = nn.Conv2d(in_channels=in_channels, 15 | out_channels=6, 16 | kernel_size=5) 17 | self.conv2 = nn.Conv2d(in_channels=6, 18 | out_channels=16, 19 | kernel_size=5) 20 | self.pool = nn.MaxPool2d(2) 21 | self.n_features = 400 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | def forward(self, x): 25 | x = self.relu(self.conv1(x)) 26 | x = self.pool(x) 27 | x = self.relu(self.conv2(x)) 28 | return x 29 | -------------------------------------------------------------------------------- /utils/buffer/mem_match.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | from utils.buffer.buffer_utils import random_retrieve 3 | import torch 4 | 5 | class MemMatch_retrieve(object): 6 | def __init__(self, params): 7 | super().__init__() 8 | self.num_retrieve = params.eps_mem_batch 9 | self.warmup = params.warmup 10 | 11 | 12 | def retrieve(self, buffer, **kwargs): 13 | match_x, match_y = torch.tensor([]), torch.tensor([]) 14 | candidate_x, candidate_y = torch.tensor([]), torch.tensor([]) 15 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 16 | while match_x.size(0) == 0: 17 | candidate_x, candidate_y, indices = random_retrieve(buffer, self.num_retrieve,return_indices=True) 18 | if candidate_x.size(0) == 0: 19 | return candidate_x, candidate_y, match_x, match_y 20 | match_x, match_y = match_retrieve(buffer, candidate_y, indices) 21 | return candidate_x, candidate_y, match_x, match_y -------------------------------------------------------------------------------- /features/efficientnet.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from . import extractor 3 | 4 | class EfficientNetB1(extractor.BaseModule): 5 | def __init__(self, config, name): 6 | super(EfficientNetB1, self).__init__() 7 | 8 | self.name = name 9 | self.features = timm.create_model('efficientnet_b1') 10 | self.n_features = 1280 11 | 12 | def forward(self, x): 13 | return self.features.forward_features(x) 14 | 15 | class EfficientNetB0(extractor.BaseModule): 16 | def __init__(self, config, name): 17 | super(EfficientNetB0, self).__init__() 18 | 19 | self.name = name 20 | self.features = timm.create_model('efficientnet_b0') 21 | self.n_features = 1280 22 | 23 | def forward(self, x): 24 | return self.features.forward_features(x) 25 | 26 | class EfficientNetV2S(extractor.BaseModule): 27 | def __init__(self, config, name): 28 | super(EfficientNetV2S, self).__init__() 29 | 30 | self.name = name 31 | drop_rate = config['dropout'] 32 | self.features = timm.create_model('efficientnetv2_s', drop_rate=drop_rate) 33 | self.n_features = 0 #FIXME: Rdit this 34 | 35 | def forward(self, x): 36 | return self.features(x) 37 | -------------------------------------------------------------------------------- /continuum/continuum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # other imports 4 | from utils.name_match import data_objects 5 | 6 | class continuum(object): 7 | def __init__(self, dataset, scenario, params): 8 | """" Initialize Object """ 9 | self.data_object = data_objects[dataset](scenario, params) 10 | self.run = params.num_runs 11 | self.task_nums = self.data_object.task_nums 12 | self.cur_task = 0 13 | self.cur_run = -1 14 | 15 | def __iter__(self): 16 | return self 17 | 18 | def __next__(self): 19 | if self.cur_task == self.data_object.task_nums: 20 | raise StopIteration 21 | x_train, y_train, labels = self.data_object.new_task(self.cur_task, cur_run=self.cur_run) 22 | self.cur_task += 1 23 | return x_train, y_train, labels 24 | 25 | def test_data(self): 26 | return self.data_object.get_test_set() 27 | 28 | def clean_mem_test_set(self): 29 | self.data_object.clean_mem_test_set() 30 | 31 | def reset_run(self): 32 | self.cur_task = 0 33 | 34 | def new_run(self): 35 | self.cur_task = 0 36 | self.cur_run += 1 37 | self.data_object.new_run(cur_run=self.cur_run) 38 | 39 | 40 | -------------------------------------------------------------------------------- /utils/global_vars.py: -------------------------------------------------------------------------------- 1 | MODLES_NDPM_VAE_NF_BASE = 32 2 | MODELS_NDPM_VAE_NF_EXT = 4 3 | MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER = False 4 | MODELS_NDPM_VAE_Z_DIM = 64 5 | MODELS_NDPM_VAE_RECON_LOSS = 'gaussian' 6 | MODELS_NDPM_VAE_LEARN_X_LOG_VAR = False 7 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM = 0 8 | MODELS_NDPM_VAE_Z_SAMPLES = 16 9 | MODELS_NDPM_CLASSIFIER_NUM_BLOCKS = [1, 1, 1, 1] 10 | MODELS_NDPM_CLASSIFIER_NORM_LAYER = 'InstanceNorm2d' 11 | MODELS_NDPM_CLASSIFIER_CLS_NF_BASE = 20 12 | MODELS_NDPM_CLASSIFIER_CLS_NF_EXT = 4 13 | MODELS_NDPM_NDPM_DISABLE_D = False 14 | MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS = False 15 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE = 50 16 | MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS = 0 17 | MODELS_NDPM_NDPM_SLEEP_STEP_G = 4000 18 | MODELS_NDPM_NDPM_SLEEP_STEP_D = 1000 19 | MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE = 0 20 | MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP = 500 21 | MODELS_NDPM_NDPM_WEIGHT_DECAY = 0.00001 22 | MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY = False 23 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 24 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 25 | MODELS_NDPM_COMPONENT_CLIP_GRAD = {'type': 'value', 'options': {'clip_value': 0.5}} 26 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/dataset_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | 4 | class DatasetBase(ABC): 5 | def __init__(self, dataset, scenario, task_nums, run, params): 6 | super(DatasetBase, self).__init__() 7 | self.params = params 8 | self.scenario = scenario 9 | self.dataset = dataset 10 | self.task_nums = task_nums 11 | self.run = run 12 | self.root = os.path.join('./datasets', self.dataset) 13 | self.test_set = [] 14 | self.val_set = [] 15 | self._is_properly_setup() 16 | self.download_load() 17 | 18 | 19 | @abstractmethod 20 | def download_load(self): 21 | pass 22 | 23 | @abstractmethod 24 | def setup(self, **kwargs): 25 | pass 26 | 27 | @abstractmethod 28 | def new_task(self, cur_task, **kwargs): 29 | pass 30 | 31 | def _is_properly_setup(self): 32 | pass 33 | 34 | @abstractmethod 35 | def new_run(self, **kwargs): 36 | pass 37 | 38 | @property 39 | def dataset_info(self): 40 | return self.dataset 41 | 42 | def get_test_set(self): 43 | return self.test_set 44 | 45 | def clean_mem_test_set(self): 46 | self.test_set = None 47 | self.test_data = None 48 | self.test_label = None -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pandas as pd 3 | import os 4 | import psutil 5 | import torch 6 | 7 | def load_yaml(path, key='parameters'): 8 | with open(path, 'r') as stream: 9 | try: 10 | return yaml.load(stream, Loader=yaml.FullLoader)[key] 11 | except yaml.YAMLError as exc: 12 | print(exc) 13 | 14 | def save_dataframe_csv(df, path, name): 15 | df.to_csv(path + '/' + name, index=False) 16 | 17 | 18 | def load_dataframe_csv(path, name=None, delimiter=None, names=None): 19 | if not name: 20 | return pd.read_csv(path, delimiter=delimiter, names=names) 21 | else: 22 | return pd.read_csv(path+name, delimiter=delimiter, names=names) 23 | 24 | def check_ram_usage(): 25 | """ 26 | Compute the RAM usage of the current process. 27 | Returns: 28 | mem (float): Memory occupation in Megabytes 29 | """ 30 | 31 | process = psutil.Process(os.getpid()) 32 | mem = process.memory_info().rss / (1024 * 1024) 33 | 34 | return mem 35 | 36 | def save_model(model, optimizer, opt, epoch, save_file): 37 | print('==> Saving...') 38 | state = { 39 | 'opt': opt, 40 | 'model': model.state_dict(), 41 | 'optimizer': optimizer.state_dict(), 42 | 'epoch': epoch, 43 | } 44 | torch.save(state, save_file) 45 | del state 46 | -------------------------------------------------------------------------------- /features/extractor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseModule(nn.Module): 8 | def __init__(self): 9 | super(BaseModule, self).__init__() 10 | self.n_features = 0 11 | self._name = "BaseModule" 12 | 13 | def forward(self, x): 14 | return x 15 | 16 | @property 17 | def name(self): 18 | return self._name 19 | 20 | @name.setter 21 | def name(self, name): 22 | self._name = name 23 | 24 | def init_weights(self, std=0.01): 25 | print("Initialize weights of %s with normal dist: mean=0, std=%0.2f" % (type(self), std)) 26 | for m in self.modules(): 27 | if type(m) == nn.Linear: 28 | nn.init.normal_(m.weight, 0, std) 29 | if m.bias is not None: 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.BatchNorm2d): 32 | nn.init.constant_(m.weight, 1) 33 | if m.bias is not None: 34 | m.bias.data.zero_() 35 | elif type(m) == nn.Conv2d: 36 | nn.init.normal_(m.weight, 0, std) 37 | if m.bias is not None: 38 | m.bias.data.zero_() 39 | 40 | 41 | if __name__ == '__main__': 42 | net = BaseModule() 43 | print(net) 44 | print("n_features:", net.n_features) 45 | -------------------------------------------------------------------------------- /models/ndpm/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.functional import binary_cross_entropy 5 | 6 | 7 | def gaussian_nll(x, mean, log_var, min_noise=0.001): 8 | return ( 9 | ((x - mean) ** 2 + min_noise) / (2 * log_var.exp() + 1e-8) 10 | + 0.5 * log_var + 0.5 * np.log(2 * np.pi) 11 | ) 12 | 13 | 14 | def laplace_nll(x, median, log_scale, min_noise=0.01): 15 | return ( 16 | ((x - median).abs() + min_noise) / (log_scale.exp() + 1e-8) 17 | + log_scale + np.log(2) 18 | ) 19 | 20 | 21 | def bernoulli_nll(x, p): 22 | # Broadcast 23 | x_exp, p_exp = [], [] 24 | for x_size, p_size in zip(x.size(), p.size()): 25 | if x_size > p_size: 26 | x_exp.append(-1) 27 | p_exp.append(x_size) 28 | elif x_size < p_size: 29 | x_exp.append(p_size) 30 | p_exp.append(-1) 31 | else: 32 | x_exp.append(-1) 33 | p_exp.append(-1) 34 | x = x.expand(*x_exp) 35 | p = p.expand(*p_exp) 36 | 37 | return binary_cross_entropy(p, x, reduction='none') 38 | 39 | 40 | def logistic_nll(x, mean, log_scale): 41 | bin_size = 1 / 256 42 | scale = log_scale.exp() 43 | x_centered = x - mean 44 | cdf1 = x_centered / scale 45 | cdf2 = (x_centered + bin_size) / scale 46 | p = torch.sigmoid(cdf2) - torch.sigmoid(cdf1) + 1e-12 47 | return -p.log() 48 | -------------------------------------------------------------------------------- /agents/cndpm.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from models.ndpm.ndpm import Ndpm 4 | from utils.setup_elements import transforms_match 5 | from torch.utils import data 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class Cndpm(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Cndpm, self).__init__(model, opt, params) 13 | self.model = model 14 | 15 | 16 | def train_learner(self, x_train, y_train): 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | # setup tracker 22 | losses_batch = AverageMeter() 23 | acc_batch = AverageMeter() 24 | 25 | self.model.train() 26 | 27 | for ep in range(self.epoch): 28 | for i, batch_data in enumerate(train_loader): 29 | # batch update 30 | batch_x, batch_y = batch_data 31 | batch_x = maybe_cuda(batch_x, self.cuda) 32 | batch_y = maybe_cuda(batch_y, self.cuda) 33 | self.model.learn(batch_x, batch_y) 34 | if self.params.verbose: 35 | print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format( 36 | i, 37 | len(self.model.stm_x), self.params.stm_capacity, 38 | len(self.model.experts) - 1 39 | ), end='') 40 | print() 41 | -------------------------------------------------------------------------------- /utils/buffer/buffer.py: -------------------------------------------------------------------------------- 1 | from utils.setup_elements import input_size_match 2 | from utils import name_match #import update_methods, retrieve_methods 3 | from utils.utils import maybe_cuda 4 | import torch 5 | from utils.buffer.buffer_utils import BufferClassTracker 6 | from utils.setup_elements import n_classes 7 | 8 | class Buffer(torch.nn.Module): 9 | def __init__(self, model, params): 10 | super().__init__() 11 | self.params = params 12 | self.model = model 13 | self.cuda = self.params.cuda 14 | self.current_index = 0 15 | self.n_seen_so_far = 0 16 | self.device = "cuda" if self.params.cuda else "cpu" 17 | 18 | # define buffer 19 | buffer_size = params.mem_size 20 | print('buffer has %d slots' % buffer_size) 21 | input_size = input_size_match[params.data] 22 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0)) 23 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0)) 24 | 25 | # registering as buffer allows us to save the object using `torch.save` 26 | self.register_buffer('buffer_img', buffer_img) 27 | self.register_buffer('buffer_label', buffer_label) 28 | 29 | # define update and retrieve method 30 | self.update_method = name_match.update_methods[params.update](params) 31 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params) 32 | 33 | if self.params.buffer_tracker: 34 | self.buffer_tracker = BufferClassTracker(n_classes[params.data], self.device) 35 | 36 | def update(self, x, y,**kwargs): 37 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs) 38 | 39 | 40 | def retrieve(self, **kwargs): 41 | return self.retrieve_method.retrieve(buffer=self, **kwargs) -------------------------------------------------------------------------------- /models/ndpm/priors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | from utils.utils import maybe_cuda 5 | 6 | 7 | class Prior(ABC): 8 | def __init__(self, params): 9 | self.params = params 10 | 11 | @abstractmethod 12 | def add_expert(self): 13 | pass 14 | 15 | @abstractmethod 16 | def record_usage(self, usage, index=None): 17 | pass 18 | 19 | @abstractmethod 20 | def nl_prior(self, normalize=False): 21 | pass 22 | 23 | 24 | class CumulativePrior(Prior): 25 | def __init__(self, params): 26 | super().__init__(params) 27 | self.log_counts = maybe_cuda(torch.tensor( 28 | params.log_alpha 29 | )).float().unsqueeze(0) 30 | 31 | def add_expert(self): 32 | self.log_counts = torch.cat( 33 | [self.log_counts, maybe_cuda(torch.zeros(1))], 34 | dim=0 35 | ) 36 | 37 | def record_usage(self, usage, index=None): 38 | """Record expert usage 39 | 40 | Args: 41 | usage: Tensor of shape [K+1] if index is None else scalar 42 | index: expert index 43 | """ 44 | if index is None: 45 | self.log_counts = torch.logsumexp(torch.stack([ 46 | self.log_counts, 47 | usage.log() 48 | ], dim=1), dim=1) 49 | else: 50 | self.log_counts[index] = torch.logsumexp(torch.stack([ 51 | self.log_counts[index], 52 | maybe_cuda(torch.tensor(usage)).float().log() 53 | ], dim=0), dim=0) 54 | 55 | def nl_prior(self, normalize=False): 56 | nl_prior = -self.log_counts 57 | if normalize: 58 | nl_prior += torch.logsumexp(self.log_counts, dim=0) 59 | return nl_prior 60 | 61 | @property 62 | def counts(self): 63 | return self.log_counts.exp() 64 | -------------------------------------------------------------------------------- /utils/name_match.py: -------------------------------------------------------------------------------- 1 | from agents.gdumb import Gdumb 2 | from continuum.dataset_scripts.cifar100 import CIFAR100 3 | from continuum.dataset_scripts.cifar10 import CIFAR10 4 | from continuum.dataset_scripts.core50 import CORE50 5 | from continuum.dataset_scripts.mini_imagenet import Mini_ImageNet 6 | from continuum.dataset_scripts.openloris import OpenLORIS 7 | from agents.exp_replay import ExperienceReplay 8 | from agents.exp_replay_dvc import ExperienceReplay_DVC 9 | from agents.agem import AGEM 10 | from agents.ewc_pp import EWC_pp 11 | from agents.cndpm import Cndpm 12 | from agents.lwf import Lwf 13 | from agents.icarl import Icarl 14 | from agents.scr import SupContrastReplay 15 | from utils.buffer.random_retrieve import Random_retrieve 16 | from utils.buffer.reservoir_update import Reservoir_update 17 | from utils.buffer.mir_retrieve import MIR_retrieve 18 | 19 | from utils.buffer.gss_greedy_update import GSSGreedyUpdate 20 | from utils.buffer.aser_retrieve import ASER_retrieve 21 | from utils.buffer.aser_update import ASER_update 22 | from utils.buffer.sc_retrieve import Match_retrieve 23 | from utils.buffer.mem_match import MemMatch_retrieve 24 | from utils.buffer.mgi_retrieve import MGI_retrieve 25 | 26 | data_objects = { 27 | 'cifar100': CIFAR100, 28 | 'cifar10': CIFAR10, 29 | 'core50': CORE50, 30 | 'mini_imagenet': Mini_ImageNet, 31 | 'openloris': OpenLORIS 32 | } 33 | 34 | agents = { 35 | 'ER': ExperienceReplay, 36 | 'ER_DVC': ExperienceReplay_DVC, 37 | 'EWC': EWC_pp, 38 | 'AGEM': AGEM, 39 | 'CNDPM': Cndpm, 40 | 'LWF': Lwf, 41 | 'ICARL': Icarl, 42 | 'GDUMB': Gdumb, 43 | 'SCR': SupContrastReplay, 44 | } 45 | 46 | retrieve_methods = { 47 | 'MIR': MIR_retrieve, 48 | 'MGI':MGI_retrieve, 49 | 'random': Random_retrieve, 50 | 'ASER': ASER_retrieve, 51 | 'match': Match_retrieve, 52 | 'mem_match': MemMatch_retrieve 53 | 54 | } 55 | 56 | update_methods = { 57 | 'random': Reservoir_update, 58 | 'GSS': GSSGreedyUpdate, 59 | 'ASER': ASER_update, 60 | } 61 | 62 | -------------------------------------------------------------------------------- /utils/buffer/aser_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.reservoir_update import Reservoir_update 3 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve 4 | from utils.buffer.aser_utils import compute_knn_sv, add_minority_class_input 5 | from utils.setup_elements import n_classes 6 | from utils.utils import nonzero_indices, maybe_cuda 7 | 8 | 9 | class ASER_update(object): 10 | def __init__(self, params, **kwargs): 11 | super().__init__() 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.num_tasks = params.num_tasks 16 | self.out_dim = n_classes[params.data] 17 | self.n_smp_cls = int(params.n_smp_cls) 18 | self.n_total_smp = int(params.n_smp_cls * self.out_dim) 19 | self.reservoir_update = Reservoir_update(params) 20 | ClassBalancedRandomSampling.class_index_cache = None 21 | 22 | def update(self, buffer, x, y, **kwargs): 23 | model = buffer.model 24 | 25 | place_left = self.mem_size - buffer.current_index 26 | 27 | # If buffer is not filled, use available space to store whole or part of batch 28 | if place_left: 29 | x_fit = x[:place_left] 30 | y_fit = y[:place_left] 31 | 32 | ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device) 33 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 34 | new_y=y_fit, ind=ind, device=self.device) 35 | self.reservoir_update.update(buffer, x_fit, y_fit) 36 | 37 | # If buffer is filled, update buffer by sv 38 | if buffer.current_index == self.mem_size: 39 | self.reservoir_update.update(buffer, x_fit, y_fit) 40 | 41 | 42 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 43 | new_y=y_upt, ind=ind_buffer, device=self.device) 44 | 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DVC 2 | Code For CVPR2022 paper "[Not Just Selection, but Exploration: Online Class-Incremental Continual Learning via Dual View Consistency](https://openaccess.thecvf.com/content/CVPR2022/papers/Gu_Not_Just_Selection_but_Exploration_Online_Class-Incremental_Continual_Learning_via_CVPR_2022_paper.pdf)" 3 | 4 | ## Usage 5 | 6 | ### Requirements 7 | requirements.txt 8 | 9 | ### Data preparation 10 | - CIFAR10 & CIFAR100 will be downloaded during the first run. (datasets/cifar10;/datasets/cifar100) 11 | - Mini-ImageNet: Download from https://www.kaggle.com/whitemoon/miniimagenet/download, and place it in datasets/mini_imagenet/ 12 | 13 | 14 | ### CIFAR-100 15 | ```shell 16 | python general_main.py --data cifar100 --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 1000 --dl_weight 4.0 17 | ``` 18 | 19 | ### CIFAR-10 20 | ```shell 21 | python general_main.py --data cifar10 --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 200 --dl_weight 2.0 --num_tasks 5 22 | ``` 23 | 24 | ### Mini-Imagenet 25 | ```shell 26 | python general_main.py --data mini_imagenet --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 1000 --dl_weight 4.0 27 | ``` 28 | 29 | 30 | 31 | 32 | 33 | ## Reference 34 | [online-continual-learning](https://github.com/RaptorMai/online-continual-learning) 35 | 36 | [agmax](https://github.com/roatienza/agmax) 37 | 38 | 39 | If our code or models help your work, please cite our [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Gu_Not_Just_Selection_but_Exploration_Online_Class-Incremental_Continual_Learning_via_CVPR_2022_paper.pdf): 40 | 41 | ```shell 42 | @InProceedings{Gu_2022_CVPR, 43 | author = {Gu, Yanan and Yang, Xu and Wei, Kun and Deng, Cheng}, 44 | title = {Not Just Selection, but Exploration: Online Class-Incremental Continual Learning via Dual View Consistency}, 45 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 46 | month = {June}, 47 | year = {2022}, 48 | pages = {7442-7451} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /experiment/tune_hyperparam.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from sklearn.model_selection import ParameterGrid 3 | from utils.setup_elements import setup_opt, setup_architecture 4 | from utils.utils import maybe_cuda 5 | from utils.name_match import agents 6 | import numpy as np 7 | from experiment.metrics import compute_performance 8 | 9 | 10 | def tune_hyper(tune_data, tune_test_loaders, default_params, tune_params): 11 | param_grid_list = list(ParameterGrid(tune_params)) 12 | print(len(param_grid_list)) 13 | tune_accs = [] 14 | tune_fgt = [] 15 | for param_set in param_grid_list: 16 | final_params = vars(default_params) 17 | print(param_set) 18 | final_params.update(param_set) 19 | final_params = SimpleNamespace(**final_params) 20 | accuracy_list = [] 21 | for run in range(final_params.num_runs_val): 22 | tmp_acc = [] 23 | model = setup_architecture(final_params) 24 | model = maybe_cuda(model, final_params.cuda) 25 | opt = setup_opt(final_params.optimizer, model, final_params.learning_rate, final_params.weight_decay) 26 | agent = agents[final_params.agent](model, opt, final_params) 27 | for i, (x_train, y_train, labels) in enumerate(tune_data): 28 | print("-----------tune run {} task {}-------------".format(run, i)) 29 | print('size: {}, {}'.format(x_train.shape, y_train.shape)) 30 | agent.train_learner(x_train, y_train) 31 | acc_array = agent.evaluate(tune_test_loaders) 32 | tmp_acc.append(acc_array) 33 | print( 34 | "-----------tune run {}-----------avg_end_acc {}-----------".format(run, np.mean(tmp_acc[-1]))) 35 | accuracy_list.append(np.array(tmp_acc)) 36 | accuracy_list = np.array(accuracy_list) 37 | avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt = compute_performance(accuracy_list) 38 | tune_accs.append(avg_end_acc[0]) 39 | tune_fgt.append(avg_end_fgt[0]) 40 | best_tune = param_grid_list[tune_accs.index(max(tune_accs))] 41 | return best_tune -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from timm.loss import LabelSmoothingCrossEntropy 6 | 7 | 8 | def cross_entropy_loss(z, zt, ytrue, label_smoothing=0): 9 | zz = torch.cat((z, zt)) 10 | yy = torch.cat((ytrue, ytrue)) 11 | if label_smoothing > 0: 12 | ce = LabelSmoothingCrossEntropy(label_smoothing)(zz, yy) 13 | else: 14 | ce = nn.CrossEntropyLoss()(zz, yy) 15 | return ce 16 | 17 | 18 | def cross_entropy(z, zt): 19 | # eps = np.finfo(float).eps 20 | Pz = F.softmax(z, dim=1) 21 | Pzt = F.softmax(zt, dim=1) 22 | # make sure no zero for log 23 | # Pz [(Pz < eps).data] = eps 24 | # Pzt [(Pzt < eps).data] = eps 25 | return -(Pz * torch.log(Pzt)).mean() 26 | 27 | 28 | def agmax_loss(y, ytrue, dl_weight=1.0): 29 | z, zt, zzt,_ = y 30 | Pz = F.softmax(z, dim=1) 31 | Pzt = F.softmax(zt, dim=1) 32 | Pzzt = F.softmax(zzt, dim=1) 33 | 34 | dl_loss = nn.L1Loss() 35 | yy = torch.cat((Pz, Pzt)) 36 | zz = torch.cat((Pzzt, Pzzt)) 37 | dl = dl_loss(zz, yy) 38 | dl *= dl_weight 39 | 40 | # -1/3*(H(z) + H(zt) + H(z, zt)), H(x) = -E[log(x)] 41 | entropy = entropy_loss(Pz, Pzt, Pzzt) 42 | return entropy, dl 43 | 44 | 45 | 46 | 47 | def clamp_to_eps(Pz, Pzt, Pzzt): 48 | eps = np.finfo(float).eps 49 | # make sure no zero for log 50 | Pz[(Pz < eps).data] = eps 51 | Pzt[(Pzt < eps).data] = eps 52 | Pzzt[(Pzzt < eps).data] = eps 53 | 54 | return Pz, Pzt, Pzzt 55 | 56 | 57 | def batch_probability(Pz, Pzt, Pzzt): 58 | Pz = Pz.sum(dim=0) 59 | Pzt = Pzt.sum(dim=0) 60 | Pzzt = Pzzt.sum(dim=0) 61 | 62 | Pz = Pz / Pz.sum() 63 | Pzt = Pzt / Pzt.sum() 64 | Pzzt = Pzzt / Pzzt.sum() 65 | 66 | # return Pz, Pzt, Pzzt 67 | return clamp_to_eps(Pz, Pzt, Pzzt) 68 | 69 | 70 | def entropy_loss(Pz, Pzt, Pzzt): 71 | # negative entropy loss 72 | Pz, Pzt, Pzzt = batch_probability(Pz, Pzt, Pzzt) 73 | entropy = (Pz * torch.log(Pz)).sum() 74 | entropy += (Pzt * torch.log(Pzt)).sum() 75 | entropy += (Pzzt * torch.log(Pzzt)).sum() 76 | entropy /= 3 77 | return entropy 78 | 79 | 80 | -------------------------------------------------------------------------------- /models/ndpm/expert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.ndpm.classifier import ResNetSharingClassifier 4 | from models.ndpm.vae import CnnSharingVae 5 | from utils.utils import maybe_cuda 6 | 7 | from utils.global_vars import * 8 | 9 | 10 | class Expert(nn.Module): 11 | def __init__(self, params, experts=()): 12 | super().__init__() 13 | self.id = len(experts) 14 | self.experts = experts 15 | 16 | self.g = maybe_cuda(CnnSharingVae(params, experts)) 17 | self.d = maybe_cuda(ResNetSharingClassifier(params, experts)) if not MODELS_NDPM_NDPM_DISABLE_D else None 18 | 19 | 20 | # use random initialized g if it's a placeholder 21 | if self.id == 0: 22 | self.eval() 23 | for p in self.g.parameters(): 24 | p.requires_grad = False 25 | 26 | # use random initialized d if it's a placeholder 27 | if self.id == 0 and self.d is not None: 28 | for p in self.d.parameters(): 29 | p.requires_grad = False 30 | 31 | def forward(self, x): 32 | return self.d(x) 33 | 34 | def nll(self, x, y, step=None): 35 | """Negative log likelihood""" 36 | nll = self.g.nll(x, step) 37 | if self.d is not None: 38 | d_nll = self.d.nll(x, y, step) 39 | nll = nll + d_nll 40 | return nll 41 | 42 | def collect_nll(self, x, y, step=None): 43 | if self.id == 0: 44 | nll = self.nll(x, y, step) 45 | return nll.unsqueeze(1) 46 | 47 | nll = self.g.collect_nll(x, step) 48 | if self.d is not None: 49 | d_nll = self.d.collect_nll(x, y, step) 50 | nll = nll + d_nll 51 | 52 | return nll 53 | 54 | def lr_scheduler_step(self): 55 | if self.g.lr_scheduler is not NotImplemented: 56 | self.g.lr_scheduler.step() 57 | if self.d is not None and self.d.lr_scheduler is not NotImplemented: 58 | self.d.lr_scheduler.step() 59 | 60 | def clip_grad(self): 61 | self.g.clip_grad() 62 | if self.d is not None: 63 | self.d.clip_grad() 64 | 65 | def optimizer_step(self): 66 | self.g.optimizer.step() 67 | if self.d is not None: 68 | self.d.optimizer.step() 69 | -------------------------------------------------------------------------------- /agents/lwf.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | import copy 8 | 9 | 10 | class Lwf(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Lwf, self).__init__(model, opt, params) 13 | 14 | def train_learner(self, x_train, y_train): 15 | self.before_train(x_train, y_train) 16 | 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | 22 | # set up model 23 | self.model = self.model.train() 24 | 25 | # setup tracker 26 | losses_batch = AverageMeter() 27 | acc_batch = AverageMeter() 28 | 29 | for ep in range(self.epoch): 30 | for i, batch_data in enumerate(train_loader): 31 | # batch update 32 | batch_x, batch_y = batch_data 33 | batch_x = maybe_cuda(batch_x, self.cuda) 34 | batch_y = maybe_cuda(batch_y, self.cuda) 35 | 36 | logits = self.forward(batch_x) 37 | loss_old = self.kd_manager.get_kd_loss(logits, batch_x) 38 | loss_new = self.criterion(logits, batch_y) 39 | loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old 40 | _, pred_label = torch.max(logits, 1) 41 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 42 | # update tracker 43 | acc_batch.update(correct_cnt, batch_y.size(0)) 44 | losses_batch.update(loss, batch_y.size(0)) 45 | # backward 46 | self.opt.zero_grad() 47 | loss.backward() 48 | self.opt.step() 49 | 50 | if i % 100 == 1 and self.verbose: 51 | print( 52 | '==>>> it: {}, avg. loss: {:.6f}, ' 53 | 'running train acc: {:.3f}' 54 | .format(i, losses_batch.avg(), acc_batch.avg()) 55 | ) 56 | self.after_train() 57 | -------------------------------------------------------------------------------- /experiment/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import sem 3 | import scipy.stats as stats 4 | 5 | def compute_performance(end_task_acc_arr): 6 | """ 7 | Given test accuracy results from multiple runs saved in end_task_acc_arr, 8 | compute the average accuracy, forgetting, and task accuracies as well as their confidence intervals. 9 | 10 | :param end_task_acc_arr: (list) List of lists 11 | :param task_ids: (list or tuple) Task ids to keep track of 12 | :return: (avg_end_acc, forgetting, avg_acc_task) 13 | """ 14 | n_run, n_tasks = end_task_acc_arr.shape[:2] 15 | t_coef = stats.t.ppf((1+0.95) / 2, n_run-1) # t coefficient used to compute 95% CIs: mean +- t * 16 | 17 | # compute average test accuracy and CI 18 | end_acc = end_task_acc_arr[:, -1, :] # shape: (num_run, num_task) 19 | avg_acc_per_run = np.mean(end_acc, axis=1) # mean of end task accuracies per run 20 | avg_end_acc = (np.mean(avg_acc_per_run), t_coef * sem(avg_acc_per_run)) 21 | 22 | # compute forgetting 23 | best_acc = np.max(end_task_acc_arr, axis=1) 24 | final_forgets = best_acc - end_acc 25 | avg_fgt = np.mean(final_forgets, axis=1) 26 | avg_end_fgt = (np.mean(avg_fgt), t_coef * sem(avg_fgt)) 27 | 28 | # compute ACC 29 | acc_per_run = np.mean((np.sum(np.tril(end_task_acc_arr), axis=2) / 30 | (np.arange(n_tasks) + 1)), axis=1) 31 | avg_acc = (np.mean(acc_per_run), t_coef * sem(acc_per_run)) 32 | 33 | 34 | # compute BWT+ 35 | bwt_per_run = (np.sum(np.tril(end_task_acc_arr, -1), axis=(1,2)) - 36 | np.sum(np.diagonal(end_task_acc_arr, axis1=1, axis2=2) * 37 | (np.arange(n_tasks, 0, -1) - 1), axis=1)) / (n_tasks * (n_tasks - 1) / 2) 38 | bwtp_per_run = np.maximum(bwt_per_run, 0) 39 | avg_bwtp = (np.mean(bwtp_per_run), t_coef * sem(bwtp_per_run)) 40 | 41 | # compute FWT 42 | fwt_per_run = np.sum(np.triu(end_task_acc_arr, 1), axis=(1,2)) / (n_tasks * (n_tasks - 1) / 2) 43 | avg_fwt = (np.mean(fwt_per_run), t_coef * sem(fwt_per_run)) 44 | return avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt 45 | 46 | 47 | 48 | 49 | def single_run_avg_end_fgt(acc_array): 50 | best_acc = np.max(acc_array, axis=1) 51 | end_acc = acc_array[-1] 52 | final_forgets = best_acc - end_acc 53 | avg_fgt = np.mean(final_forgets) 54 | return avg_fgt 55 | -------------------------------------------------------------------------------- /utils/buffer/reservoir_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Reservoir_update(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | 8 | def update(self, buffer, x, y, **kwargs): 9 | batch_size = x.size(0) 10 | 11 | # add whatever still fits in the buffer 12 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index) 13 | if place_left: 14 | offset = min(place_left, batch_size) 15 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset]) 16 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset]) 17 | 18 | 19 | buffer.current_index += offset 20 | buffer.n_seen_so_far += offset 21 | 22 | # everything was added 23 | if offset == x.size(0): 24 | filled_idx = list(range(buffer.current_index - offset, buffer.current_index, )) 25 | if buffer.params.buffer_tracker: 26 | buffer.buffer_tracker.update_cache(buffer.buffer_label, y[:offset], filled_idx) 27 | return filled_idx 28 | 29 | 30 | #TODO: the buffer tracker will have bug when the mem size can't be divided by batch size 31 | 32 | # remove what is already in the buffer 33 | x, y = x[place_left:], y[place_left:] 34 | 35 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long() 36 | valid_indices = (indices < buffer.buffer_img.size(0)).long() 37 | 38 | idx_new_data = valid_indices.nonzero().squeeze(-1) 39 | idx_buffer = indices[idx_new_data] 40 | 41 | buffer.n_seen_so_far += x.size(0) 42 | 43 | if idx_buffer.numel() == 0: 44 | return [] 45 | 46 | assert idx_buffer.max() < buffer.buffer_img.size(0) 47 | assert idx_buffer.max() < buffer.buffer_label.size(0) 48 | # assert idx_buffer.max() < self.buffer_task.size(0) 49 | 50 | assert idx_new_data.max() < x.size(0) 51 | assert idx_new_data.max() < y.size(0) 52 | 53 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))} 54 | 55 | replace_y = y[list(idx_map.values())] 56 | if buffer.params.buffer_tracker: 57 | buffer.buffer_tracker.update_cache(buffer.buffer_label, replace_y, list(idx_map.keys())) 58 | # perform overwrite op 59 | buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 60 | buffer.buffer_label[list(idx_map.keys())] = replace_y 61 | return list(idx_map.keys()) -------------------------------------------------------------------------------- /utils/buffer/mir_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import maybe_cuda 3 | import torch.nn.functional as F 4 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector 5 | import copy 6 | 7 | 8 | class MIR_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.params = params 12 | self.subsample = params.subsample 13 | self.num_retrieve = params.eps_mem_batch 14 | 15 | def retrieve(self, buffer, **kwargs): 16 | sub_x, sub_y = random_retrieve(buffer, self.subsample) 17 | grad_dims = [] 18 | for param in buffer.model.parameters(): 19 | grad_dims.append(param.data.numel()) 20 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 21 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 22 | if sub_x.size(0) > 0: 23 | with torch.no_grad(): 24 | logits_pre,_ = buffer.model.forward(sub_x) 25 | logits_post,_ = model_temp.forward(sub_x) 26 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none') 27 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none') 28 | scores = post_loss - pre_loss 29 | big_ind = scores.sort(descending=True)[1][:self.num_retrieve] 30 | return sub_x[big_ind], sub_y[big_ind] 31 | else: 32 | return sub_x, sub_y 33 | 34 | def get_future_step_parameters(self, model, grad_vector, grad_dims): 35 | """ 36 | computes \theta-\delta\theta 37 | :param this_net: 38 | :param grad_vector: 39 | :return: 40 | """ 41 | new_model = copy.deepcopy(model) 42 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims) 43 | with torch.no_grad(): 44 | for param in new_model.parameters(): 45 | if param.grad is not None: 46 | param.data = param.data - self.params.learning_rate * param.grad.data 47 | return new_model 48 | 49 | def overwrite_grad(self, pp, new_grad, grad_dims): 50 | """ 51 | This is used to overwrite the gradients with a new gradient 52 | vector, whenever violations occur. 53 | pp: parameters 54 | newgrad: corrected gradient 55 | grad_dims: list storing number of parameters at each layer 56 | """ 57 | cnt = 0 58 | for param in pp(): 59 | param.grad = torch.zeros_like(param.data) 60 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 61 | en = sum(grad_dims[:cnt + 1]) 62 | this_grad = new_grad[beg: en].contiguous().view( 63 | param.data.size()) 64 | param.grad.data.copy_(this_grad) 65 | cnt += 1 -------------------------------------------------------------------------------- /utils/setup_elements.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.resnet import Reduced_ResNet18, SupConResNet,Reduced_ResNet18_DVC 3 | from torchvision import transforms 4 | import torch.nn as nn 5 | 6 | 7 | default_trick = {'labels_trick': False, 'kd_trick': False, 'separated_softmax': False, 8 | 'review_trick': False, 'ncm_trick': False, 'kd_trick_star': False} 9 | 10 | 11 | input_size_match = { 12 | 'cifar100': [3, 32, 32], 13 | 'cifar10': [3, 32, 32], 14 | 'core50': [3, 128, 128], 15 | 'mini_imagenet': [3, 84, 84], 16 | 'openloris': [3, 50, 50] 17 | } 18 | 19 | 20 | n_classes = { 21 | 'cifar100': 100, 22 | 'cifar10': 10, 23 | 'core50': 50, 24 | 'mini_imagenet': 100, 25 | 'openloris': 69 26 | } 27 | 28 | 29 | transforms_match = { 30 | 'core50': transforms.Compose([ 31 | transforms.ToTensor(), 32 | ]), 33 | 'cifar100': transforms.Compose([ 34 | transforms.ToTensor(), 35 | ]), 36 | 'cifar10': transforms.Compose([ 37 | transforms.ToTensor(), 38 | ]), 39 | 'mini_imagenet': transforms.Compose([ 40 | transforms.ToTensor()]), 41 | 'openloris': transforms.Compose([ 42 | transforms.ToTensor()]) 43 | } 44 | 45 | 46 | def setup_architecture(params): 47 | nclass = n_classes[params.data] 48 | if params.agent in ['SCR', 'SCP']: 49 | if params.data == 'mini_imagenet': 50 | return SupConResNet(640, head=params.head) 51 | return SupConResNet(head=params.head) 52 | if params.agent == 'CNDPM': 53 | from models.ndpm.ndpm import Ndpm 54 | return Ndpm(params) 55 | if params.data == 'cifar100': 56 | if params.agent == 'ER_DVC': 57 | return Reduced_ResNet18_DVC(nclass) 58 | else: 59 | return Reduced_ResNet18(nclass) 60 | elif params.data == 'cifar10': 61 | if params.agent == 'ER_DVC': 62 | return Reduced_ResNet18_DVC(nclass) 63 | else: 64 | return Reduced_ResNet18(nclass) 65 | elif params.data == 'core50': 66 | model = Reduced_ResNet18(nclass) 67 | model.backbone.linear = nn.Linear(2560, nclass, bias=True) 68 | return model 69 | elif params.data == 'mini_imagenet': 70 | if params.agent == 'ER_DVC': 71 | model= Reduced_ResNet18_DVC(nclass) 72 | model.backbone.linear = nn.Linear(640, nclass, bias=True) 73 | else: 74 | model = Reduced_ResNet18(nclass) 75 | model.linear = nn.Linear(640, nclass, bias=True) 76 | return model 77 | elif params.data == 'openloris': 78 | return Reduced_ResNet18(nclass) 79 | 80 | 81 | def setup_opt(optimizer, model, lr, wd): 82 | if optimizer == 'SGD': 83 | optim = torch.optim.SGD(model.parameters(), 84 | lr=lr, 85 | weight_decay=wd) 86 | elif optimizer == 'Adam': 87 | optim = torch.optim.Adam(model.parameters(), 88 | lr=lr, 89 | weight_decay=wd) 90 | else: 91 | raise Exception('wrong optimizer name') 92 | return optim 93 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR100(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar100' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR100, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR100(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR100(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc': 36 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.test_set = [] 38 | for labels in self.task_labels: 39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 40 | self.test_set.append((x_test, y_test)) 41 | else: 42 | raise Exception('wrong scenario') 43 | 44 | def new_task(self, cur_task, **kwargs): 45 | if self.scenario == 'ni': 46 | x_train, y_train = self.train_set[cur_task] 47 | labels = set(y_train) 48 | elif self.scenario == 'nc': 49 | labels = self.task_labels[cur_task] 50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 51 | return x_train, y_train, labels 52 | 53 | def new_run(self, **kwargs): 54 | self.setup() 55 | return self.test_set 56 | 57 | def test_plot(self): 58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 59 | self.params.ns_factor) 60 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR10(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar10' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR10, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR10(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR10(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc': 36 | self.task_labels = create_task_composition(class_nums=10, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.test_set = [] 38 | for labels in self.task_labels: 39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 40 | self.test_set.append((x_test, y_test)) 41 | else: 42 | raise Exception('wrong scenario') 43 | 44 | def new_task(self, cur_task, **kwargs): 45 | if self.scenario == 'ni': 46 | x_train, y_train = self.train_set[cur_task] 47 | labels = set(y_train) 48 | elif self.scenario == 'nc': 49 | labels = self.task_labels[cur_task] 50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 51 | return x_train, y_train, labels 52 | 53 | def new_run(self, **kwargs): 54 | self.setup() 55 | return self.test_set 56 | 57 | def test_plot(self): 58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 59 | self.params.ns_factor) 60 | -------------------------------------------------------------------------------- /agents/icarl.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils import utils 4 | from utils.buffer.buffer_utils import random_retrieve 5 | from utils.setup_elements import transforms_match 6 | from torch.utils import data 7 | import numpy as np 8 | from torch.nn import functional as F 9 | from utils.utils import maybe_cuda, AverageMeter 10 | from utils.buffer.buffer import Buffer 11 | import torch 12 | import copy 13 | 14 | 15 | class Icarl(ContinualLearner): 16 | def __init__(self, model, opt, params): 17 | super(Icarl, self).__init__(model, opt, params) 18 | self.model = model 19 | self.mem_size = params.mem_size 20 | self.buffer = Buffer(model, params) 21 | self.prev_model = None 22 | 23 | def train_learner(self, x_train, y_train): 24 | self.before_train(x_train, y_train) 25 | # set up loader 26 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 27 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 28 | drop_last=True) 29 | self.model.train() 30 | self.update_representation(train_loader) 31 | self.prev_model = copy.deepcopy(self.model) 32 | self.after_train() 33 | 34 | def update_representation(self, train_loader): 35 | updated_idx = [] 36 | for ep in range(self.epoch): 37 | for i, train_data in enumerate(train_loader): 38 | # batch update 39 | train_x, train_y = train_data 40 | train_x = maybe_cuda(train_x, self.cuda) 41 | train_y = maybe_cuda(train_y, self.cuda) 42 | train_y_copy = train_y.clone() 43 | for k, y in enumerate(train_y_copy): 44 | train_y_copy[k] = len(self.old_labels) + self.new_labels.index(y) 45 | all_cls_num = len(self.new_labels) + len(self.old_labels) 46 | target_labels = utils.ohe_label(train_y_copy, all_cls_num, device=train_y_copy.device).float() 47 | if self.prev_model is not None: 48 | mem_x, mem_y = random_retrieve(self.buffer, self.batch, 49 | excl_indices=updated_idx) 50 | mem_x = maybe_cuda(mem_x, self.cuda) 51 | batch_x = torch.cat([train_x, mem_x]) 52 | target_labels = torch.cat([target_labels, torch.zeros_like(target_labels)]) 53 | else: 54 | batch_x = train_x 55 | logits = self.forward(batch_x) 56 | self.opt.zero_grad() 57 | if self.prev_model is not None: 58 | with torch.no_grad(): 59 | q = torch.sigmoid(self.prev_model.forward(batch_x)) 60 | for k, y in enumerate(self.old_labels): 61 | target_labels[:, k] = q[:, k] 62 | loss = F.binary_cross_entropy_with_logits(logits[:, :all_cls_num], target_labels, reduction='none').sum(dim=1).mean() 63 | loss.backward() 64 | self.opt.step() 65 | updated_idx += self.buffer.update(train_x, train_y) 66 | -------------------------------------------------------------------------------- /utils/L_softmax.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from scipy.special import binom 5 | 6 | 7 | class LSoftmaxLinear(nn.Module): 8 | 9 | def __init__(self, input_features, output_features, margin): 10 | super().__init__() 11 | self.input_dim = input_features # number of input feature i.e. output of the last fc layer 12 | self.output_dim = output_features # number of output = class numbers 13 | self.margin = margin # m 14 | self.beta = 100 15 | self.beta_min = 0 16 | self.scale = 0.99 17 | 18 | # gpu or cpu 19 | 20 | # Initialize L-Softmax parameters 21 | self.weight = nn.Parameter(torch.FloatTensor(input_features, output_features)) 22 | self.divisor = math.pi / self.margin # pi/m 23 | self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1, 2))).cuda() # C_m{2n} 24 | self.cos_powers = torch.Tensor(range(self.margin, -1, -2)).cuda() # m - 2n 25 | self.sin2_powers = torch.Tensor(range(len(self.cos_powers))).cuda() # n 26 | self.signs = torch.ones(margin // 2 + 1).cuda() # 1, -1, 1, -1, ... 27 | self.signs[1::2] = -1 28 | 29 | def calculate_cos_m_theta(self, cos_theta): 30 | sin2_theta = 1 - cos_theta**2 31 | cos_terms = cos_theta.unsqueeze(1) ** self.cos_powers.unsqueeze(0) # cos^{m - 2n} 32 | sin2_terms = (sin2_theta.unsqueeze(1) # sin2^{n} 33 | ** self.sin2_powers.unsqueeze(0)) 34 | 35 | cos_m_theta = (self.signs.unsqueeze(0) * # -1^{n} * C_m{2n} * cos^{m - 2n} * sin2^{n} 36 | self.C_m_2n.unsqueeze(0) * 37 | cos_terms * 38 | sin2_terms).sum(1) # summation of all terms 39 | 40 | return cos_m_theta 41 | 42 | def reset_parameters(self): 43 | nn.init.kaiming_normal_(self.weight.data.t()) 44 | 45 | def find_k(self, cos): 46 | # to account for acos numerical errors 47 | eps = 1e-7 48 | cos = torch.clamp(cos, -1 + eps, 1 - eps) 49 | acos = cos.acos() 50 | k = (acos / self.divisor).floor().detach() 51 | return k 52 | 53 | def forward(self, input, target=None): 54 | if self.training: 55 | assert target is not None 56 | x, w = input, self.weight 57 | beta = max(self.beta, self.beta_min) 58 | logit = x.mm(w) 59 | indexes = range(logit.size(0)) 60 | logit_target = logit[indexes, target] 61 | 62 | # cos(theta) = w * x / ||w||*||x|| 63 | w_target_norm = w[:, target].norm(p=2, dim=0) 64 | x_norm = x.norm(p=2, dim=1) 65 | cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10) 66 | 67 | # equation 7 68 | cos_m_theta_target = self.calculate_cos_m_theta(cos_theta_target) 69 | 70 | # find k in equation 6 71 | k = self.find_k(cos_theta_target) 72 | 73 | # f_y_i 74 | logit_target_updated = (w_target_norm * 75 | x_norm * 76 | (((-1) ** k * cos_m_theta_target) - 2 * k)) 77 | logit_target_updated_beta = (logit_target_updated + beta * logit[indexes, target]) / (1 + beta) 78 | 79 | logit[indexes, target] = logit_target_updated_beta 80 | self.beta *= self.scale 81 | return logit 82 | else: 83 | assert target is None 84 | return input.mm(self.weight) -------------------------------------------------------------------------------- /continuum/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils import data 4 | from utils.setup_elements import transforms_match 5 | 6 | def create_task_composition(class_nums, num_tasks, fixed_order=False): 7 | classes_per_task = class_nums // num_tasks 8 | total_classes = classes_per_task * num_tasks 9 | label_array = np.arange(0, total_classes) 10 | if not fixed_order: 11 | np.random.shuffle(label_array) 12 | 13 | task_labels = [] 14 | for tt in range(num_tasks): 15 | tt_offset = tt * classes_per_task 16 | task_labels.append(list(label_array[tt_offset:tt_offset + classes_per_task])) 17 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 18 | return task_labels 19 | 20 | 21 | def load_task_with_labels_torch(x, y, labels): 22 | tmp = [] 23 | for i in labels: 24 | tmp.append((y == i).nonzero().view(-1)) 25 | idx = torch.cat(tmp) 26 | return x[idx], y[idx] 27 | 28 | 29 | def load_task_with_labels(x, y, labels): 30 | tmp = [] 31 | for i in labels: 32 | tmp.append((np.where(y == i)[0])) 33 | idx = np.concatenate(tmp, axis=None) 34 | return x[idx], y[idx] 35 | 36 | 37 | 38 | class dataset_transform(data.Dataset): 39 | def __init__(self, x, y, transform=None): 40 | self.x = x 41 | self.y = torch.from_numpy(y).type(torch.LongTensor) 42 | self.transform = transform # save the transform 43 | 44 | def __len__(self): 45 | return len(self.y)#self.x.shape[0] # return 1 as we have only one image 46 | 47 | def __getitem__(self, idx): 48 | # return the augmented image 49 | if self.transform: 50 | x = self.transform(self.x[idx]) 51 | else: 52 | x = self.x[idx] 53 | 54 | return x.float(), self.y[idx] 55 | 56 | 57 | def setup_test_loader(test_data, params): 58 | test_loaders = [] 59 | 60 | for (x_test, y_test) in test_data: 61 | test_dataset = dataset_transform(x_test, y_test, transform=transforms_match[params.data]) 62 | test_loader = data.DataLoader(test_dataset, batch_size=params.test_batch, shuffle=True, num_workers=0) 63 | test_loaders.append(test_loader) 64 | return test_loaders 65 | 66 | 67 | def shuffle_data(x, y): 68 | perm_inds = np.arange(0, x.shape[0]) 69 | np.random.shuffle(perm_inds) 70 | rdm_x = x[perm_inds] 71 | rdm_y = y[perm_inds] 72 | return rdm_x, rdm_y 73 | 74 | 75 | def train_val_test_split_ni(train_data, train_label, test_data, test_label, task_nums, img_size, val_size=0.1): 76 | train_data_rdm, train_label_rdm = shuffle_data(train_data, train_label) 77 | val_size = int(len(train_data_rdm) * val_size) 78 | val_data_rdm, val_label_rdm = train_data_rdm[:val_size], train_label_rdm[:val_size] 79 | train_data_rdm, train_label_rdm = train_data_rdm[val_size:], train_label_rdm[val_size:] 80 | test_data_rdm, test_label_rdm = shuffle_data(test_data, test_label) 81 | train_data_rdm_split = train_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 82 | train_label_rdm_split = train_label_rdm.reshape(task_nums, -1) 83 | val_data_rdm_split = val_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 84 | val_label_rdm_split = val_label_rdm.reshape(task_nums, -1) 85 | test_data_rdm_split = test_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 86 | test_label_rdm_split = test_label_rdm.reshape(task_nums, -1) 87 | return train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split -------------------------------------------------------------------------------- /agents/scr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match, input_size_match 7 | from utils.utils import maybe_cuda, AverageMeter 8 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale 9 | import torch.nn as nn 10 | 11 | class SupContrastReplay(ContinualLearner): 12 | def __init__(self, model, opt, params): 13 | super(SupContrastReplay, self).__init__(model, opt, params) 14 | self.buffer = Buffer(model, params) 15 | self.mem_size = params.mem_size 16 | self.eps_mem_batch = params.eps_mem_batch 17 | self.mem_iters = params.mem_iters 18 | self.transform = nn.Sequential( 19 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)), 20 | RandomHorizontalFlip(), 21 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 22 | RandomGrayscale(p=0.2) 23 | 24 | ) 25 | self.criterion_ce= torch.nn.CrossEntropyLoss(reduction='mean') 26 | 27 | def train_learner(self, x_train, y_train): 28 | self.before_train(x_train, y_train) 29 | # set up loader 30 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 31 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 32 | drop_last=True) 33 | # set up model 34 | self.model = self.model.train() 35 | 36 | # setup tracker 37 | losses = AverageMeter() 38 | acc_batch = AverageMeter() 39 | 40 | for ep in range(self.epoch): 41 | for i, batch_data in enumerate(train_loader): 42 | # batch update 43 | batch_x, batch_y = batch_data 44 | batch_x = maybe_cuda(batch_x, self.cuda) 45 | batch_y = maybe_cuda(batch_y, self.cuda) 46 | 47 | for j in range(self.mem_iters): 48 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 49 | 50 | if mem_x.size(0) > 0: 51 | mem_x = maybe_cuda(mem_x, self.cuda) 52 | mem_y = maybe_cuda(mem_y, self.cuda) 53 | combined_batch = torch.cat((mem_x, batch_x)) 54 | combined_labels = torch.cat((mem_y, batch_y)) 55 | combined_batch_aug = self.transform(combined_batch) 56 | batch_features, batch_logits = self.model.forward(combined_batch) 57 | batch_features_aug, batch_logits_aug = self.model.forward(combined_batch_aug) 58 | features = torch.cat([batch_features.unsqueeze(1), batch_features_aug.unsqueeze(1)], dim=1) 59 | loss_1 = self.criterion(features, combined_labels) 60 | loss_2 = self.criterion_ce(batch_logits,combined_labels) 61 | loss_3 = self.criterion_ce(batch_logits_aug, combined_labels) 62 | loss = 0.2*loss_1 + loss_2 + loss_3 63 | losses.update(loss, batch_y.size(0)) 64 | self.opt.zero_grad() 65 | loss.backward() 66 | self.opt.step() 67 | 68 | # update mem 69 | self.buffer.update(batch_x, batch_y) 70 | if i % 100 == 1 and self.verbose: 71 | print( 72 | '==>>> it: {}, avg. loss: {:.6f}, ' 73 | .format(i, losses.avg(), acc_batch.avg()) 74 | ) 75 | self.after_train() 76 | -------------------------------------------------------------------------------- /features/wide_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch.nn as nn 6 | from . import extractor 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=True) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 18 | bias=True) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | def __init__(self, inplanes, planes, dropout, stride=1): 23 | super(BasicBlock, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(inplanes) 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.dropout = nn.Dropout(dropout) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | if stride != 1 or inplanes != planes: 31 | self.shortcut = conv1x1(inplanes, planes, stride) 32 | self.use_conv1x1 = True 33 | else: 34 | self.use_conv1x1 = False 35 | 36 | def forward(self, x): 37 | out = self.bn1(x) 38 | out = self.relu(out) 39 | 40 | if self.use_conv1x1: 41 | shortcut = self.shortcut(out) 42 | else: 43 | shortcut = x 44 | 45 | out = self.conv1(out) 46 | 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | out = self.dropout(out) 50 | out = self.conv2(out) 51 | 52 | out += shortcut 53 | 54 | return out 55 | 56 | 57 | class WideResNet(extractor.BaseModule): 58 | def __init__(self, config, name): 59 | super(WideResNet, self).__init__() 60 | self.name = name 61 | depth = config["depth"] 62 | width = config["width"] 63 | dropout = config["dropout"] 64 | in_channels = config["channels"] 65 | print("%s depth: %d, width: %d, dropout=%f" \ 66 | % (type(self), depth, width, dropout)) 67 | 68 | layer = (depth - 4) // 6 69 | 70 | self.inplanes = 16 71 | self.conv = conv3x3(in_channels, 16) 72 | self.layer1 = self._make_layer(16*width, layer, dropout) 73 | self.layer2 = self._make_layer(32*width, layer, dropout, stride=2) 74 | self.layer3 = self._make_layer(64*width, layer, dropout, stride=2) 75 | self.bn = nn.BatchNorm2d(64*width) 76 | self.relu = nn.ReLU(inplace=True) 77 | 78 | self.n_features = 64 * width 79 | #self.avgpool = nn.AdaptiveAvgPool2d(1) 80 | #self.fc = nn.Linear(64*width, num_classes) 81 | 82 | # use default init 83 | #for m in self.modules(): 84 | # if isinstance(m, nn.Conv2d): 85 | # m.weight.data = nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 86 | 87 | def _make_layer(self, planes, blocks, dropout, stride=1): 88 | layers = [] 89 | for i in range(blocks): 90 | layers.append(BasicBlock(self.inplanes, planes, dropout, stride if i == 0 else 1)) 91 | self.inplanes = planes 92 | 93 | return nn.Sequential(*layers) 94 | 95 | 96 | def forward(self, x): 97 | x = self.conv(x) 98 | 99 | x = self.layer1(x) 100 | x = self.layer2(x) 101 | x = self.layer3(x) 102 | 103 | x = self.bn(x) 104 | x = self.relu(x) 105 | #x = self.avgpool(x) 106 | #x = x.view(x.size(0), -1) 107 | #x = self.fc(x) 108 | 109 | return x 110 | 111 | 112 | 113 | if __name__ == '__main__': 114 | pass 115 | -------------------------------------------------------------------------------- /agents/gdumb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import math 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match, setup_architecture, setup_opt 7 | from utils.utils import maybe_cuda, EarlyStopping 8 | import numpy as np 9 | import random 10 | 11 | 12 | class Gdumb(ContinualLearner): 13 | def __init__(self, model, opt, params): 14 | super(Gdumb, self).__init__(model, opt, params) 15 | self.mem_img = {} 16 | self.mem_c = {} 17 | #self.early_stopping = EarlyStopping(self.params.min_delta, self.params.patience, self.params.cumulative_delta) 18 | 19 | def greedy_balancing_update(self, x, y): 20 | k_c = self.params.mem_size // max(1, len(self.mem_img)) 21 | if y not in self.mem_img or self.mem_c[y] < k_c: 22 | if sum(self.mem_c.values()) >= self.params.mem_size: 23 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0] 24 | idx = random.randrange(self.mem_c[cls_max]) 25 | self.mem_img[cls_max].pop(idx) 26 | self.mem_c[cls_max] -= 1 27 | if y not in self.mem_img: 28 | self.mem_img[y] = [] 29 | self.mem_c[y] = 0 30 | self.mem_img[y].append(x) 31 | self.mem_c[y] += 1 32 | 33 | def train_learner(self, x_train, y_train): 34 | self.before_train(x_train, y_train) 35 | # set up loader 36 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 37 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 38 | drop_last=True) 39 | 40 | for i, batch_data in enumerate(train_loader): 41 | # batch update 42 | batch_x, batch_y = batch_data 43 | batch_x = maybe_cuda(batch_x, self.cuda) 44 | batch_y = maybe_cuda(batch_y, self.cuda) 45 | # update mem 46 | for j in range(len(batch_x)): 47 | self.greedy_balancing_update(batch_x[j], batch_y[j].item()) 48 | #self.early_stopping.reset() 49 | self.train_mem() 50 | self.after_train() 51 | 52 | def train_mem(self): 53 | mem_x = [] 54 | mem_y = [] 55 | for i in self.mem_img.keys(): 56 | mem_x += self.mem_img[i] 57 | mem_y += [i] * self.mem_c[i] 58 | 59 | mem_x = torch.stack(mem_x) 60 | mem_y = torch.LongTensor(mem_y) 61 | self.model = setup_architecture(self.params) 62 | self.model = maybe_cuda(self.model, self.cuda) 63 | opt = setup_opt(self.params.optimizer, self.model, self.params.learning_rate, self.params.weight_decay) 64 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1, T_mult=2, eta_min=self.params.minlr) 65 | 66 | #loss = math.inf 67 | for i in range(self.params.mem_epoch): 68 | idx = np.random.permutation(len(mem_x)).tolist() 69 | mem_x = maybe_cuda(mem_x[idx], self.cuda) 70 | mem_y = maybe_cuda(mem_y[idx], self.cuda) 71 | self.model = self.model.train() 72 | batch_size = self.params.batch 73 | #scheduler.step() 74 | #if opt.param_groups[0]['lr'] == self.params.learning_rate: 75 | # if self.early_stopping.step(-loss): 76 | # return 77 | for j in range(len(mem_y) // batch_size): 78 | opt.zero_grad() 79 | logits = self.model.forward(mem_x[batch_size * j:batch_size * (j + 1)]) 80 | loss = self.criterion(logits, mem_y[batch_size * j:batch_size * (j + 1)]) 81 | loss.backward() 82 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip) 83 | opt.step() 84 | 85 | -------------------------------------------------------------------------------- /utils/buffer/mgi_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector 4 | import copy 5 | from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip,ColorJitter,RandomGrayscale 6 | from utils.setup_elements import transforms_match, input_size_match 7 | import torch.nn as nn 8 | from utils.setup_elements import n_classes 9 | 10 | 11 | class MGI_retrieve(object): 12 | def __init__(self, params, **kwargs): 13 | super().__init__() 14 | self.params = params 15 | self.subsample = params.subsample 16 | self.num_retrieve = params.eps_mem_batch 17 | self.transform = nn.Sequential( 18 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), 19 | scale=(0.2, 1.)).cuda(), 20 | RandomHorizontalFlip().cuda(), 21 | ColorJitter(0.4, 0.4, 0.4, 0.1), 22 | RandomGrayscale(p=0.2) 23 | 24 | ) 25 | self.out_dim = n_classes[params.data] 26 | 27 | def retrieve(self, buffer, **kwargs): 28 | sub_x, sub_y = random_retrieve(buffer, self.subsample) 29 | grad_dims = [] 30 | for param in buffer.model.parameters(): 31 | grad_dims.append(param.data.numel()) 32 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 33 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 34 | if sub_x.size(0) > 0: 35 | with torch.no_grad(): 36 | sub_x_aug = self.transform(sub_x) 37 | logits_pre = buffer.model(sub_x,sub_x_aug) 38 | logits_post = model_temp(sub_x,sub_x_aug) 39 | 40 | z_pre, zt_pre, zzt_pre,fea_z_pre = logits_pre 41 | z_post, zt_post, zzt_post,fea_z_post = logits_post 42 | 43 | 44 | grads_pre_z= torch.sum(torch.abs(F.softmax(z_pre, dim=1) - F.one_hot(sub_y, self.out_dim)), 1) 45 | mgi_pre_z = grads_pre_z * fea_z_pre[0].reshape(-1) 46 | grads_post_z = torch.sum(torch.abs(F.softmax(z_post, dim=1) - F.one_hot(sub_y, self.out_dim)), 1) # N * 1 47 | mgi_post_z = grads_post_z * fea_z_post[0].reshape(-1) 48 | 49 | 50 | scores = mgi_post_z - mgi_pre_z 51 | 52 | big_ind = scores.sort(descending=True)[1][:int(self.num_retrieve)] 53 | 54 | 55 | 56 | return sub_x[big_ind], sub_x_aug[big_ind],sub_y[big_ind] 57 | else: 58 | return sub_x, sub_x,sub_y 59 | 60 | def get_future_step_parameters(self, model, grad_vector, grad_dims): 61 | """ 62 | computes \theta-\delta\theta 63 | :param this_net: 64 | :param grad_vector: 65 | :return: 66 | """ 67 | new_model = copy.deepcopy(model) 68 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims) 69 | with torch.no_grad(): 70 | for param in new_model.parameters(): 71 | if param.grad is not None: 72 | param.data = param.data - self.params.learning_rate * param.grad.data 73 | return new_model 74 | 75 | def overwrite_grad(self, pp, new_grad, grad_dims): 76 | """ 77 | This is used to overwrite the gradients with a new gradient 78 | vector, whenever violations occur. 79 | pp: parameters 80 | newgrad: corrected gradient 81 | grad_dims: list storing number of parameters at each layer 82 | """ 83 | cnt = 0 84 | for param in pp(): 85 | param.grad = torch.zeros_like(param.data) 86 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 87 | en = sum(grad_dims[:cnt + 1]) 88 | this_grad = new_grad[beg: en].contiguous().view( 89 | param.data.size()) 90 | param.grad.data.copy_(this_grad) 91 | cnt += 1 -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all'): 15 | super(SupConLoss, self).__init__() 16 | self.temperature = temperature 17 | self.contrast_mode = contrast_mode 18 | 19 | def forward(self, features, labels=None, mask=None): 20 | """Compute loss for model. If both `labels` and `mask` are None, 21 | it degenerates to SimCLR unsupervised loss: 22 | https://arxiv.org/pdf/2002.05709.pdf 23 | 24 | Args: 25 | features: hidden vector of shape [bsz, n_views, ...]. 26 | labels: ground truth of shape [bsz]. 27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 28 | has the same class as sample i. Can be asymmetric. 29 | Returns: 30 | A loss scalar. 31 | """ 32 | device = (torch.device('cuda') 33 | if features.is_cuda 34 | else torch.device('cpu')) 35 | 36 | if len(features.shape) < 3: 37 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 38 | 'at least 3 dimensions are required') 39 | if len(features.shape) > 3: 40 | features = features.view(features.shape[0], features.shape[1], -1) 41 | 42 | batch_size = features.shape[0] 43 | if labels is not None and mask is not None: 44 | raise ValueError('Cannot define both `labels` and `mask`') 45 | elif labels is None and mask is None: 46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 47 | elif labels is not None: 48 | labels = labels.contiguous().view(-1, 1) 49 | if labels.shape[0] != batch_size: 50 | raise ValueError('Num of labels does not match num of features') 51 | mask = torch.eq(labels, labels.T).float().to(device) 52 | else: 53 | mask = mask.float().to(device) 54 | 55 | contrast_count = features.shape[1] 56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 57 | if self.contrast_mode == 'one': 58 | anchor_feature = features[:, 0] 59 | anchor_count = 1 60 | elif self.contrast_mode == 'all': 61 | anchor_feature = contrast_feature 62 | anchor_count = contrast_count 63 | else: 64 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 65 | 66 | # compute logits 67 | anchor_dot_contrast = torch.div( 68 | torch.matmul(anchor_feature, contrast_feature.T), 69 | self.temperature) 70 | # for numerical stability 71 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 72 | logits = anchor_dot_contrast - logits_max.detach() 73 | 74 | # tile mask 75 | mask = mask.repeat(anchor_count, contrast_count) 76 | # mask-out self-contrast cases 77 | logits_mask = torch.scatter( 78 | torch.ones_like(mask), 79 | 1, 80 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 81 | 0 82 | ) 83 | mask = mask * logits_mask 84 | 85 | # compute log_prob 86 | exp_logits = torch.exp(logits) * logits_mask 87 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 88 | 89 | # compute mean of log-likelihood over positive 90 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 91 | 92 | # loss 93 | loss = -1 * mean_log_prob_pos 94 | loss = loss.view(anchor_count, batch_size).mean() 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | TEST_SPLIT = 1 / 6 8 | 9 | 10 | class Mini_ImageNet(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'mini_imagenet' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | super(Mini_ImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | train_in = open("datasets/mini_imagenet/mini-imagenet-cache-train.pkl", "rb") 22 | train = pickle.load(train_in) 23 | train_x = train["image_data"].reshape([64, 600, 84, 84, 3]) 24 | val_in = open("datasets/mini_imagenet/mini-imagenet-cache-val.pkl", "rb") 25 | val = pickle.load(val_in) 26 | val_x = val['image_data'].reshape([16, 600, 84, 84, 3]) 27 | test_in = open("datasets/mini_imagenet/mini-imagenet-cache-test.pkl", "rb") 28 | test = pickle.load(test_in) 29 | test_x = test['image_data'].reshape([20, 600, 84, 84, 3]) 30 | all_data = np.vstack((train_x, val_x, test_x)) 31 | train_data = [] 32 | train_label = [] 33 | test_data = [] 34 | test_label = [] 35 | for i in range(len(all_data)): 36 | cur_x = all_data[i] 37 | cur_y = np.ones((600,)) * i 38 | rdm_x, rdm_y = shuffle_data(cur_x, cur_y) 39 | x_test = rdm_x[: int(600 * TEST_SPLIT)] 40 | y_test = rdm_y[: int(600 * TEST_SPLIT)] 41 | x_train = rdm_x[int(600 * TEST_SPLIT):] 42 | y_train = rdm_y[int(600 * TEST_SPLIT):] 43 | train_data.append(x_train) 44 | train_label.append(y_train) 45 | test_data.append(x_test) 46 | test_label.append(y_test) 47 | self.train_data = np.concatenate(train_data) 48 | self.train_label = np.concatenate(train_label) 49 | self.test_data = np.concatenate(test_data) 50 | self.test_label = np.concatenate(test_label) 51 | 52 | def new_run(self, **kwargs): 53 | self.setup() 54 | return self.test_set 55 | 56 | def new_task(self, cur_task, **kwargs): 57 | if self.scenario == 'ni': 58 | x_train, y_train = self.train_set[cur_task] 59 | labels = set(y_train) 60 | elif self.scenario == 'nc': 61 | labels = self.task_labels[cur_task] 62 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 63 | else: 64 | raise Exception('unrecognized scenario') 65 | return x_train, y_train, labels 66 | 67 | def setup(self): 68 | if self.scenario == 'ni': 69 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 70 | self.train_label, 71 | self.test_data, self.test_label, 72 | self.task_nums, 84, 73 | self.params.val_size, 74 | self.params.ns_type, self.params.ns_factor, 75 | plot=self.params.plot_sample) 76 | 77 | elif self.scenario == 'nc': 78 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, 79 | fixed_order=self.params.fix_order) 80 | self.test_set = [] 81 | for labels in self.task_labels: 82 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 83 | self.test_set.append((x_test, y_test)) 84 | 85 | def test_plot(self): 86 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 87 | self.params.ns_factor) 88 | -------------------------------------------------------------------------------- /models/ndpm/component.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | from typing import Tuple 5 | 6 | from utils.utils import maybe_cuda 7 | from utils.global_vars import * 8 | 9 | 10 | class Component(nn.Module, ABC): 11 | def __init__(self, params, experts: Tuple): 12 | super().__init__() 13 | self.params = params 14 | self.experts = experts 15 | 16 | self.optimizer = NotImplemented 17 | self.lr_scheduler = NotImplemented 18 | 19 | @abstractmethod 20 | def nll(self, x, y, step=None): 21 | """Return NLL""" 22 | pass 23 | 24 | @abstractmethod 25 | def collect_nll(self, x, y, step=None): 26 | """Return NLLs including previous experts""" 27 | pass 28 | 29 | def _clip_grad_value(self, clip_value): 30 | for group in self.optimizer.param_groups: 31 | nn.utils.clip_grad_value_(group['params'], clip_value) 32 | 33 | def _clip_grad_norm(self, max_norm, norm_type=2): 34 | for group in self.optimizer.param_groups: 35 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type) 36 | 37 | def clip_grad(self): 38 | clip_grad_config = MODELS_NDPM_COMPONENT_CLIP_GRAD 39 | if clip_grad_config['type'] == 'value': 40 | self._clip_grad_value(**clip_grad_config['options']) 41 | elif clip_grad_config['type'] == 'norm': 42 | self._clip_grad_norm(**clip_grad_config['options']) 43 | else: 44 | raise ValueError('Invalid clip_grad type: {}' 45 | .format(clip_grad_config['type'])) 46 | 47 | @staticmethod 48 | def build_optimizer(optim_config, params): 49 | return getattr(torch.optim, optim_config['type'])( 50 | params, **optim_config['options']) 51 | 52 | @staticmethod 53 | def build_lr_scheduler(lr_config, optimizer): 54 | return getattr(torch.optim.lr_scheduler, lr_config['type'])( 55 | optimizer, **lr_config['options']) 56 | 57 | def weight_decay_loss(self): 58 | loss = maybe_cuda(torch.zeros([])) 59 | for param in self.parameters(): 60 | loss += torch.norm(param) ** 2 61 | return loss 62 | 63 | 64 | class ComponentG(Component, ABC): 65 | def setup_optimizer(self): 66 | self.optimizer = self.build_optimizer( 67 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 68 | self.lr_scheduler = self.build_lr_scheduler( 69 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G, self.optimizer) 70 | 71 | def collect_nll(self, x, y=None, step=None): 72 | """Default `collect_nll` 73 | 74 | Warning: Parameter-sharing components should implement their own 75 | `collect_nll` 76 | 77 | Returns: 78 | nll: Tensor of shape [B, 1+K] 79 | """ 80 | outputs = [expert.g.nll(x, y, step) for expert in self.experts] 81 | nll = outputs 82 | output = self.nll(x, y, step) 83 | nll.append(output) 84 | return torch.stack(nll, dim=1) 85 | 86 | 87 | 88 | class ComponentD(Component, ABC): 89 | def setup_optimizer(self): 90 | self.optimizer = self.build_optimizer( 91 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 92 | self.lr_scheduler = self.build_lr_scheduler( 93 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D, self.optimizer) 94 | 95 | def collect_forward(self, x): 96 | """Default `collect_forward` 97 | 98 | Warning: Parameter-sharing components should implement their own 99 | `collect_forward` 100 | 101 | Returns: 102 | output: Tensor of shape [B, 1+K, C] 103 | """ 104 | outputs = [expert.d(x) for expert in self.experts] 105 | outputs.append(self.forward(x)) 106 | return torch.stack(outputs, 1) 107 | 108 | def collect_nll(self, x, y, step=None): 109 | """Default `collect_nll` 110 | 111 | Warning: Parameter-sharing components should implement their own 112 | `collect_nll` 113 | 114 | Returns: 115 | nll: Tensor of shape [B, 1+K] 116 | """ 117 | outputs = [expert.d.nll(x, y, step) for expert in self.experts] 118 | nll = outputs 119 | output = self.nll(x, y, step) 120 | nll.append(output) 121 | return torch.stack(nll, dim=1) 122 | -------------------------------------------------------------------------------- /utils/buffer/aser_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.buffer_utils import random_retrieve, ClassBalancedRandomSampling 3 | from utils.buffer.aser_utils import compute_knn_sv 4 | from utils.utils import maybe_cuda 5 | from utils.setup_elements import n_classes 6 | 7 | 8 | class ASER_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.num_retrieve = params.eps_mem_batch 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.aser_type = params.aser_type 16 | self.n_smp_cls = int(params.n_smp_cls) 17 | self.out_dim = n_classes[params.data] 18 | self.is_aser_upt = params.update == "ASER" 19 | ClassBalancedRandomSampling.class_index_cache = None 20 | 21 | def retrieve(self, buffer, **kwargs): 22 | model = buffer.model 23 | 24 | if buffer.n_seen_so_far <= self.mem_size: 25 | # Use random retrieval until buffer is filled 26 | ret_x, ret_y = random_retrieve(buffer, self.num_retrieve) 27 | else: 28 | # Use ASER retrieval if buffer is filled 29 | cur_x, cur_y = kwargs['x'], kwargs['y'] 30 | buffer_x, buffer_y = buffer.buffer_img, buffer.buffer_label 31 | ret_x, ret_y = self._retrieve_by_knn_sv(model, buffer_x, buffer_y, cur_x, cur_y, self.num_retrieve) 32 | return ret_x, ret_y 33 | 34 | def _retrieve_by_knn_sv(self, model, buffer_x, buffer_y, cur_x, cur_y, num_retrieve): 35 | """ 36 | Retrieves data instances with top-N Shapley Values from candidate set. 37 | Args: 38 | model (object): neural network. 39 | buffer_x (tensor): data buffer. 40 | buffer_y (tensor): label buffer. 41 | cur_x (tensor): current input data tensor. 42 | cur_y (tensor): current input label tensor. 43 | num_retrieve (int): number of data instances to be retrieved. 44 | Returns 45 | ret_x (tensor): retrieved data tensor. 46 | ret_y (tensor): retrieved label tensor. 47 | """ 48 | cur_x = maybe_cuda(cur_x) 49 | cur_y = maybe_cuda(cur_y) 50 | 51 | # Reset and update ClassBalancedRandomSampling cache if ASER update is not enabled 52 | if not self.is_aser_upt: 53 | ClassBalancedRandomSampling.update_cache(buffer_y, self.out_dim) 54 | 55 | # Get candidate data for retrieval (i.e., cand <- class balanced subsamples from memory) 56 | cand_x, cand_y, cand_ind = \ 57 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, device=self.device) 58 | 59 | # Type 1 - Adversarial SV 60 | # Get evaluation data for type 1 (i.e., eval <- current input) 61 | eval_adv_x, eval_adv_y = cur_x, cur_y 62 | # Compute adversarial Shapley value of candidate data 63 | # (i.e., sv wrt current input) 64 | sv_matrix_adv = compute_knn_sv(model, eval_adv_x, eval_adv_y, cand_x, cand_y, self.k, device=self.device) 65 | 66 | if self.aser_type != "neg_sv": 67 | # Type 2 - Cooperative SV 68 | # Get evaluation data for type 2 69 | # (i.e., eval <- class balanced subsamples from memory excluding those already in candidate set) 70 | excl_indices = set(cand_ind.tolist()) 71 | eval_coop_x, eval_coop_y, _ = \ 72 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, 73 | excl_indices=excl_indices, device=self.device) 74 | # Compute Shapley value 75 | sv_matrix_coop = \ 76 | compute_knn_sv(model, eval_coop_x, eval_coop_y, cand_x, cand_y, self.k, device=self.device) 77 | if self.aser_type == "asv": 78 | # Use extremal SVs for computation 79 | sv = sv_matrix_coop.max(0).values - sv_matrix_adv.min(0).values 80 | else: 81 | # Use mean variation for aser_type == "asvm" or anything else 82 | sv = sv_matrix_coop.mean(0) - sv_matrix_adv.mean(0) 83 | else: 84 | # aser_type == "neg_sv" 85 | # No Type 1 - Cooperative SV; Use sum of Adversarial SV only 86 | sv = sv_matrix_adv.sum(0) * -1 87 | 88 | ret_ind = sv.argsort(descending=True) 89 | 90 | ret_x = cand_x[ret_ind][:num_retrieve] 91 | ret_y = cand_y[ret_ind][:num_retrieve] 92 | return ret_x, ret_y 93 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def maybe_cuda(what, use_cuda=True, **kw): 5 | """ 6 | Moves `what` to CUDA and returns it, if `use_cuda` and it's available. 7 | Args: 8 | what (object): any object to move to eventually gpu 9 | use_cuda (bool): if we want to use gpu or cpu. 10 | Returns 11 | object: the same object but eventually moved to gpu. 12 | """ 13 | 14 | if use_cuda is not False and torch.cuda.is_available(): 15 | what = what.cuda() 16 | return what 17 | 18 | 19 | def boolean_string(s): 20 | if s not in {'False', 'True'}: 21 | raise ValueError('Not a valid boolean string') 22 | return s == 'True' 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n): 36 | self.sum += val * n 37 | self.count += n 38 | 39 | def avg(self): 40 | if self.count == 0: 41 | return 0 42 | return float(self.sum) / self.count 43 | 44 | 45 | def mini_batch_deep_features(model, total_x, num): 46 | """ 47 | Compute deep features with mini-batches. 48 | Args: 49 | model (object): neural network. 50 | total_x (tensor): data tensor. 51 | num (int): number of data. 52 | Returns 53 | deep_features (tensor): deep feature representation of data tensor. 54 | """ 55 | is_train = False 56 | if model.training: 57 | is_train = True 58 | model.eval() 59 | if hasattr(model, "features"): 60 | model_has_feature_extractor = True 61 | else: 62 | model_has_feature_extractor = False 63 | # delete the last fully connected layer 64 | modules = list(model.children())[:-1] 65 | # make feature extractor 66 | model_features = torch.nn.Sequential(*modules) 67 | 68 | with torch.no_grad(): 69 | bs = 64 70 | num_itr = num // bs + int(num % bs > 0) 71 | sid = 0 72 | deep_features_list = [] 73 | for i in range(num_itr): 74 | eid = sid + bs if i != num_itr - 1 else num 75 | batch_x = total_x[sid: eid] 76 | 77 | if model_has_feature_extractor: 78 | batch_deep_features_ = model.features(batch_x) 79 | else: 80 | batch_deep_features_ = torch.squeeze(model_features(batch_x)) 81 | 82 | deep_features_list.append(batch_deep_features_.reshape((batch_x.size(0), -1))) 83 | sid = eid 84 | if num_itr == 1: 85 | deep_features_ = deep_features_list[0] 86 | else: 87 | deep_features_ = torch.cat(deep_features_list, 0) 88 | if is_train: 89 | model.train() 90 | return deep_features_ 91 | 92 | 93 | def euclidean_distance(u, v): 94 | euclidean_distance_ = (u - v).pow(2).sum(1) 95 | return euclidean_distance_ 96 | 97 | 98 | def ohe_label(label_tensor, dim, device="cpu"): 99 | # Returns one-hot-encoding of input label tensor 100 | n_labels = label_tensor.size(0) 101 | zero_tensor = torch.zeros((n_labels, dim), device=device, dtype=torch.long) 102 | return zero_tensor.scatter_(1, label_tensor.reshape((n_labels, 1)), 1) 103 | 104 | 105 | def nonzero_indices(bool_mask_tensor): 106 | # Returns tensor which contains indices of nonzero elements in bool_mask_tensor 107 | return bool_mask_tensor.nonzero(as_tuple=True)[0] 108 | 109 | 110 | class EarlyStopping(): 111 | def __init__(self, min_delta, patience, cumulative_delta): 112 | self.min_delta = min_delta 113 | self.patience = patience 114 | self.cumulative_delta = cumulative_delta 115 | self.counter = 0 116 | self.best_score = None 117 | 118 | def step(self, score): 119 | if self.best_score is None: 120 | self.best_score = score 121 | elif score <= self.best_score + self.min_delta: 122 | if not self.cumulative_delta and score > self.best_score: 123 | self.best_score = score 124 | self.counter += 1 125 | if self.counter >= self.patience: 126 | return True 127 | else: 128 | self.best_score = score 129 | self.counter = 0 130 | return False 131 | 132 | def reset(self): 133 | self.counter = 0 134 | self.best_score = None 135 | -------------------------------------------------------------------------------- /agents/agem.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.buffer.buffer import Buffer 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class AGEM(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(AGEM, self).__init__(model, opt, params) 13 | self.buffer = Buffer(model, params) 14 | self.mem_size = params.mem_size 15 | self.eps_mem_batch = params.eps_mem_batch 16 | self.mem_iters = params.mem_iters 17 | 18 | def train_learner(self, x_train, y_train): 19 | self.before_train(x_train, y_train) 20 | 21 | # set up loader 22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 24 | drop_last=True) 25 | # set up model 26 | self.model = self.model.train() 27 | 28 | # setup tracker 29 | losses_batch = AverageMeter() 30 | acc_batch = AverageMeter() 31 | 32 | for ep in range(self.epoch): 33 | for i, batch_data in enumerate(train_loader): 34 | # batch update 35 | batch_x, batch_y = batch_data 36 | batch_x = maybe_cuda(batch_x, self.cuda) 37 | batch_y = maybe_cuda(batch_y, self.cuda) 38 | for j in range(self.mem_iters): 39 | logits = self.forward(batch_x) 40 | loss = self.criterion(logits, batch_y) 41 | if self.params.trick['kd_trick']: 42 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 43 | self.kd_manager.get_kd_loss(logits, batch_x) 44 | if self.params.trick['kd_trick_star']: 45 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 46 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x) 47 | _, pred_label = torch.max(logits, 1) 48 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 49 | # update tracker 50 | acc_batch.update(correct_cnt, batch_y.size(0)) 51 | losses_batch.update(loss, batch_y.size(0)) 52 | # backward 53 | self.opt.zero_grad() 54 | loss.backward() 55 | 56 | if self.task_seen > 0: 57 | # sample from memory of previous tasks 58 | mem_x, mem_y = self.buffer.retrieve() 59 | if mem_x.size(0) > 0: 60 | params = [p for p in self.model.parameters() if p.requires_grad] 61 | # gradient computed using current batch 62 | grad = [p.grad.clone() for p in params] 63 | mem_x = maybe_cuda(mem_x, self.cuda) 64 | mem_y = maybe_cuda(mem_y, self.cuda) 65 | mem_logits = self.forward(mem_x) 66 | loss_mem = self.criterion(mem_logits, mem_y) 67 | self.opt.zero_grad() 68 | loss_mem.backward() 69 | # gradient computed using memory samples 70 | grad_ref = [p.grad.clone() for p in params] 71 | 72 | # inner product of grad and grad_ref 73 | prod = sum([torch.sum(g * g_r) for g, g_r in zip(grad, grad_ref)]) 74 | if prod < 0: 75 | prod_ref = sum([torch.sum(g_r ** 2) for g_r in grad_ref]) 76 | # do projection 77 | grad = [g - prod / prod_ref * g_r for g, g_r in zip(grad, grad_ref)] 78 | # replace params' grad 79 | for g, p in zip(grad, params): 80 | p.grad.data.copy_(g) 81 | self.opt.step() 82 | # update mem 83 | self.buffer.update(batch_x, batch_y) 84 | 85 | if i % 100 == 1 and self.verbose: 86 | print( 87 | '==>>> it: {}, avg. loss: {:.6f}, ' 88 | 'running train acc: {:.3f}' 89 | .format(i, losses_batch.avg(), acc_batch.avg()) 90 | ) 91 | self.after_train() -------------------------------------------------------------------------------- /continuum/dataset_scripts/openloris.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from PIL import Image 3 | import numpy as np 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | import time 6 | from continuum.data_utils import shuffle_data 7 | 8 | 9 | class OpenLORIS(DatasetBase): 10 | """ 11 | tasks_nums is predefined and it depends on the ns_type. 12 | """ 13 | def __init__(self, scenario, params): # scenario refers to "ni" or "nc" 14 | dataset = 'openloris' 15 | self.ns_type = params.ns_type 16 | task_nums = openloris_ntask[self.ns_type] # ns_type can be (illumination, occlusion, pixel, clutter, sequence) 17 | super(OpenLORIS, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | s = time.time() 22 | self.train_set = [] 23 | for batch_num in range(1, self.task_nums+1): 24 | train_x = [] 25 | train_y = [] 26 | test_x = [] 27 | test_y = [] 28 | for i in range(len(datapath)): 29 | train_temp = glob.glob('datasets/openloris/' + self.ns_type + '/train/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 30 | 31 | train_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in train_temp]) 32 | train_y.extend([i] * len(train_temp)) 33 | 34 | test_temp = glob.glob( 35 | 'datasets/openloris/' + self.ns_type + '/test/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 36 | 37 | test_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in test_temp]) 38 | test_y.extend([i] * len(test_temp)) 39 | 40 | print(" --> batch{}'-dataset consisting of {} samples".format(batch_num, len(train_x))) 41 | print(" --> test'-dataset consisting of {} samples".format(len(test_x))) 42 | self.train_set.append((np.array(train_x), np.array(train_y))) 43 | self.test_set.append((np.array(test_x), np.array(test_y))) 44 | e = time.time() 45 | print('loading time: {}'.format(str(e - s))) 46 | 47 | def new_run(self, **kwargs): 48 | pass 49 | 50 | def new_task(self, cur_task, **kwargs): 51 | train_x, train_y = self.train_set[cur_task] 52 | # get val set 53 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 54 | val_size = int(len(train_x_rdm) * self.params.val_size) 55 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 56 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 57 | self.val_set.append((val_data_rdm, val_label_rdm)) 58 | labels = set(train_label_rdm) 59 | return train_data_rdm, train_label_rdm, labels 60 | 61 | def setup(self, **kwargs): 62 | pass 63 | 64 | 65 | 66 | openloris_ntask = { 67 | 'illumination': 9, 68 | 'occlusion': 9, 69 | 'pixel': 9, 70 | 'clutter': 9, 71 | 'sequence': 12 72 | } 73 | 74 | datapath = ['bottle_01', 'bottle_02', 'bottle_03', 'bottle_04', 'bowl_01', 'bowl_02', 'bowl_03', 'bowl_04', 'bowl_05', 75 | 'corkscrew_01', 'cottonswab_01', 'cottonswab_02', 'cup_01', 'cup_02', 'cup_03', 'cup_04', 'cup_05', 76 | 'cup_06', 'cup_07', 'cup_08', 'cup_10', 'cushion_01', 'cushion_02', 'cushion_03', 'glasses_01', 77 | 'glasses_02', 'glasses_03', 'glasses_04', 'knife_01', 'ladle_01', 'ladle_02', 'ladle_03', 'ladle_04', 78 | 'mask_01', 'mask_02', 'mask_03', 'mask_04', 'mask_05', 'paper_cutter_01', 'paper_cutter_02', 79 | 'paper_cutter_03', 'paper_cutter_04', 'pencil_01', 'pencil_02', 'pencil_03', 'pencil_04', 'pencil_05', 80 | 'plasticbag_01', 'plasticbag_02', 'plasticbag_03', 'plug_01', 'plug_02', 'plug_03', 'plug_04', 'pot_01', 81 | 'scissors_01', 'scissors_02', 'scissors_03', 'stapler_01', 'stapler_02', 'stapler_03', 'thermometer_01', 82 | 'thermometer_02', 'thermometer_03', 'toy_01', 'toy_02', 'toy_03', 'toy_04', 'toy_05','nail_clippers_01','nail_clippers_02', 83 | 'nail_clippers_03', 'bracelet_01', 'bracelet_02','bracelet_03', 'comb_01','comb_02', 84 | 'comb_03', 'umbrella_01','umbrella_02','umbrella_03','socks_01','socks_02','socks_03', 85 | 'toothpaste_01','toothpaste_02','toothpaste_03','wallet_01','wallet_02','wallet_03', 86 | 'headphone_01','headphone_02','headphone_03', 'key_01','key_02','key_03', 87 | 'battery_01', 'battery_02', 'mouse_01', 'pencilcase_01', 'pencilcase_02', 'tape_01', 88 | 'chopsticks_01', 'chopsticks_02', 'chopsticks_03', 89 | 'notebook_01', 'notebook_02', 'notebook_03', 90 | 'spoon_01', 'spoon_02', 'spoon_03', 91 | 'tissue_01', 'tissue_02', 'tissue_03', 92 | 'clamp_01', 'clamp_02', 'hat_01', 'hat_02', 'u_disk_01', 'u_disk_02', 'swimming_glasses_01' 93 | ] 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /agents/ewc_pp.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | 8 | class EWC_pp(ContinualLearner): 9 | def __init__(self, model, opt, params): 10 | super(EWC_pp, self).__init__(model, opt, params) 11 | self.weights = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 12 | self.lambda_ = params.lambda_ 13 | self.alpha = params.alpha 14 | self.fisher_update_after = params.fisher_update_after 15 | self.prev_params = {} 16 | self.running_fisher = self.init_fisher() 17 | self.tmp_fisher = self.init_fisher() 18 | self.normalized_fisher = self.init_fisher() 19 | 20 | def train_learner(self, x_train, y_train): 21 | self.before_train(x_train, y_train) 22 | # set up loader 23 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 24 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 25 | drop_last=True) 26 | # setup tracker 27 | losses_batch = AverageMeter() 28 | acc_batch = AverageMeter() 29 | 30 | # set up model 31 | self.model.train() 32 | 33 | for ep in range(self.epoch): 34 | for i, batch_data in enumerate(train_loader): 35 | # batch update 36 | batch_x, batch_y = batch_data 37 | batch_x = maybe_cuda(batch_x, self.cuda) 38 | batch_y = maybe_cuda(batch_y, self.cuda) 39 | 40 | # update the running fisher 41 | if (ep * len(train_loader) + i + 1) % self.fisher_update_after == 0: 42 | self.update_running_fisher() 43 | 44 | out = self.forward(batch_x) 45 | loss = self.total_loss(out, batch_y) 46 | if self.params.trick['kd_trick']: 47 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 48 | self.kd_manager.get_kd_loss(out, batch_x) 49 | if self.params.trick['kd_trick_star']: 50 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 51 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x) 52 | # update tracker 53 | losses_batch.update(loss.item(), batch_y.size(0)) 54 | _, pred_label = torch.max(out, 1) 55 | acc = (pred_label == batch_y).sum().item() / batch_y.size(0) 56 | acc_batch.update(acc, batch_y.size(0)) 57 | # backward 58 | self.opt.zero_grad() 59 | loss.backward() 60 | 61 | # accumulate the fisher of current batch 62 | self.accum_fisher() 63 | self.opt.step() 64 | 65 | if i % 100 == 1 and self.verbose: 66 | print( 67 | '==>>> it: {}, avg. loss: {:.6f}, ' 68 | 'running train acc: {:.3f}' 69 | .format(i, losses_batch.avg(), acc_batch.avg()) 70 | ) 71 | 72 | # save params for current task 73 | for n, p in self.weights.items(): 74 | self.prev_params[n] = p.clone().detach() 75 | 76 | # update normalized fisher of current task 77 | max_fisher = max([torch.max(m) for m in self.running_fisher.values()]) 78 | min_fisher = min([torch.min(m) for m in self.running_fisher.values()]) 79 | for n, p in self.running_fisher.items(): 80 | self.normalized_fisher[n] = (p - min_fisher) / (max_fisher - min_fisher + 1e-32) 81 | self.after_train() 82 | 83 | def total_loss(self, inputs, targets): 84 | # cross entropy loss 85 | loss = self.criterion(inputs, targets) 86 | if len(self.prev_params) > 0: 87 | # add regularization loss 88 | reg_loss = 0 89 | for n, p in self.weights.items(): 90 | reg_loss += (self.normalized_fisher[n] * (p - self.prev_params[n]) ** 2).sum() 91 | loss += self.lambda_ * reg_loss 92 | return loss 93 | 94 | def init_fisher(self): 95 | return {n: p.clone().detach().fill_(0) for n, p in self.model.named_parameters() if p.requires_grad} 96 | 97 | def update_running_fisher(self): 98 | for n, p in self.running_fisher.items(): 99 | self.running_fisher[n] = (1. - self.alpha) * p \ 100 | + 1. / self.fisher_update_after * self.alpha * self.tmp_fisher[n] 101 | # reset the accumulated fisher 102 | self.tmp_fisher = self.init_fisher() 103 | 104 | def accum_fisher(self): 105 | for n, p in self.tmp_fisher.items(): 106 | p += self.weights[n].grad ** 2 -------------------------------------------------------------------------------- /agents/exp_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match 7 | from utils.utils import maybe_cuda, AverageMeter 8 | 9 | 10 | class ExperienceReplay(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(ExperienceReplay, self).__init__(model, opt, params) 13 | self.buffer = Buffer(model, params) 14 | self.mem_size = params.mem_size 15 | self.agent = params.agent 16 | self.eps_mem_batch = params.eps_mem_batch 17 | self.mem_iters = params.mem_iters 18 | 19 | def train_learner(self, x_train, y_train): 20 | self.before_train(x_train, y_train) 21 | # set up loader 22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 24 | drop_last=True) 25 | # set up model 26 | self.model = self.model.train() 27 | 28 | # setup tracker 29 | losses_batch = AverageMeter() 30 | losses_mem = AverageMeter() 31 | acc_batch = AverageMeter() 32 | acc_mem = AverageMeter() 33 | 34 | for ep in range(self.epoch): 35 | for i, batch_data in enumerate(train_loader): 36 | # batch update 37 | batch_x, batch_y = batch_data 38 | batch_x = maybe_cuda(batch_x, self.cuda) 39 | batch_y = maybe_cuda(batch_y, self.cuda) 40 | for j in range(self.mem_iters): 41 | logits,_ = self.model.forward(batch_x) 42 | loss = self.criterion(logits, batch_y) 43 | if self.params.trick['kd_trick']: 44 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 45 | self.kd_manager.get_kd_loss(logits, batch_x) 46 | if self.params.trick['kd_trick_star']: 47 | loss = 1/((self.task_seen + 1) ** 0.5) * loss + \ 48 | (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x) 49 | _, pred_label = torch.max(logits, 1) 50 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 51 | # update tracker 52 | acc_batch.update(correct_cnt, batch_y.size(0)) 53 | losses_batch.update(loss, batch_y.size(0)) 54 | # backward 55 | self.opt.zero_grad() 56 | loss.backward() 57 | 58 | # mem update 59 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 60 | if mem_x.size(0) > 0: 61 | mem_x = maybe_cuda(mem_x, self.cuda) 62 | mem_y = maybe_cuda(mem_y, self.cuda) 63 | mem_logits,_ = self.model.forward(mem_x) 64 | loss_mem = self.criterion(mem_logits, mem_y) 65 | if self.params.trick['kd_trick']: 66 | loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \ 67 | self.kd_manager.get_kd_loss(mem_logits, mem_x) 68 | if self.params.trick['kd_trick_star']: 69 | loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \ 70 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(mem_logits, 71 | mem_x) 72 | # update tracker 73 | losses_mem.update(loss_mem, mem_y.size(0)) 74 | _, pred_label = torch.max(mem_logits, 1) 75 | correct_cnt = (pred_label == mem_y).sum().item() / mem_y.size(0) 76 | acc_mem.update(correct_cnt, mem_y.size(0)) 77 | 78 | loss_mem.backward() 79 | 80 | if self.params.update == 'ASER' or self.params.retrieve == 'ASER': 81 | # opt update 82 | self.opt.zero_grad() 83 | combined_batch = torch.cat((mem_x, batch_x)) 84 | combined_labels = torch.cat((mem_y, batch_y)) 85 | combined_logits = self.model.forward(combined_batch) 86 | loss_combined = self.criterion(combined_logits, combined_labels) 87 | loss_combined.backward() 88 | self.opt.step() 89 | else: 90 | self.opt.step() 91 | 92 | # update mem 93 | self.buffer.update(batch_x, batch_y) 94 | 95 | if i % 100 == 1 and self.verbose: 96 | print( 97 | '==>>> it: {}, avg. loss: {:.6f}, ' 98 | 'running train acc: {:.3f}' 99 | .format(i, losses_batch.avg(), acc_batch.avg()) 100 | ) 101 | print( 102 | '==>>> it: {}, mem avg. loss: {:.6f}, ' 103 | 'running mem acc: {:.3f}' 104 | .format(i, losses_mem.avg(), acc_mem.avg()) 105 | ) 106 | self.after_train() -------------------------------------------------------------------------------- /continuum/dataset_scripts/core50.py: -------------------------------------------------------------------------------- 1 | import os 2 | from continuum.dataset_scripts.dataset_base import DatasetBase 3 | import pickle as pkl 4 | import logging 5 | from hashlib import md5 6 | import numpy as np 7 | from PIL import Image 8 | from continuum.data_utils import shuffle_data, load_task_with_labels 9 | import time 10 | 11 | core50_ntask = { 12 | 'ni': 8, 13 | 'nc': 9, 14 | 'nic': 79, 15 | 'nicv2_79': 79, 16 | 'nicv2_196': 196, 17 | 'nicv2_391': 391 18 | } 19 | 20 | 21 | class CORE50(DatasetBase): 22 | def __init__(self, scenario, params): 23 | if isinstance(params.num_runs, int) and params.num_runs > 10: 24 | raise Exception('the max number of runs for CORE50 is 10') 25 | dataset = 'core50' 26 | task_nums = core50_ntask[scenario] 27 | super(CORE50, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 28 | 29 | 30 | def download_load(self): 31 | 32 | print("Loading paths...") 33 | with open(os.path.join(self.root, 'paths.pkl'), 'rb') as f: 34 | self.paths = pkl.load(f) 35 | 36 | print("Loading LUP...") 37 | with open(os.path.join(self.root, 'LUP.pkl'), 'rb') as f: 38 | self.LUP = pkl.load(f) 39 | 40 | print("Loading labels...") 41 | with open(os.path.join(self.root, 'labels.pkl'), 'rb') as f: 42 | self.labels = pkl.load(f) 43 | 44 | 45 | 46 | def setup(self, cur_run): 47 | self.val_set = [] 48 | self.test_set = [] 49 | print('Loading test set...') 50 | test_idx_list = self.LUP[self.scenario][cur_run][-1] 51 | 52 | #test paths 53 | test_paths = [] 54 | for idx in test_idx_list: 55 | test_paths.append(os.path.join(self.root, self.paths[idx])) 56 | 57 | # test imgs 58 | self.test_data = self.get_batch_from_paths(test_paths) 59 | self.test_label = np.asarray(self.labels[self.scenario][cur_run][-1]) 60 | 61 | 62 | if self.scenario == 'nc': 63 | self.task_labels = self.labels[self.scenario][cur_run][:-1] 64 | for labels in self.task_labels: 65 | labels = list(set(labels)) 66 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 67 | self.test_set.append((x_test, y_test)) 68 | elif self.scenario == 'ni': 69 | self.test_set = [(self.test_data, self.test_label)] 70 | 71 | def new_task(self, cur_task, **kwargs): 72 | cur_run = kwargs['cur_run'] 73 | s = time.time() 74 | train_idx_list = self.LUP[self.scenario][cur_run][cur_task] 75 | print("Loading data...") 76 | # Getting the actual paths 77 | train_paths = [] 78 | for idx in train_idx_list: 79 | train_paths.append(os.path.join(self.root, self.paths[idx])) 80 | # loading imgs 81 | train_x = self.get_batch_from_paths(train_paths) 82 | train_y = self.labels[self.scenario][cur_run][cur_task] 83 | train_y = np.asarray(train_y) 84 | # get val set 85 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 86 | val_size = int(len(train_x_rdm) * self.params.val_size) 87 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 88 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 89 | self.val_set.append((val_data_rdm, val_label_rdm)) 90 | e = time.time() 91 | print('loading time {}'.format(str(e-s))) 92 | return train_data_rdm, train_label_rdm, set(train_label_rdm) 93 | 94 | 95 | def new_run(self, **kwargs): 96 | cur_run = kwargs['cur_run'] 97 | self.setup(cur_run) 98 | 99 | 100 | @staticmethod 101 | def get_batch_from_paths(paths, compress=False, snap_dir='', 102 | on_the_fly=True, verbose=False): 103 | """ Given a number of abs. paths it returns the numpy array 104 | of all the images. """ 105 | 106 | # Getting root logger 107 | log = logging.getLogger('mylogger') 108 | 109 | # If we do not process data on the fly we check if the same train 110 | # filelist has been already processed and saved. If so, we load it 111 | # directly. In either case we end up returning x and y, as the full 112 | # training set and respective labels. 113 | num_imgs = len(paths) 114 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest() 115 | log.debug("Paths Hex: " + str(hexdigest)) 116 | loaded = False 117 | x = None 118 | file_path = None 119 | 120 | if compress: 121 | file_path = snap_dir + hexdigest + ".npz" 122 | if os.path.exists(file_path) and not on_the_fly: 123 | loaded = True 124 | with open(file_path, 'rb') as f: 125 | npzfile = np.load(f) 126 | x, y = npzfile['x'] 127 | else: 128 | x_file_path = snap_dir + hexdigest + "_x.bin" 129 | if os.path.exists(x_file_path) and not on_the_fly: 130 | loaded = True 131 | with open(x_file_path, 'rb') as f: 132 | x = np.fromfile(f, dtype=np.uint8) \ 133 | .reshape(num_imgs, 128, 128, 3) 134 | 135 | # Here we actually load the images. 136 | if not loaded: 137 | # Pre-allocate numpy arrays 138 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8) 139 | 140 | for i, path in enumerate(paths): 141 | if verbose: 142 | print("\r" + path + " processed: " + str(i + 1), end='') 143 | x[i] = np.array(Image.open(path)) 144 | 145 | if verbose: 146 | print() 147 | 148 | if not on_the_fly: 149 | # Then we save x 150 | if compress: 151 | with open(file_path, 'wb') as g: 152 | np.savez_compressed(g, x=x) 153 | else: 154 | x.tofile(snap_dir + hexdigest + "_x.bin") 155 | 156 | assert (x is not None), 'Problems loading data. x is None!' 157 | 158 | return x 159 | -------------------------------------------------------------------------------- /agents/exp_replay_dvc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | import torch.nn as nn 7 | from utils.setup_elements import transforms_match, input_size_match 8 | from utils.utils import maybe_cuda, AverageMeter 9 | from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale 10 | from loss import agmax_loss, cross_entropy_loss 11 | 12 | 13 | class ExperienceReplay_DVC(ContinualLearner): 14 | def __init__(self, model, opt, params): 15 | super(ExperienceReplay_DVC, self).__init__(model, opt, params) 16 | self.buffer = Buffer(model, params) 17 | self.mem_size = params.mem_size 18 | self.agent = params.agent 19 | self.dl_weight = params.dl_weight 20 | 21 | self.eps_mem_batch = params.eps_mem_batch 22 | self.mem_iters = params.mem_iters 23 | self.transform = nn.Sequential( 24 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)), 25 | RandomHorizontalFlip(), 26 | ColorJitter(0.4, 0.4, 0.4, 0.1), 27 | RandomGrayscale(p=0.2) 28 | 29 | ) 30 | self.L2loss = torch.nn.MSELoss() 31 | 32 | def train_learner(self, x_train, y_train): 33 | self.before_train(x_train, y_train) 34 | # set up loader 35 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 36 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 37 | drop_last=True) 38 | # set up model 39 | self.model = self.model.train() 40 | self.transform = self.transform.cuda() 41 | # setup tracker 42 | losses_batch = AverageMeter() 43 | losses_mem = AverageMeter() 44 | acc_batch = AverageMeter() 45 | acc_mem = AverageMeter() 46 | 47 | for ep in range(self.epoch): 48 | for i, batch_data in enumerate(train_loader): 49 | # batch update 50 | batch_x, batch_y = batch_data 51 | batch_x = maybe_cuda(batch_x, self.cuda) 52 | batch_x_aug = self.transform(batch_x) 53 | batch_y = maybe_cuda(batch_y, self.cuda) 54 | for j in range(self.mem_iters): 55 | y = self.model(batch_x, batch_x_aug) 56 | z, zt, _,_ = y 57 | ce = cross_entropy_loss(z, zt, batch_y, label_smoothing=0) 58 | 59 | 60 | agreement_loss, dl = agmax_loss(y, batch_y, dl_weight=self.dl_weight) 61 | loss = ce + agreement_loss + dl 62 | 63 | if self.params.trick['kd_trick']: 64 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 65 | self.kd_manager.get_kd_loss(z, batch_x) 66 | if self.params.trick['kd_trick_star']: 67 | loss = 1/((self.task_seen + 1) ** 0.5) * loss + \ 68 | (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(z, batch_x) 69 | _, pred_label = torch.max(z, 1) 70 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 71 | # update tracker 72 | acc_batch.update(correct_cnt, batch_y.size(0)) 73 | losses_batch.update(loss, batch_y.size(0)) 74 | # backward 75 | self.opt.zero_grad() 76 | loss.backward() 77 | 78 | # mem update 79 | if self.params.retrieve == 'MGI': 80 | mem_x, mem_x_aug, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 81 | 82 | else: 83 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 84 | if mem_x.size(0) > 0: 85 | mem_x_aug = self.transform(mem_x) 86 | 87 | if mem_x.size(0) > 0: 88 | mem_x = maybe_cuda(mem_x, self.cuda) 89 | mem_x_aug = maybe_cuda(mem_x_aug, self.cuda) 90 | mem_y = maybe_cuda(mem_y, self.cuda) 91 | y = self.model(mem_x, mem_x_aug) 92 | z, zt, _,_ = y 93 | ce = cross_entropy_loss(z, zt, mem_y, label_smoothing=0) 94 | agreement_loss, dl = agmax_loss(y, mem_y, dl_weight=self.dl_weight) 95 | loss_mem = ce + agreement_loss + dl 96 | 97 | if self.params.trick['kd_trick']: 98 | loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \ 99 | self.kd_manager.get_kd_loss(z, mem_x) 100 | if self.params.trick['kd_trick_star']: 101 | loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \ 102 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(z, 103 | mem_x) 104 | # update tracker 105 | losses_mem.update(loss_mem, mem_y.size(0)) 106 | _, pred_label = torch.max(z, 1) 107 | correct_cnt = (pred_label == mem_y).sum().item() / mem_y.size(0) 108 | acc_mem.update(correct_cnt, mem_y.size(0)) 109 | 110 | loss_mem.backward() 111 | self.opt.step() 112 | 113 | # update mem 114 | self.buffer.update(batch_x, batch_y) 115 | 116 | if i % 100 == 1 and self.verbose: 117 | print( 118 | '==>>> it: {}, avg. loss: {:.6f}, ' 119 | 'running train acc: {:.3f}' 120 | .format(i, losses_batch.avg(), acc_batch.avg()) 121 | ) 122 | print( 123 | '==>>> it: {}, mem avg. loss: {:.6f}, ' 124 | 'running mem acc: {:.3f}' 125 | .format(i, losses_mem.avg(), acc_mem.avg()) 126 | ) 127 | self.after_train() 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /utils/buffer/gss_greedy_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils.buffer.buffer_utils import get_grad_vector, cosine_similarity 5 | from utils.utils import maybe_cuda 6 | 7 | class GSSGreedyUpdate(object): 8 | def __init__(self, params): 9 | super().__init__() 10 | # the number of gradient vectors to estimate new samples similarity, line 5 in alg.2 11 | self.mem_strength = params.gss_mem_strength 12 | self.gss_batch_size = params.gss_batch_size 13 | self.buffer_score = maybe_cuda(torch.FloatTensor(params.mem_size).fill_(0)) 14 | 15 | def update(self, buffer, x, y, **kwargs): 16 | buffer.model.eval() 17 | 18 | grad_dims = [] 19 | for param in buffer.model.parameters(): 20 | grad_dims.append(param.data.numel()) 21 | 22 | place_left = buffer.buffer_img.size(0) - buffer.current_index 23 | if place_left <= 0: # buffer is full 24 | batch_sim, mem_grads = self.get_batch_sim(buffer, grad_dims, x, y) 25 | if batch_sim < 0: 26 | buffer_score = self.buffer_score[:buffer.current_index].cpu() 27 | buffer_sim = (buffer_score - torch.min(buffer_score)) / \ 28 | ((torch.max(buffer_score) - torch.min(buffer_score)) + 0.01) 29 | # draw candidates for replacement from the buffer 30 | index = torch.multinomial(buffer_sim, x.size(0), replacement=False) 31 | # estimate the similarity of each sample in the recieved batch 32 | # to the randomly drawn samples from the buffer. 33 | batch_item_sim = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y) 34 | # normalize to [0,1] 35 | scaled_batch_item_sim = ((batch_item_sim + 1) / 2).unsqueeze(1) 36 | buffer_repl_batch_sim = ((self.buffer_score[index] + 1) / 2).unsqueeze(1) 37 | # draw an event to decide on replacement decision 38 | outcome = torch.multinomial(torch.cat((scaled_batch_item_sim, buffer_repl_batch_sim), dim=1), 1, 39 | replacement=False) 40 | # replace samples with outcome =1 41 | added_indx = torch.arange(end=batch_item_sim.size(0)) 42 | sub_index = outcome.squeeze(1).bool() 43 | buffer.buffer_img[index[sub_index]] = x[added_indx[sub_index]].clone() 44 | buffer.buffer_label[index[sub_index]] = y[added_indx[sub_index]].clone() 45 | self.buffer_score[index[sub_index]] = batch_item_sim[added_indx[sub_index]].clone() 46 | else: 47 | offset = min(place_left, x.size(0)) 48 | x = x[:offset] 49 | y = y[:offset] 50 | # first buffer insertion 51 | if buffer.current_index == 0: 52 | batch_sample_memory_cos = torch.zeros(x.size(0)) + 0.1 53 | else: 54 | # draw random samples from buffer 55 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims) 56 | # estimate a score for each added sample 57 | batch_sample_memory_cos = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y) 58 | buffer.buffer_img[buffer.current_index:buffer.current_index + offset].data.copy_(x) 59 | buffer.buffer_label[buffer.current_index:buffer.current_index + offset].data.copy_(y) 60 | self.buffer_score[buffer.current_index:buffer.current_index + offset] \ 61 | .data.copy_(batch_sample_memory_cos) 62 | buffer.current_index += offset 63 | buffer.model.train() 64 | 65 | def get_batch_sim(self, buffer, grad_dims, batch_x, batch_y): 66 | """ 67 | Args: 68 | buffer: memory buffer 69 | grad_dims: gradient dimensions 70 | batch_x: batch images 71 | batch_y: batch labels 72 | Returns: score of current batch, gradient from memory subsets 73 | """ 74 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims) 75 | buffer.model.zero_grad() 76 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y) 77 | loss.backward() 78 | batch_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0) 79 | batch_sim = max(cosine_similarity(mem_grads, batch_grad)) 80 | return batch_sim, mem_grads 81 | 82 | def get_rand_mem_grads(self, buffer, grad_dims): 83 | """ 84 | Args: 85 | buffer: memory buffer 86 | grad_dims: gradient dimensions 87 | Returns: gradient from memory subsets 88 | """ 89 | gss_batch_size = min(self.gss_batch_size, buffer.current_index) 90 | num_mem_subs = min(self.mem_strength, buffer.current_index // gss_batch_size) 91 | mem_grads = maybe_cuda(torch.zeros(num_mem_subs, sum(grad_dims), dtype=torch.float32)) 92 | shuffeled_inds = torch.randperm(buffer.current_index) 93 | for i in range(num_mem_subs): 94 | random_batch_inds = shuffeled_inds[ 95 | i * gss_batch_size:i * gss_batch_size + gss_batch_size] 96 | batch_x = buffer.buffer_img[random_batch_inds] 97 | batch_y = buffer.buffer_label[random_batch_inds] 98 | buffer.model.zero_grad() 99 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y) 100 | loss.backward() 101 | mem_grads[i].data.copy_(get_grad_vector(buffer.model.parameters, grad_dims)) 102 | return mem_grads 103 | 104 | def get_each_batch_sample_sim(self, buffer, grad_dims, mem_grads, batch_x, batch_y): 105 | """ 106 | Args: 107 | buffer: memory buffer 108 | grad_dims: gradient dimensions 109 | mem_grads: gradient from memory subsets 110 | batch_x: batch images 111 | batch_y: batch labels 112 | Returns: score of each sample from current batch 113 | """ 114 | cosine_sim = maybe_cuda(torch.zeros(batch_x.size(0))) 115 | for i, (x, y) in enumerate(zip(batch_x, batch_y)): 116 | buffer.model.zero_grad() 117 | ptloss = F.cross_entropy(buffer.model.forward(x.unsqueeze(0)), y.unsqueeze(0)) 118 | ptloss.backward() 119 | # add the new grad to the memory grads and add it is cosine similarity 120 | this_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0) 121 | cosine_sim[i] = max(cosine_similarity(mem_grads, this_grad)) 122 | return cosine_sim 123 | -------------------------------------------------------------------------------- /utils/buffer/aser_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import maybe_cuda, mini_batch_deep_features, euclidean_distance, nonzero_indices, ohe_label 3 | from utils.setup_elements import n_classes 4 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling 5 | 6 | 7 | def compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, k, device="cpu"): 8 | """ 9 | Compute KNN SV of candidate data w.r.t. evaluation data. 10 | Args: 11 | model (object): neural network. 12 | eval_x (tensor): evaluation data tensor. 13 | eval_y (tensor): evaluation label tensor. 14 | cand_x (tensor): candidate data tensor. 15 | cand_y (tensor): candidate label tensor. 16 | k (int): number of nearest neighbours. 17 | device (str): device for tensor allocation. 18 | Returns 19 | sv_matrix (tensor): KNN Shapley value matrix of candidate data w.r.t. evaluation data. 20 | """ 21 | # Compute KNN SV score for candidate samples w.r.t. evaluation samples 22 | n_eval = eval_x.size(0) 23 | n_cand = cand_x.size(0) 24 | # Initialize SV matrix to matrix of -1 25 | sv_matrix = torch.zeros((n_eval, n_cand), device=device) 26 | # Get deep features 27 | eval_df, cand_df = deep_features(model, eval_x, n_eval, cand_x, n_cand) 28 | # Sort indices based on distance in deep feature space 29 | sorted_ind_mat = sorted_cand_ind(eval_df, cand_df, n_eval, n_cand) 30 | 31 | # Evaluation set labels 32 | el = eval_y 33 | el_vec = el.repeat([n_cand, 1]).T 34 | # Sorted candidate set labels 35 | cl = cand_y[sorted_ind_mat] 36 | 37 | # Indicator function matrix 38 | indicator = (el_vec == cl).float() 39 | indicator_next = torch.zeros_like(indicator, device=device) 40 | indicator_next[:, 0:n_cand - 1] = indicator[:, 1:] 41 | indicator_diff = indicator - indicator_next 42 | 43 | cand_ind = torch.arange(n_cand, dtype=torch.float, device=device) + 1 44 | denom_factor = cand_ind.clone() 45 | denom_factor[:n_cand - 1] = denom_factor[:n_cand - 1] * k 46 | numer_factor = cand_ind.clone() 47 | numer_factor[k:n_cand - 1] = k 48 | numer_factor[n_cand - 1] = 1 49 | factor = numer_factor / denom_factor 50 | 51 | indicator_factor = indicator_diff * factor 52 | indicator_factor_cumsum = indicator_factor.flip(1).cumsum(1).flip(1) 53 | 54 | # Row indices 55 | row_ind = torch.arange(n_eval, device=device) 56 | row_mat = torch.repeat_interleave(row_ind, n_cand).reshape([n_eval, n_cand]) 57 | 58 | # Compute SV recursively 59 | sv_matrix[row_mat, sorted_ind_mat] = indicator_factor_cumsum 60 | 61 | return sv_matrix 62 | 63 | 64 | def deep_features(model, eval_x, n_eval, cand_x, n_cand): 65 | """ 66 | Compute deep features of evaluation and candidate data. 67 | Args: 68 | model (object): neural network. 69 | eval_x (tensor): evaluation data tensor. 70 | n_eval (int): number of evaluation data. 71 | cand_x (tensor): candidate data tensor. 72 | n_cand (int): number of candidate data. 73 | Returns 74 | eval_df (tensor): deep features of evaluation data. 75 | cand_df (tensor): deep features of evaluation data. 76 | """ 77 | # Get deep features 78 | if cand_x is None: 79 | num = n_eval 80 | total_x = eval_x 81 | else: 82 | num = n_eval + n_cand 83 | total_x = torch.cat((eval_x, cand_x), 0) 84 | 85 | # compute deep features with mini-batches 86 | total_x = maybe_cuda(total_x) 87 | deep_features_ = mini_batch_deep_features(model, total_x, num) 88 | 89 | eval_df = deep_features_[0:n_eval] 90 | cand_df = deep_features_[n_eval:] 91 | return eval_df, cand_df 92 | 93 | 94 | def sorted_cand_ind(eval_df, cand_df, n_eval, n_cand): 95 | """ 96 | Sort indices of candidate data according to 97 | their Euclidean distance to each evaluation data in deep feature space. 98 | Args: 99 | eval_df (tensor): deep features of evaluation data. 100 | cand_df (tensor): deep features of evaluation data. 101 | n_eval (int): number of evaluation data. 102 | n_cand (int): number of candidate data. 103 | Returns 104 | sorted_cand_ind (tensor): sorted indices of candidate set w.r.t. each evaluation data. 105 | """ 106 | # Sort indices of candidate set according to distance w.r.t. evaluation set in deep feature space 107 | # Preprocess feature vectors to facilitate vector-wise distance computation 108 | eval_df_repeat = eval_df.repeat([1, n_cand]).reshape([n_eval * n_cand, eval_df.shape[1]]) 109 | cand_df_tile = cand_df.repeat([n_eval, 1]) 110 | # Compute distance between evaluation and candidate feature vectors 111 | distance_vector = euclidean_distance(eval_df_repeat, cand_df_tile) 112 | # Turn distance vector into distance matrix 113 | distance_matrix = distance_vector.reshape((n_eval, n_cand)) 114 | # Sort candidate set indices based on distance 115 | sorted_cand_ind_ = distance_matrix.argsort(1) 116 | return sorted_cand_ind_ 117 | 118 | 119 | def add_minority_class_input(cur_x, cur_y, mem_size, num_class): 120 | """ 121 | Find input instances from minority classes, and concatenate them to evaluation data/label tensors later. 122 | This facilitates the inclusion of minority class samples into memory when ASER's update method is used under online-class incremental setting. 123 | 124 | More details: 125 | 126 | Evaluation set may not contain any samples from minority classes (i.e., those classes with very few number of corresponding samples stored in the memory). 127 | This happens after task changes in online-class incremental setting. 128 | Minority class samples can then get very low or negative KNN-SV, making it difficult to store any of them in the memory. 129 | 130 | By identifying minority class samples in the current input batch, and concatenating them to the evaluation set, 131 | KNN-SV of the minority class samples can be artificially boosted (i.e., positive value with larger magnitude). 132 | This allows to quickly accomodate new class samples in the memory right after task changes. 133 | 134 | Threshold for being a minority class is a hyper-parameter related to the class proportion. 135 | In this implementation, it is randomly selected between 0 and 1 / number of all classes for each current input batch. 136 | 137 | 138 | Args: 139 | cur_x (tensor): current input data tensor. 140 | cur_y (tensor): current input label tensor. 141 | mem_size (int): memory size. 142 | num_class (int): number of classes in dataset. 143 | Returns 144 | minority_batch_x (tensor): subset of current input data from minority class. 145 | minority_batch_y (tensor): subset of current input label from minority class. 146 | """ 147 | # Select input instances from minority classes that will be concatenated to pre-selected data 148 | threshold = torch.tensor(1).float().uniform_(0, 1 / num_class).item() 149 | 150 | # If number of buffered samples from certain class is lower than random threshold, 151 | # that class is minority class 152 | cls_proportion = ClassBalancedRandomSampling.class_num_cache.float() / mem_size 153 | minority_ind = nonzero_indices(cls_proportion[cur_y] < threshold) 154 | 155 | minority_batch_x = cur_x[minority_ind] 156 | minority_batch_y = cur_y[minority_ind] 157 | return minority_batch_x, minority_batch_y 158 | -------------------------------------------------------------------------------- /models/ndpm/ndpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler 4 | 5 | from utils.utils import maybe_cuda 6 | from utils.global_vars import * 7 | from .expert import Expert 8 | from .priors import CumulativePrior 9 | 10 | 11 | class Ndpm(nn.Module): 12 | def __init__(self, params): 13 | super().__init__() 14 | self.params = params 15 | self.experts = nn.ModuleList([Expert(params)]) 16 | self.stm_capacity = params.stm_capacity 17 | self.stm_x = [] 18 | self.stm_y = [] 19 | self.prior = CumulativePrior(params) 20 | 21 | def get_experts(self): 22 | return tuple(self.experts.children()) 23 | 24 | def forward(self, x): 25 | with torch.no_grad(): 26 | if len(self.experts) == 1: 27 | raise RuntimeError('There\'s no expert to run on the input') 28 | x = maybe_cuda(x) 29 | log_evid = -self.experts[-1].g.collect_nll(x) # [B, 1+K] 30 | log_evid = log_evid[:, 1:].unsqueeze(2) # [B, K, 1] 31 | log_prior = -self.prior.nl_prior()[1:] # [K] 32 | log_prior -= torch.logsumexp(log_prior, dim=0) 33 | log_prior = log_prior.unsqueeze(0).unsqueeze(2) # [1, K, 1] 34 | log_joint = log_prior + log_evid # [B, K, 1] 35 | if not MODELS_NDPM_NDPM_DISABLE_D: 36 | log_pred = self.experts[-1].d.collect_forward(x) # [B, 1+K, C] 37 | log_pred = log_pred[:, 1:, :] # [B, K, C] 38 | log_joint = log_joint + log_pred # [B, K, C] 39 | 40 | log_joint = log_joint.logsumexp(dim=1).squeeze() # [B,] or [B, C] 41 | return log_joint 42 | 43 | 44 | def learn(self, x, y): 45 | x, y = maybe_cuda(x), maybe_cuda(y) 46 | 47 | if MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS: 48 | self.stm_x.extend(torch.unbind(x.cpu())) 49 | self.stm_y.extend(torch.unbind(y.cpu())) 50 | else: 51 | # Determine the destination of each data point 52 | nll = self.experts[-1].collect_nll(x, y) # [B, 1+K] 53 | nl_prior = self.prior.nl_prior() # [1+K] 54 | nl_joint = nll + nl_prior.unsqueeze(0).expand( 55 | nll.size(0), -1) # [B, 1+K] 56 | 57 | # Save to short-term memory 58 | destination = maybe_cuda(torch.argmin(nl_joint, dim=1)) # [B] 59 | to_stm = destination == 0 # [B] 60 | self.stm_x.extend(torch.unbind(x[to_stm].cpu())) 61 | self.stm_y.extend(torch.unbind(y[to_stm].cpu())) 62 | 63 | # Train expert 64 | with torch.no_grad(): 65 | min_joint = nl_joint.min(dim=1)[0].view(-1, 1) 66 | to_expert = torch.exp(-nl_joint + min_joint) # [B, 1+K] 67 | to_expert[:, 0] = 0. # [B, 1+K] 68 | to_expert = \ 69 | to_expert / (to_expert.sum(dim=1).view(-1, 1) + 1e-7) 70 | 71 | # Compute losses per expert 72 | nll_for_train = nll * (1. - to_stm.float()).unsqueeze(1) # [B,1+K] 73 | losses = (nll_for_train * to_expert).sum(0) # [1+K] 74 | 75 | # Record expert usage 76 | expert_usage = to_expert.sum(dim=0) # [K+1] 77 | self.prior.record_usage(expert_usage) 78 | 79 | # Do lr_decay implicitly 80 | if MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY: 81 | losses = losses \ 82 | * self.params.stm_capacity / (self.prior.counts + 1e-8) 83 | loss = losses.sum() 84 | 85 | if loss.requires_grad: 86 | update_threshold = 0 87 | for k, usage in enumerate(expert_usage): 88 | if usage > update_threshold: 89 | self.experts[k].zero_grad() 90 | loss.backward() 91 | for k, usage in enumerate(expert_usage): 92 | if usage > update_threshold: 93 | self.experts[k].clip_grad() 94 | self.experts[k].optimizer_step() 95 | self.experts[k].lr_scheduler_step() 96 | 97 | # Sleep 98 | if len(self.stm_x) >= self.stm_capacity: 99 | dream_dataset = TensorDataset( 100 | torch.stack(self.stm_x), torch.stack(self.stm_y)) 101 | self.sleep(dream_dataset) 102 | self.stm_x = [] 103 | self.stm_y = [] 104 | 105 | def sleep(self, dream_dataset): 106 | print('\nGoing to sleep...') 107 | # Add new expert and optimizer 108 | expert = Expert(self.params, self.get_experts()) 109 | self.experts.append(expert) 110 | self.prior.add_expert() 111 | 112 | stacked_stm_x = torch.stack(self.stm_x) 113 | stacked_stm_y = torch.stack(self.stm_y) 114 | indices = torch.randperm(stacked_stm_x.size(0)) 115 | train_size = stacked_stm_x.size(0) - MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE 116 | dream_dataset = TensorDataset( 117 | stacked_stm_x[indices[:train_size]], 118 | stacked_stm_y[indices[:train_size]]) 119 | 120 | # Prepare data iterator 121 | self.prior.record_usage(len(dream_dataset), index=-1) 122 | dream_iterator = iter(DataLoader( 123 | dream_dataset, 124 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE, 125 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS, 126 | sampler=RandomSampler( 127 | dream_dataset, 128 | replacement=True, 129 | num_samples=( 130 | MODELS_NDPM_NDPM_SLEEP_STEP_G * 131 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE 132 | )) 133 | )) 134 | 135 | # Train generative component 136 | for step, (x, y) in enumerate(dream_iterator): 137 | step += 1 138 | x, y = maybe_cuda(x), maybe_cuda(y) 139 | g_loss = expert.g.nll(x, y, step=step) 140 | g_loss = (g_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY 141 | * expert.g.weight_decay_loss()) 142 | expert.g.zero_grad() 143 | g_loss.mean().backward() 144 | expert.g.clip_grad() 145 | expert.g.optimizer.step() 146 | 147 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0: 148 | print('\r [Sleep-G %6d] loss: %5.1f' % ( 149 | step, g_loss.mean() 150 | ), end='') 151 | print() 152 | 153 | dream_iterator = iter(DataLoader( 154 | dream_dataset, 155 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE, 156 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS, 157 | sampler=RandomSampler( 158 | dream_dataset, 159 | replacement=True, 160 | num_samples=( 161 | MODELS_NDPM_NDPM_SLEEP_STEP_D * 162 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE) 163 | ) 164 | )) 165 | 166 | # Train discriminative component 167 | if not MODELS_NDPM_NDPM_DISABLE_D: 168 | for step, (x, y) in enumerate(dream_iterator): 169 | step += 1 170 | x, y = maybe_cuda(x), maybe_cuda(y) 171 | d_loss = expert.d.nll(x, y, step=step) 172 | d_loss = (d_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY 173 | * expert.d.weight_decay_loss()) 174 | expert.d.zero_grad() 175 | d_loss.mean().backward() 176 | expert.d.clip_grad() 177 | expert.d.optimizer.step() 178 | 179 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0: 180 | print('\r [Sleep-D %6d] loss: %5.1f' % ( 181 | step, d_loss.mean() 182 | ), end='') 183 | 184 | expert.lr_scheduler_step() 185 | expert.lr_scheduler_step() 186 | expert.eval() 187 | print() 188 | 189 | @staticmethod 190 | def _nl_joint(nl_prior, nll): 191 | batch = nll.size(0) 192 | nl_prior = nl_prior.unsqueeze(0).expand(batch, -1) # [B, 1+K] 193 | return nll + nl_prior 194 | 195 | def train(self, mode=True): 196 | # Disabled 197 | pass 198 | -------------------------------------------------------------------------------- /features/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | #from .utils import load_state_dict_from_url 3 | 4 | from . import extractor 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(extractor.BaseModule): 25 | 26 | def __init__(self, config, name): #features, num_classes=1000, init_weights=True): 27 | super(VGG, self).__init__() 28 | 29 | #self.features = features 30 | self.name = name 31 | cfg = config["cfg"] 32 | in_channels = config["channels"] 33 | batch_norm = config["batch_norm"] 34 | self.features = make_layers(cfgs[cfg], batch_norm=batch_norm, in_channels=in_channels) 35 | self.n_features = 512 36 | #self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 37 | #self.classifier = nn.Sequential( 38 | # nn.Linear(512 * 7 * 7, 4096), 39 | # nn.ReLU(True), 40 | # nn.Dropout(), 41 | # nn.Linear(4096, 4096), 42 | # nn.ReLU(True), 43 | # nn.Dropout(), 44 | # nn.Linear(4096, num_classes), 45 | #) 46 | #if init_weights: 47 | #self._initialize_weights() 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | #x = self.avgpool(x) 52 | #x = torch.flatten(x, 1) 53 | #x = self.classifier(x) 54 | return x 55 | 56 | 57 | def _initialize_weights(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 61 | if m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | nn.init.constant_(m.weight, 1) 65 | nn.init.constant_(m.bias, 0) 66 | elif isinstance(m, nn.Linear): 67 | nn.init.normal_(m.weight, 0, 0.01) 68 | nn.init.constant_(m.bias, 0) 69 | 70 | 71 | def make_layers(cfg, batch_norm=False, in_channels=3): 72 | layers = [] 73 | #in_channels = 3 74 | for v in cfg: 75 | if v == 'M': 76 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 77 | else: 78 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 79 | if batch_norm: 80 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 81 | else: 82 | layers += [conv2d, nn.ReLU(inplace=True)] 83 | in_channels = v 84 | return nn.Sequential(*layers) 85 | 86 | 87 | cfgs = { 88 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 89 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 90 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 91 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 92 | } 93 | 94 | 95 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 96 | if pretrained: 97 | kwargs['init_weights'] = False 98 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 99 | if pretrained: 100 | state_dict = load_state_dict_from_url(model_urls[arch], 101 | progress=progress) 102 | model.load_state_dict(state_dict) 103 | return model 104 | 105 | 106 | def vgg11(pretrained=False, progress=True, **kwargs): 107 | r"""VGG 11-layer model (configuration "A") from 108 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | progress (bool): If True, displays a progress bar of the download to stderr 113 | """ 114 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 115 | 116 | 117 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 118 | r"""VGG 11-layer model (configuration "A") with batch normalization 119 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 120 | 121 | Args: 122 | pretrained (bool): If True, returns a model pre-trained on ImageNet 123 | progress (bool): If True, displays a progress bar of the download to stderr 124 | """ 125 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 126 | 127 | 128 | def vgg13(pretrained=False, progress=True, **kwargs): 129 | r"""VGG 13-layer model (configuration "B") 130 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 131 | 132 | Args: 133 | pretrained (bool): If True, returns a model pre-trained on ImageNet 134 | progress (bool): If True, displays a progress bar of the download to stderr 135 | """ 136 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 137 | 138 | 139 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 140 | r"""VGG 13-layer model (configuration "B") with batch normalization 141 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 142 | 143 | Args: 144 | pretrained (bool): If True, returns a model pre-trained on ImageNet 145 | progress (bool): If True, displays a progress bar of the download to stderr 146 | """ 147 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 148 | 149 | 150 | def vgg16(pretrained=False, progress=True, **kwargs): 151 | r"""VGG 16-layer model (configuration "D") 152 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | progress (bool): If True, displays a progress bar of the download to stderr 157 | """ 158 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 159 | 160 | 161 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 162 | r"""VGG 16-layer model (configuration "D") with batch normalization 163 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 164 | 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | progress (bool): If True, displays a progress bar of the download to stderr 168 | """ 169 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 170 | 171 | 172 | def vgg19(pretrained=False, progress=True, **kwargs): 173 | r"""VGG 19-layer model (configuration "E") 174 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 175 | 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | progress (bool): If True, displays a progress bar of the download to stderr 179 | """ 180 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 181 | 182 | 183 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 184 | r"""VGG 19-layer model (configuration 'E') with batch normalization 185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | progress (bool): If True, displays a progress bar of the download to stderr 190 | """ 191 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 192 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/facebookresearch/GradientEpisodicMemory 3 | & 4 | https://github.com/kuangliu/pytorch-cifar 5 | """ 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from torch.nn.functional import relu, avg_pool2d 9 | import torch 10 | from features.extractor import BaseModule 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(in_planes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion * planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 29 | stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion * planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = relu(out) 38 | return out 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 48 | stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion * 51 | planes, kernel_size=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion * planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion * planes, 58 | kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = relu(self.bn1(self.conv1(x))) 64 | out = relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | out = relu(out) 68 | return out 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes, nf, bias): 72 | super(ResNet, self).__init__() 73 | self.in_planes = nf 74 | self.conv1 = conv3x3(3, nf * 1) 75 | self.bn1 = nn.BatchNorm2d(nf * 1) 76 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes, bias=bias) 81 | 82 | 83 | def _make_layer(self, block, planes, num_blocks, stride): 84 | strides = [stride] + [1] * (num_blocks - 1) 85 | layers = [] 86 | for stride in strides: 87 | layers.append(block(self.in_planes, planes, stride)) 88 | self.in_planes = planes * block.expansion 89 | return nn.Sequential(*layers) 90 | 91 | def features(self, x): 92 | '''Features before FC layers''' 93 | out = relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = avg_pool2d(out, 4) 99 | out = out.contiguous().view(out.size(0), -1) 100 | return out 101 | 102 | def logits(self, x): 103 | '''Apply the last FC linear mapping to get logits''' 104 | x = self.linear(x) 105 | return x 106 | 107 | def forward(self, x): 108 | out = self.features(x) 109 | logits = self.logits(out) 110 | return logits,out 111 | 112 | 113 | class QNet(BaseModule): 114 | def __init__(self, 115 | n_units, 116 | n_classes): 117 | super(QNet, self).__init__() 118 | 119 | self.model = nn.Sequential( 120 | nn.Linear(2 * n_classes, n_units), 121 | nn.ReLU(True), 122 | nn.Linear(n_units, n_classes), 123 | ) 124 | 125 | def forward(self, zcat): 126 | zzt = self.model(zcat) 127 | return zzt 128 | 129 | 130 | class DVCNet(BaseModule): 131 | def __init__(self, 132 | backbone, 133 | n_units, 134 | n_classes, 135 | has_mi_qnet=True): 136 | super(DVCNet, self).__init__() 137 | 138 | self.backbone = backbone 139 | self.has_mi_qnet = has_mi_qnet 140 | 141 | if has_mi_qnet: 142 | self.qnet = QNet(n_units=n_units, 143 | n_classes=n_classes) 144 | 145 | def forward(self, x, xt): 146 | size = x.size(0) 147 | xx = torch.cat((x, xt)) 148 | zz,fea = self.backbone(xx) 149 | z = zz[0:size] 150 | zt = zz[size:] 151 | 152 | fea_z = fea[0:size] 153 | fea_zt = fea[size:] 154 | 155 | if not self.has_mi_qnet: 156 | return z, zt, None 157 | 158 | zcat = torch.cat((z, zt), dim=1) 159 | zzt = self.qnet(zcat) 160 | 161 | return z, zt, zzt,[torch.sum(torch.abs(fea_z), 1).reshape(-1, 1),torch.sum(torch.abs(fea_zt), 1).reshape(-1, 1)] 162 | 163 | 164 | def Reduced_ResNet18_DVC(nclasses, nf=20, bias=True): 165 | """ 166 | Reduced ResNet18 as in GEM MIR(note that nf=20). 167 | """ 168 | backnone = ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias) 169 | return DVCNet(backbone=backnone,n_units=128,n_classes=nclasses,has_mi_qnet=True) 170 | 171 | 172 | def Reduced_ResNet18(nclasses, nf=20, bias=True): 173 | """ 174 | Reduced ResNet18 as in GEM MIR(note that nf=20). 175 | """ 176 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias) 177 | 178 | 179 | 180 | def ResNet18(nclasses, nf=64, bias=True): 181 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias) 182 | 183 | ''' 184 | See https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 185 | ''' 186 | 187 | def ResNet34(nclasses, nf=64, bias=True): 188 | return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf, bias) 189 | 190 | def ResNet50(nclasses, nf=64, bias=True): 191 | return ResNet(Bottleneck, [3, 4, 6, 3], nclasses, nf, bias) 192 | 193 | 194 | def ResNet101(nclasses, nf=64, bias=True): 195 | return ResNet(Bottleneck, [3, 4, 23, 3], nclasses, nf, bias) 196 | 197 | 198 | def ResNet152(nclasses, nf=64, bias=True): 199 | return ResNet(Bottleneck, [3, 8, 36, 3], nclasses, nf, bias) 200 | 201 | 202 | class SupConResNet(nn.Module): 203 | """backbone + projection head""" 204 | def __init__(self, dim_in=160, head='mlp', feat_dim=128): 205 | super(SupConResNet, self).__init__() 206 | self.encoder = Reduced_ResNet18(100) 207 | if head == 'linear': 208 | self.head = nn.Linear(dim_in, feat_dim) 209 | elif head == 'mlp': 210 | self.head = nn.Sequential( 211 | nn.Linear(dim_in, dim_in), 212 | nn.ReLU(inplace=True), 213 | nn.Linear(dim_in, feat_dim) 214 | ) 215 | elif head == 'None': 216 | self.head = None 217 | else: 218 | raise NotImplementedError( 219 | 'head not supported: {}'.format(head)) 220 | 221 | def forward(self, x): 222 | feat = self.encoder.features(x) 223 | logit = self.encoder.linear(feat) 224 | if self.head: 225 | feat = F.normalize(self.head(feat), dim=1) 226 | else: 227 | feat = F.normalize(feat, dim=1) 228 | return feat 229 | 230 | def features(self, x): 231 | return self.encoder.features(x) -------------------------------------------------------------------------------- /continuum/non_stationary.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from skimage.filters import gaussian 6 | from continuum.data_utils import train_val_test_split_ni 7 | 8 | 9 | class Original(object): 10 | def __init__(self, x, y, unroll=False, color=False): 11 | if color: 12 | self.x = x / 255.0 13 | else: 14 | self.x = x 15 | self.next_x = self.x 16 | self.next_y = y 17 | self.y = y 18 | self.unroll = unroll 19 | 20 | def get_dims(self): 21 | # Get data input and output dimensions 22 | print("input size {}\noutput size {}".format(self.x.shape[1], self.y.shape[1])) 23 | return self.x.shape[1], self.y.shape[1] 24 | 25 | def show_sample(self, num_plot=1): 26 | # idx = np.random.choice(self.x.shape[0]) 27 | for i in range(num_plot): 28 | plt.subplot(1, 2, 1) 29 | if self.x[i].shape[2] == 1: 30 | plt.imshow(np.squeeze(self.x[i])) 31 | else: 32 | plt.imshow(self.x[i]) 33 | plt.title("original task image") 34 | plt.subplot(1, 2, 2) 35 | if self.x[i].shape[2] == 1: 36 | plt.imshow(np.squeeze(self.next_x[i])) 37 | else: 38 | plt.imshow(self.next_x[i]) 39 | plt.title(self.get_name()) 40 | plt.axis('off') 41 | plt.show() 42 | 43 | def create_output(self): 44 | if self.unroll: 45 | ret = self.next_x.reshape((-1, self.x.shape[1] ** 2)), self.next_y 46 | else: 47 | ret = self.next_x, self.next_y 48 | return ret 49 | 50 | @staticmethod 51 | def clip_minmax(l, min_=0., max_=1.): 52 | return np.clip(l, min_, max_) 53 | 54 | def get_name(self): 55 | if hasattr(self, 'factor'): 56 | return str(self.__class__.__name__) + '_' + str(self.factor) 57 | 58 | def next_task(self, *args): 59 | self.next_x = self.x 60 | self.next_y = self.y 61 | return self.create_output() 62 | 63 | 64 | class Noisy(Original): 65 | def __init__(self, x, y, full=False, color=False): 66 | super(Noisy, self).__init__(x, y, full, color) 67 | 68 | def next_task(self, noise_factor=0.8, sig=0.1, noise_type='Gaussian'): 69 | next_x = deepcopy(self.x) 70 | self.factor = noise_factor 71 | if noise_type == 'Gaussian': 72 | self.next_x = next_x + noise_factor * np.random.normal(loc=0.0, scale=sig, size=next_x.shape) 73 | elif noise_factor == 'S&P': 74 | # TODO implement S&P 75 | pass 76 | 77 | self.next_x = super().clip_minmax(self.next_x, 0, 1) 78 | 79 | return super().create_output() 80 | 81 | 82 | class Blurring(Original): 83 | def __init__(self, x, y, full=False, color=False): 84 | super(Blurring, self).__init__(x, y, full, color) 85 | 86 | def next_task(self, blurry_factor=0.6, blurry_type='Gaussian'): 87 | next_x = deepcopy(self.x) 88 | self.factor = blurry_factor 89 | if blurry_type == 'Gaussian': 90 | self.next_x = gaussian(next_x, sigma=blurry_factor, multichannel=True) 91 | elif blurry_type == 'Average': 92 | pass 93 | # TODO implement average 94 | 95 | self.next_x = super().clip_minmax(self.next_x, 0, 1) 96 | 97 | return super().create_output() 98 | 99 | 100 | class Occlusion(Original): 101 | def __init__(self, x, y, full=False, color=False): 102 | super(Occlusion, self).__init__(x, y, full, color) 103 | 104 | def next_task(self, occlusion_factor=0.2): 105 | next_x = deepcopy(self.x) 106 | self.factor = occlusion_factor 107 | self.image_size = next_x.shape[1] 108 | 109 | occlusion_size = int(occlusion_factor * self.image_size) 110 | half_size = occlusion_size // 2 111 | occlusion_x = random.randint(min(half_size, self.image_size - half_size), 112 | max(half_size, self.image_size - half_size)) 113 | occlusion_y = random.randint(min(half_size, self.image_size - half_size), 114 | max(half_size, self.image_size - half_size)) 115 | 116 | # self.next_x = next_x.reshape((-1, self.image_size, self.image_size)) 117 | 118 | next_x[:, max((occlusion_x - half_size), 0):min((occlusion_x + half_size), self.image_size), \ 119 | max((occlusion_y - half_size), 0):min((occlusion_y + half_size), self.image_size)] = 1 120 | 121 | self.next_x = next_x 122 | super().clip_minmax(self.next_x, 0, 1) 123 | 124 | return super().create_output() 125 | 126 | 127 | def test_ns(x, y, ns_type, change_factor): 128 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring} 129 | change = ns_match[ns_type] 130 | tmp = change(x, y, color=True) 131 | tmp.next_task(change_factor) 132 | tmp.show_sample(10) 133 | 134 | 135 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring} 136 | 137 | 138 | def construct_ns_single(train_x_split, train_y_split, test_x_split, test_y_split, ns_type, change_factor, ns_task, 139 | plot=True): 140 | # Data splits 141 | train_list = [] 142 | test_list = [] 143 | change = ns_match[ns_type] 144 | i = 0 145 | if len(change_factor) == 1: 146 | change_factor = change_factor[0] 147 | for idx, val in enumerate(ns_task): 148 | if idx % 2 == 0: 149 | for _ in range(val): 150 | print(i, 'normal') 151 | # train 152 | tmp = Original(train_x_split[i], train_y_split[i], color=True) 153 | train_list.append(tmp.next_task()) 154 | if plot: 155 | tmp.show_sample() 156 | 157 | # test 158 | tmp_test = Original(test_x_split[i], test_y_split[i], color=True) 159 | test_list.append(tmp_test.next_task()) 160 | if plot: 161 | tmp_test.show_sample() 162 | 163 | i += 1 164 | else: 165 | for _ in range(val): 166 | print(i, 'change') 167 | # train 168 | tmp = change(train_x_split[i], train_y_split[i], color=True) 169 | train_list.append(tmp.next_task(change_factor)) 170 | if plot: 171 | tmp.show_sample() 172 | # test 173 | tmp_test = change(test_x_split[i], test_y_split[i], color=True) 174 | test_list.append(tmp_test.next_task(change_factor)) 175 | if plot: 176 | tmp_test.show_sample() 177 | 178 | i += 1 179 | return train_list, test_list 180 | 181 | 182 | def construct_ns_multiple(train_x_split, train_y_split, val_x_rdm_split, val_y_rdm_split, test_x_split, 183 | test_y_split, ns_type, change_factors, plot): 184 | train_list = [] 185 | val_list = [] 186 | test_list = [] 187 | ns_len = len(change_factors) 188 | for i in range(ns_len): 189 | factor = change_factors[i] 190 | if factor == 0: 191 | ns_generator = Original 192 | else: 193 | ns_generator = ns_match[ns_type] 194 | print(i, factor) 195 | # train 196 | tmp = ns_generator(train_x_split[i], train_y_split[i], color=True) 197 | train_list.append(tmp.next_task(factor)) 198 | if plot: 199 | tmp.show_sample() 200 | 201 | tmp_val = ns_generator(val_x_rdm_split[i], val_y_rdm_split[i], color=True) 202 | val_list.append(tmp_val.next_task(factor)) 203 | 204 | tmp_test = ns_generator(test_x_split[i], test_y_split[i], color=True) 205 | test_list.append(tmp_test.next_task(factor)) 206 | return train_list, val_list, test_list 207 | 208 | 209 | def construct_ns_multiple_wrapper(train_data, train_label, test_data, est_label, task_nums, img_size, 210 | val_size, ns_type, ns_factor, plot): 211 | train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split = train_val_test_split_ni( 212 | train_data, train_label, test_data, est_label, task_nums, img_size, 213 | val_size) 214 | train_set, val_set, test_set = construct_ns_multiple(train_data_rdm_split, train_label_rdm_split, 215 | val_data_rdm_split, val_label_rdm_split, 216 | test_data_rdm_split, test_label_rdm_split, 217 | ns_type, 218 | ns_factor, 219 | plot=plot) 220 | return train_set, val_set, test_set 221 | -------------------------------------------------------------------------------- /models/ndpm/classifier.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.utils import maybe_cuda 6 | from .component import ComponentD 7 | from utils.global_vars import * 8 | from utils.setup_elements import n_classes 9 | 10 | 11 | class Classifier(ComponentD, ABC): 12 | def __init__(self, params, experts): 13 | super().__init__(params, experts) 14 | self.ce_loss = nn.NLLLoss(reduction='none') 15 | 16 | @abstractmethod 17 | def forward(self, x): 18 | """Output log P(y|x)""" 19 | pass 20 | 21 | def nll(self, x, y, step=None): 22 | x, y = maybe_cuda(x), maybe_cuda(y) 23 | log_softmax = self.forward(x) 24 | loss_pred = self.ce_loss(log_softmax, y) 25 | 26 | # Classifier chilling 27 | chilled_log_softmax = F.log_softmax( 28 | log_softmax / self.params.classifier_chill, dim=1) 29 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y) 30 | 31 | # Value with chill & gradient without chill 32 | loss_pred = loss_pred - loss_pred.detach() \ 33 | + chilled_loss_pred.detach() 34 | 35 | return loss_pred 36 | 37 | 38 | class SharingClassifier(Classifier, ABC): 39 | @abstractmethod 40 | def forward(self, x, collect=False): 41 | pass 42 | 43 | def collect_forward(self, x): 44 | dummy_pred = self.experts[0](x) 45 | preds, _ = self.forward(x, collect=True) 46 | return torch.stack([dummy_pred] + preds, dim=1) 47 | 48 | def collect_nll(self, x, y, step=None): 49 | preds = self.collect_forward(x) # [B, 1+K, C] 50 | loss_preds = [] 51 | for log_softmax in preds.unbind(dim=1): 52 | loss_pred = self.ce_loss(log_softmax, y) 53 | 54 | # Classifier chilling 55 | chilled_log_softmax = F.log_softmax( 56 | log_softmax / self.params.classifier_chill, dim=1) 57 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y) 58 | 59 | # Value with chill & gradient without chill 60 | loss_pred = loss_pred - loss_pred.detach() \ 61 | + chilled_loss_pred.detach() 62 | 63 | loss_preds.append(loss_pred) 64 | return torch.stack(loss_preds, dim=1) 65 | 66 | 67 | class BasicBlock(nn.Module): 68 | expansion = 1 69 | 70 | def __init__(self, inplanes, planes, stride=1, 71 | downsample=None, upsample=None, 72 | dilation=1, norm_layer=nn.BatchNorm2d): 73 | super(BasicBlock, self).__init__() 74 | if dilation > 1: 75 | raise NotImplementedError( 76 | "Dilation > 1 not supported in BasicBlock" 77 | ) 78 | transpose = upsample is not None and stride != 1 79 | self.conv1 = ( 80 | conv4x4t(inplanes, planes, stride) if transpose else 81 | conv3x3(inplanes, planes, stride) 82 | ) 83 | self.bn1 = norm_layer(planes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.conv2 = conv3x3(planes, planes) 86 | self.bn2 = norm_layer(planes) 87 | self.downsample = downsample 88 | self.upsample = upsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | elif self.upsample is not None: 104 | identity = self.upsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | def conv4x4t(in_planes, out_planes, stride=1, groups=1, dilation=1): 113 | """4x4 transposed convolution with padding""" 114 | return nn.ConvTranspose2d( 115 | in_planes, out_planes, kernel_size=4, stride=stride, 116 | padding=dilation, groups=groups, bias=False, dilation=dilation, 117 | ) 118 | 119 | 120 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d( 123 | in_planes, out_planes, kernel_size=3, stride=stride, 124 | padding=dilation, groups=groups, bias=False, dilation=dilation 125 | ) 126 | 127 | 128 | def conv1x1(in_planes, out_planes, stride=1): 129 | """1x1 convolution""" 130 | return nn.Conv2d( 131 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 132 | ) 133 | 134 | 135 | class ResNetSharingClassifier(SharingClassifier): 136 | block = BasicBlock 137 | num_blocks = [2, 2, 2, 2] 138 | norm_layer = nn.InstanceNorm2d 139 | 140 | def __init__(self, params, experts): 141 | super().__init__(params, experts) 142 | self.precursors = [expert.d for expert in self.experts[1:]] 143 | first = len(self.precursors) == 0 144 | 145 | if MODELS_NDPM_CLASSIFIER_NUM_BLOCKS is not None: 146 | num_blocks = MODELS_NDPM_CLASSIFIER_NUM_BLOCKS 147 | else: 148 | num_blocks = self.num_blocks 149 | if MODELS_NDPM_CLASSIFIER_NORM_LAYER is not None: 150 | self.norm_layer = getattr(nn, MODELS_NDPM_CLASSIFIER_NORM_LAYER) 151 | else: 152 | self.norm_layer = nn.BatchNorm2d 153 | 154 | num_classes = n_classes[params.data] 155 | nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT 156 | nf_cat = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE \ 157 | + len(self.precursors) * MODELS_NDPM_CLASSIFIER_CLS_NF_EXT 158 | self.nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT 159 | self.nf_cat = nf_cat 160 | 161 | self.layer0 = nn.Sequential( 162 | nn.Conv2d( 163 | 3, nf * 1, kernel_size=3, stride=1, padding=1, bias=False 164 | ), 165 | self.norm_layer(nf * 1), 166 | nn.ReLU() 167 | ) 168 | self.layer1 = self._make_layer( 169 | nf_cat * 1, nf * 1, num_blocks[0], stride=1) 170 | self.layer2 = self._make_layer( 171 | nf_cat * 1, nf * 2, num_blocks[1], stride=2) 172 | self.layer3 = self._make_layer( 173 | nf_cat * 2, nf * 4, num_blocks[2], stride=2) 174 | self.layer4 = self._make_layer( 175 | nf_cat * 4, nf * 8, num_blocks[3], stride=2) 176 | self.predict = nn.Sequential( 177 | nn.Linear(nf_cat * 8, num_classes), 178 | nn.LogSoftmax(dim=1) 179 | ) 180 | self.setup_optimizer() 181 | 182 | def _make_layer(self, nf_in, nf_out, num_blocks, stride): 183 | norm_layer = self.norm_layer 184 | block = self.block 185 | downsample = None 186 | if stride != 1 or nf_in != nf_out: 187 | downsample = nn.Sequential( 188 | conv1x1(nf_in, nf_out, stride), 189 | norm_layer(nf_out), 190 | ) 191 | layers = [block( 192 | nf_in, nf_out, stride, 193 | downsample=downsample, 194 | norm_layer=norm_layer 195 | )] 196 | for _ in range(1, num_blocks): 197 | layers.append(block(nf_out, nf_out, norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x, collect=False): 202 | x = maybe_cuda(x) 203 | 204 | # First component 205 | if len(self.precursors) == 0: 206 | h1 = self.layer0(x) 207 | h2 = self.layer1(h1) 208 | h3 = self.layer2(h2) 209 | h4 = self.layer3(h3) 210 | h5 = self.layer4(h4) 211 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1) 212 | pred = self.predict(h5) 213 | 214 | if collect: 215 | return [pred], [ 216 | h1.detach(), h2.detach(), h3.detach(), 217 | h4.detach(), h5.detach()] 218 | else: 219 | return pred 220 | 221 | # Second or layer component 222 | preds, features = self.precursors[-1](x, collect=True) 223 | h1 = self.layer0(x) 224 | h1_cat = torch.cat([features[0], h1], dim=1) 225 | h2 = self.layer1(h1_cat) 226 | h2_cat = torch.cat([features[1], h2], dim=1) 227 | h3 = self.layer2(h2_cat) 228 | h3_cat = torch.cat([features[2], h3], dim=1) 229 | h4 = self.layer3(h3_cat) 230 | h4_cat = torch.cat([features[3], h4], dim=1) 231 | h5 = self.layer4(h4_cat) 232 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1) 233 | h5_cat = torch.cat([features[4], h5], dim=1) 234 | pred = self.predict(h5_cat) 235 | 236 | if collect: 237 | preds.append(pred) 238 | return preds, [ 239 | h1_cat.detach(), h2_cat.detach(), h3_cat.detach(), 240 | h4_cat.detach(), h5_cat.detach(), 241 | ] 242 | else: 243 | return pred 244 | -------------------------------------------------------------------------------- /models/ndpm/vae.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from itertools import accumulate 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils.utils import maybe_cuda 7 | from .loss import bernoulli_nll, logistic_nll, gaussian_nll, laplace_nll 8 | from .component import ComponentG 9 | from .utils import Lambda 10 | from utils.global_vars import * 11 | from utils.setup_elements import input_size_match 12 | 13 | class Vae(ComponentG, ABC): 14 | def __init__(self, params, experts): 15 | super().__init__(params, experts) 16 | x_c, x_h, x_w = input_size_match[params.data] 17 | bernoulli = MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' 18 | if bernoulli: 19 | self.log_var_param = None 20 | elif MODELS_NDPM_VAE_LEARN_X_LOG_VAR: 21 | self.log_var_param = nn.Parameter( 22 | torch.ones([x_c]) * MODELS_NDPM_VAE_X_LOG_VAR_PARAM, 23 | requires_grad=True 24 | ) 25 | else: 26 | self.log_var_param = ( 27 | maybe_cuda(torch.ones([x_c])) * 28 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM 29 | ) 30 | 31 | def forward(self, x): 32 | x = maybe_cuda(x) 33 | z_mean, z_log_var = self.encode(x) 34 | z = self.reparameterize(z_mean, z_log_var, 1) 35 | return self.decode(z) 36 | 37 | def nll(self, x, y=None, step=None): 38 | x = maybe_cuda(x) 39 | z_mean, z_log_var = self.encode(x) 40 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES) 41 | x_mean = self.decode(z) 42 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, *x.shape[1:]) 43 | x_log_var = ( 44 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else 45 | self.log_var.view(1, 1, -1, 1, 1) 46 | ) 47 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var) 48 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, -1) 49 | loss_recon = loss_recon.sum(2).mean(1) 50 | loss_kl = self.gaussian_kl(z_mean, z_log_var) 51 | loss_vae = loss_recon + loss_kl 52 | 53 | return loss_vae 54 | 55 | def sample(self, n=1): 56 | z = maybe_cuda(torch.randn(n, MODELS_NDPM_VAE_Z_DIM)) 57 | x_mean = self.decode(z) 58 | return x_mean 59 | 60 | def reconstruction_loss(self, x, x_mean, x_log_var=None): 61 | loss_type = MODELS_NDPM_VAE_RECON_LOSS 62 | loss = ( 63 | bernoulli_nll if loss_type == 'bernoulli' else 64 | gaussian_nll if loss_type == 'gaussian' else 65 | laplace_nll if loss_type == 'laplace' else 66 | logistic_nll if loss_type == 'logistic' else None 67 | ) 68 | if loss is None: 69 | raise ValueError('Unknown recon_loss type: {}'.format(loss_type)) 70 | 71 | if len(x_mean.size()) > len(x.size()): 72 | x = x.unsqueeze(1) 73 | 74 | return ( 75 | loss(x, x_mean) if x_log_var is None else 76 | loss(x, x_mean, x_log_var) 77 | ) 78 | 79 | @staticmethod 80 | def gaussian_kl(q_mean, q_log_var, p_mean=None, p_log_var=None): 81 | # p defaults to N(0, 1) 82 | zeros = torch.zeros_like(q_mean) 83 | p_mean = p_mean if p_mean is not None else zeros 84 | p_log_var = p_log_var if p_log_var is not None else zeros 85 | # calcaulate KL(q, p) 86 | kld = 0.5 * ( 87 | p_log_var - q_log_var + 88 | (q_log_var.exp() + (q_mean - p_mean) ** 2) / p_log_var.exp() - 1 89 | ) 90 | kld = kld.sum(1) 91 | return kld 92 | 93 | @staticmethod 94 | def reparameterize(z_mean, z_log_var, num_samples=1): 95 | z_std = (z_log_var * 0.5).exp() 96 | z_std = z_std.unsqueeze(1).expand(-1, num_samples, -1) 97 | z_mean = z_mean.unsqueeze(1).expand(-1, num_samples, -1) 98 | unit_normal = torch.randn_like(z_std) 99 | z = z_mean + unit_normal * z_std 100 | z = z.view(-1, z_std.size(2)) 101 | return z 102 | 103 | @abstractmethod 104 | def encode(self, x): 105 | pass 106 | 107 | @abstractmethod 108 | def decode(self, x): 109 | pass 110 | 111 | @property 112 | def log_var(self): 113 | return ( 114 | None if self.log_var_param is None else 115 | self.log_var_param 116 | ) 117 | 118 | 119 | class SharingVae(Vae, ABC): 120 | def collect_nll(self, x, y=None, step=None): 121 | """Collect NLL values 122 | 123 | Returns: 124 | loss_vae: Tensor of shape [B, 1+K] 125 | """ 126 | x = maybe_cuda(x) 127 | 128 | # Dummy VAE 129 | dummy_nll = self.experts[0].g.nll(x, y, step) 130 | 131 | # Encode 132 | z_means, z_log_vars, features = self.encode(x, collect=True) 133 | 134 | # Decode 135 | loss_vaes = [dummy_nll] 136 | vaes = [expert.g for expert in self.experts[1:]] + [self] 137 | x_logits = [] 138 | for z_mean, z_log_var, vae in zip(z_means, z_log_vars, vaes): 139 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES) 140 | if MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER: 141 | x_logit = vae.decode(z, as_logit=True) 142 | x_logits.append(x_logit) 143 | continue 144 | x_mean = vae.decode(z) 145 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 146 | *x.shape[1:]) 147 | x_log_var = ( 148 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else 149 | self.log_var.view(1, 1, -1, 1, 1) 150 | ) 151 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var) 152 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 153 | -1) 154 | loss_recon = loss_recon.sum(2).mean(1) 155 | loss_kl = self.gaussian_kl(z_mean, z_log_var) 156 | loss_vae = loss_recon + loss_kl 157 | 158 | loss_vaes.append(loss_vae) 159 | 160 | x_logits = list(accumulate( 161 | x_logits, func=(lambda x, y: x.detach() + y) 162 | )) 163 | for x_logit in x_logits: 164 | x_mean = torch.sigmoid(x_logit) 165 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 166 | *x.shape[1:]) 167 | x_log_var = ( 168 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else 169 | self.log_var.view(1, 1, -1, 1, 1) 170 | ) 171 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var) 172 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, 173 | -1) 174 | loss_recon = loss_recon.sum(2).mean(1) 175 | loss_kl = self.gaussian_kl(z_mean, z_log_var) 176 | loss_vae = loss_recon + loss_kl 177 | loss_vaes.append(loss_vae) 178 | 179 | return torch.stack(loss_vaes, dim=1) 180 | 181 | @abstractmethod 182 | def encode(self, x, collect=False): 183 | pass 184 | 185 | @abstractmethod 186 | def decode(self, z, as_logit=False): 187 | """ 188 | Decode do not share parameters 189 | """ 190 | pass 191 | 192 | 193 | class CnnSharingVae(SharingVae): 194 | def __init__(self, params, experts): 195 | super().__init__(params, experts) 196 | self.precursors = [expert.g for expert in self.experts[1:]] 197 | first = len(self.precursors) == 0 198 | nf_base, nf_ext = MODLES_NDPM_VAE_NF_BASE, MODELS_NDPM_VAE_NF_EXT 199 | nf = nf_base if first else nf_ext 200 | nf_cat = nf_base + len(self.precursors) * nf_ext 201 | 202 | h1_dim = 1 * nf 203 | h2_dim = 2 * nf 204 | fc_dim = 4 * nf 205 | h1_cat_dim = 1 * nf_cat 206 | h2_cat_dim = 2 * nf_cat 207 | fc_cat_dim = 4 * nf_cat 208 | 209 | x_c, x_h, x_w = input_size_match[params.data] 210 | 211 | self.fc_dim = fc_dim 212 | feature_volume = ((x_h // 4) * (x_w // 4) * 213 | h2_cat_dim) 214 | 215 | self.enc1 = nn.Sequential( 216 | nn.Conv2d(x_c, h1_dim, 3, 1, 1), 217 | nn.MaxPool2d(2), 218 | nn.ReLU() 219 | ) 220 | self.enc2 = nn.Sequential( 221 | nn.Conv2d(h1_cat_dim, h2_dim, 3, 1, 1), 222 | nn.MaxPool2d(2), 223 | nn.ReLU(), 224 | Lambda(lambda x: x.view(x.size(0), -1)) 225 | ) 226 | self.enc3 = nn.Sequential( 227 | nn.Linear(feature_volume, fc_dim), 228 | nn.ReLU() 229 | ) 230 | self.enc_z_mean = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM) 231 | self.enc_z_log_var = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM) 232 | 233 | self.dec_z = nn.Sequential( 234 | nn.Linear(MODELS_NDPM_VAE_Z_DIM, 4 * nf_base), 235 | nn.ReLU() 236 | ) 237 | self.dec3 = nn.Sequential( 238 | nn.Linear( 239 | 4 * nf_base, 240 | (x_h // 4) * (x_w // 4) * 2 * nf_base), 241 | nn.ReLU() 242 | ) 243 | self.dec2 = nn.Sequential( 244 | Lambda(lambda x: x.view( 245 | x.size(0), 2 * nf_base, 246 | x_h // 4, x_w // 4)), 247 | nn.ConvTranspose2d(2 * nf_base, 1 * nf_base, 248 | kernel_size=4, stride=2, padding=1), 249 | nn.ReLU() 250 | ) 251 | self.dec1 = nn.ConvTranspose2d(1 * nf_base, x_c, 252 | kernel_size=4, stride=2, padding=1) 253 | 254 | self.setup_optimizer() 255 | 256 | def encode(self, x, collect=False): 257 | # When first component 258 | if len(self.precursors) == 0: 259 | h1 = self.enc1(x) 260 | h2 = self.enc2(h1) 261 | h3 = self.enc3(h2) 262 | z_mean = self.enc_z_mean(h3) 263 | z_log_var = self.enc_z_log_var(h3) 264 | 265 | if collect: 266 | return [z_mean], [z_log_var], \ 267 | [h1.detach(), h2.detach(), h3.detach()] 268 | else: 269 | return z_mean, z_log_var 270 | 271 | # Second or later component 272 | z_means, z_log_vars, features = \ 273 | self.precursors[-1].encode(x, collect=True) 274 | 275 | h1 = self.enc1(x) 276 | h1_cat = torch.cat([features[0], h1], dim=1) 277 | h2 = self.enc2(h1_cat) 278 | h2_cat = torch.cat([features[1], h2], dim=1) 279 | h3 = self.enc3(h2_cat) 280 | h3_cat = torch.cat([features[2], h3], dim=1) 281 | z_mean = self.enc_z_mean(h3_cat) 282 | z_log_var = self.enc_z_log_var(h3_cat) 283 | 284 | if collect: 285 | z_means.append(z_mean) 286 | z_log_vars.append(z_log_var) 287 | features = [h1_cat.detach(), h2_cat.detach(), h3_cat.detach()] 288 | return z_means, z_log_vars, features 289 | else: 290 | return z_mean, z_log_var 291 | 292 | def decode(self, z, as_logit=False): 293 | h3 = self.dec_z(z) 294 | h2 = self.dec3(h3) 295 | h1 = self.dec2(h2) 296 | x_logit = self.dec1(h1) 297 | return x_logit if as_logit else torch.sigmoid(x_logit) 298 | --------------------------------------------------------------------------------