├── agents
├── __init__.py
├── __pycache__
│ ├── agem.cpython-36.pyc
│ ├── base.cpython-36.pyc
│ ├── lwf.cpython-36.pyc
│ ├── scr.cpython-36.pyc
│ ├── cndpm.cpython-36.pyc
│ ├── ewc_pp.cpython-36.pyc
│ ├── gdumb.cpython-36.pyc
│ ├── icarl.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ ├── exp_replay.cpython-36.pyc
│ └── exp_replay_dvc.cpython-36.pyc
├── cndpm.py
├── lwf.py
├── icarl.py
├── scr.py
├── gdumb.py
├── agem.py
├── ewc_pp.py
├── exp_replay.py
└── exp_replay_dvc.py
├── utils
├── __init__.py
├── buffer
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── buffer.cpython-36.pyc
│ │ ├── aser_utils.cpython-36.pyc
│ │ ├── mem_match.cpython-36.pyc
│ │ ├── aser_retrieve.cpython-36.pyc
│ │ ├── aser_update.cpython-36.pyc
│ │ ├── buffer_utils.cpython-36.pyc
│ │ ├── mgi_retrieve.cpython-36.pyc
│ │ ├── mir_retrieve.cpython-36.pyc
│ │ ├── sc_retrieve.cpython-36.pyc
│ │ ├── random_retrieve.cpython-36.pyc
│ │ ├── gss_greedy_update.cpython-36.pyc
│ │ └── reservoir_update.cpython-36.pyc
│ ├── random_retrieve.py
│ ├── sc_retrieve.py
│ ├── mem_match.py
│ ├── buffer.py
│ ├── aser_update.py
│ ├── reservoir_update.py
│ ├── mir_retrieve.py
│ ├── mgi_retrieve.py
│ ├── aser_retrieve.py
│ ├── gss_greedy_update.py
│ └── aser_utils.py
├── __pycache__
│ ├── io.cpython-36.pyc
│ ├── loss.cpython-36.pyc
│ ├── utils.cpython-36.pyc
│ ├── L_softmax.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ ├── kd_manager.cpython-36.pyc
│ ├── name_match.cpython-36.pyc
│ ├── global_vars.cpython-36.pyc
│ └── setup_elements.cpython-36.pyc
├── kd_manager.py
├── global_vars.py
├── io.py
├── name_match.py
├── setup_elements.py
├── L_softmax.py
├── loss.py
└── utils.py
├── continuum
├── __init__.py
├── dataset_scripts
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── cifar10.cpython-36.pyc
│ │ ├── core50.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── cifar100.cpython-36.pyc
│ │ ├── openloris.cpython-36.pyc
│ │ ├── dataset_base.cpython-36.pyc
│ │ └── mini_imagenet.cpython-36.pyc
│ ├── dataset_base.py
│ ├── cifar100.py
│ ├── cifar10.py
│ ├── mini_imagenet.py
│ ├── openloris.py
│ └── core50.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── continuum.cpython-36.pyc
│ ├── data_utils.cpython-36.pyc
│ └── non_stationary.cpython-36.pyc
├── continuum.py
├── data_utils.py
└── non_stationary.py
├── experiment
├── __init__.py
├── __pycache__
│ ├── run.cpython-36.pyc
│ ├── metrics.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ └── tune_hyperparam.cpython-36.pyc
├── tune_hyperparam.py
└── metrics.py
├── models
├── ndpm
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── loss.cpython-36.pyc
│ │ ├── ndpm.cpython-36.pyc
│ │ ├── vae.cpython-36.pyc
│ │ ├── expert.cpython-36.pyc
│ │ ├── priors.cpython-36.pyc
│ │ ├── utils.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── component.cpython-36.pyc
│ │ └── classifier.cpython-36.pyc
│ ├── utils.py
│ ├── loss.py
│ ├── priors.py
│ ├── expert.py
│ ├── component.py
│ ├── ndpm.py
│ ├── classifier.py
│ └── vae.py
├── __init__.py
├── __pycache__
│ ├── resnet.cpython-36.pyc
│ └── __init__.cpython-36.pyc
├── pretrained.py
└── resnet.py
├── __pycache__
└── loss.cpython-36.pyc
├── features
├── __pycache__
│ └── extractor.cpython-36.pyc
├── configs.py
├── regnet.py
├── lenet.py
├── efficientnet.py
├── extractor.py
├── wide_resnet.py
└── vgg.py
├── requirement.txt
├── .idea
├── misc.xml
├── modules.xml
├── online-continual-learning-main-realori.iml
└── inspectionProfiles
│ └── Project_Default.xml
├── README.md
└── loss.py
/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/continuum/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiment/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/ndpm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/buffer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import resnet
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/io.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/io.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/agem.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/agem.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/base.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/base.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/lwf.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/lwf.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/scr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/scr.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/cndpm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/cndpm.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/ewc_pp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/ewc_pp.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/gdumb.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/gdumb.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/icarl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/icarl.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/experiment/__pycache__/run.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/run.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/ndpm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/ndpm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/vae.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/L_softmax.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/L_softmax.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/kd_manager.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/kd_manager.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/name_match.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/name_match.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/exp_replay.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/exp_replay.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/experiment/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/features/__pycache__/extractor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/features/__pycache__/extractor.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/expert.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/expert.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/priors.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/priors.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/global_vars.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/global_vars.cpython-36.pyc
--------------------------------------------------------------------------------
/agents/__pycache__/exp_replay_dvc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/agents/__pycache__/exp_replay_dvc.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/__pycache__/continuum.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/continuum.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/__pycache__/data_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/data_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/experiment/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/component.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/component.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/setup_elements.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/__pycache__/setup_elements.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/buffer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/buffer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/ndpm/__pycache__/classifier.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/models/ndpm/__pycache__/classifier.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/aser_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/aser_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/mem_match.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/mem_match.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/__pycache__/non_stationary.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/__pycache__/non_stationary.cpython-36.pyc
--------------------------------------------------------------------------------
/experiment/__pycache__/tune_hyperparam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/experiment/__pycache__/tune_hyperparam.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/aser_retrieve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/aser_retrieve.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/aser_update.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/aser_update.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/buffer_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/buffer_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/mgi_retrieve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/mgi_retrieve.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/mir_retrieve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/mir_retrieve.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/sc_retrieve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/sc_retrieve.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/random_retrieve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/random_retrieve.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/gss_greedy_update.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/gss_greedy_update.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/buffer/__pycache__/reservoir_update.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/utils/buffer/__pycache__/reservoir_update.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/cifar10.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/cifar10.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/core50.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/core50.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/cifar100.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/cifar100.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/openloris.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/openloris.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/dataset_base.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/dataset_base.cpython-36.pyc
--------------------------------------------------------------------------------
/continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YananGu/DVC/HEAD/continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.0
2 | torchvision==0.10.0
3 | matplotlib==3.3.4
4 | scikit-image==0.17.2
5 | timm==0.4.12
6 | kornia==0.5.4
7 | scikit-learn==0.24.2
8 | pandas==1.1.5
9 | psutil==5.8.0
10 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/models/pretrained.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | import torch
3 |
4 | def ResNet18_pretrained(n_classes):
5 | classifier = models.resnet18(pretrained=True)
6 | classifier.fc = torch.nn.Linear(512, n_classes)
7 | return classifier
--------------------------------------------------------------------------------
/models/ndpm/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Lambda(nn.Module):
5 | def __init__(self, f=None):
6 | super().__init__()
7 | self.f = f if f is not None else (lambda x: x)
8 |
9 | def forward(self, *args, **kwargs):
10 | return self.f(*args, **kwargs)
11 |
--------------------------------------------------------------------------------
/utils/buffer/random_retrieve.py:
--------------------------------------------------------------------------------
1 | from utils.buffer.buffer_utils import random_retrieve
2 |
3 | class Random_retrieve(object):
4 | def __init__(self, params):
5 | super().__init__()
6 | self.num_retrieve = params.eps_mem_batch
7 |
8 | def retrieve(self, buffer, **kwargs):
9 | return random_retrieve(buffer, self.num_retrieve)
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/online-continual-learning-main-realori.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/utils/buffer/sc_retrieve.py:
--------------------------------------------------------------------------------
1 | from utils.buffer.buffer_utils import match_retrieve
2 | import torch
3 |
4 | class Match_retrieve(object):
5 | def __init__(self, params):
6 | super().__init__()
7 | self.num_retrieve = params.eps_mem_batch
8 | self.warmup = params.warmup
9 |
10 | def retrieve(self, buffer, **kwargs):
11 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup:
12 | cur_x, cur_y = kwargs['x'], kwargs['y']
13 | return match_retrieve(buffer, cur_y)
14 | else:
15 | return torch.tensor([]), torch.tensor([])
--------------------------------------------------------------------------------
/features/configs.py:
--------------------------------------------------------------------------------
1 |
2 | vgg = {
3 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
4 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
5 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
6 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
7 | 'VGG7': [64, 'M', 128, 'M', 256, 'M', 512],
8 | }
9 |
10 | wide_resnet = {
11 | 'WideResNet28-10' : [28, 10, 0.3],
12 | 'WideResNet40-2' : [40, 2, 0.3],
13 | }
14 |
15 | shake_resnet = {
16 | 'ShakeResNet26-2x32d' : [26, 32],
17 | }
18 |
--------------------------------------------------------------------------------
/features/regnet.py:
--------------------------------------------------------------------------------
1 | import timm
2 | from . import extractor
3 |
4 | class RegNetX002(extractor.BaseModule):
5 | def __init__(self, config, name):
6 | super(RegNetX002, self).__init__()
7 |
8 | self.name = name
9 | self.features = timm.create_model('regnetx_002')
10 | self.n_features = 368
11 |
12 | def forward(self, x):
13 | return self.features.forward_features(x)
14 |
15 | class RegNetY004(extractor.BaseModule):
16 | def __init__(self, config, name):
17 | super(RegNetY004, self).__init__()
18 |
19 | self.name = name
20 | self.features = timm.create_model('regnety_004')
21 | self.n_features = 440
22 |
23 | def forward(self, x):
24 | return self.features.forward_features(x)
25 |
26 |
--------------------------------------------------------------------------------
/utils/kd_manager.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def loss_fn_kd(scores, target_scores, T=2.):
7 | log_scores_norm = F.log_softmax(scores / T, dim=1)
8 | targets_norm = F.softmax(target_scores / T, dim=1)
9 | # Calculate distillation loss (see e.g., Li and Hoiem, 2017)
10 | kd_loss = (-1 * targets_norm * log_scores_norm).sum(dim=1).mean() * T ** 2
11 | return kd_loss
12 |
13 |
14 | class KdManager:
15 | def __init__(self):
16 | self.teacher_model = None
17 |
18 | def update_teacher(self, model):
19 | self.teacher_model = copy.deepcopy(model)
20 |
21 | def get_kd_loss(self, cur_model_logits, x):
22 | if self.teacher_model is not None:
23 | with torch.no_grad():
24 | prev_model_logits = self.teacher_model.forward(x)
25 | dist_loss = loss_fn_kd(cur_model_logits, prev_model_logits)
26 | else:
27 | dist_loss = 0
28 | return dist_loss
29 |
--------------------------------------------------------------------------------
/features/lenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch.nn as nn
6 | from . import extractor
7 |
8 |
9 | class LeNet(extractor.BaseModule):
10 | def __init__(self, config, name):
11 | super(LeNet, self).__init__()
12 | self.name = name
13 | in_channels = config["channels"]
14 | self.conv1 = nn.Conv2d(in_channels=in_channels,
15 | out_channels=6,
16 | kernel_size=5)
17 | self.conv2 = nn.Conv2d(in_channels=6,
18 | out_channels=16,
19 | kernel_size=5)
20 | self.pool = nn.MaxPool2d(2)
21 | self.n_features = 400
22 | self.relu = nn.ReLU(inplace=True)
23 |
24 | def forward(self, x):
25 | x = self.relu(self.conv1(x))
26 | x = self.pool(x)
27 | x = self.relu(self.conv2(x))
28 | return x
29 |
--------------------------------------------------------------------------------
/utils/buffer/mem_match.py:
--------------------------------------------------------------------------------
1 | from utils.buffer.buffer_utils import match_retrieve
2 | from utils.buffer.buffer_utils import random_retrieve
3 | import torch
4 |
5 | class MemMatch_retrieve(object):
6 | def __init__(self, params):
7 | super().__init__()
8 | self.num_retrieve = params.eps_mem_batch
9 | self.warmup = params.warmup
10 |
11 |
12 | def retrieve(self, buffer, **kwargs):
13 | match_x, match_y = torch.tensor([]), torch.tensor([])
14 | candidate_x, candidate_y = torch.tensor([]), torch.tensor([])
15 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup:
16 | while match_x.size(0) == 0:
17 | candidate_x, candidate_y, indices = random_retrieve(buffer, self.num_retrieve,return_indices=True)
18 | if candidate_x.size(0) == 0:
19 | return candidate_x, candidate_y, match_x, match_y
20 | match_x, match_y = match_retrieve(buffer, candidate_y, indices)
21 | return candidate_x, candidate_y, match_x, match_y
--------------------------------------------------------------------------------
/features/efficientnet.py:
--------------------------------------------------------------------------------
1 | import timm
2 | from . import extractor
3 |
4 | class EfficientNetB1(extractor.BaseModule):
5 | def __init__(self, config, name):
6 | super(EfficientNetB1, self).__init__()
7 |
8 | self.name = name
9 | self.features = timm.create_model('efficientnet_b1')
10 | self.n_features = 1280
11 |
12 | def forward(self, x):
13 | return self.features.forward_features(x)
14 |
15 | class EfficientNetB0(extractor.BaseModule):
16 | def __init__(self, config, name):
17 | super(EfficientNetB0, self).__init__()
18 |
19 | self.name = name
20 | self.features = timm.create_model('efficientnet_b0')
21 | self.n_features = 1280
22 |
23 | def forward(self, x):
24 | return self.features.forward_features(x)
25 |
26 | class EfficientNetV2S(extractor.BaseModule):
27 | def __init__(self, config, name):
28 | super(EfficientNetV2S, self).__init__()
29 |
30 | self.name = name
31 | drop_rate = config['dropout']
32 | self.features = timm.create_model('efficientnetv2_s', drop_rate=drop_rate)
33 | self.n_features = 0 #FIXME: Rdit this
34 |
35 | def forward(self, x):
36 | return self.features(x)
37 |
--------------------------------------------------------------------------------
/continuum/continuum.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # other imports
4 | from utils.name_match import data_objects
5 |
6 | class continuum(object):
7 | def __init__(self, dataset, scenario, params):
8 | """" Initialize Object """
9 | self.data_object = data_objects[dataset](scenario, params)
10 | self.run = params.num_runs
11 | self.task_nums = self.data_object.task_nums
12 | self.cur_task = 0
13 | self.cur_run = -1
14 |
15 | def __iter__(self):
16 | return self
17 |
18 | def __next__(self):
19 | if self.cur_task == self.data_object.task_nums:
20 | raise StopIteration
21 | x_train, y_train, labels = self.data_object.new_task(self.cur_task, cur_run=self.cur_run)
22 | self.cur_task += 1
23 | return x_train, y_train, labels
24 |
25 | def test_data(self):
26 | return self.data_object.get_test_set()
27 |
28 | def clean_mem_test_set(self):
29 | self.data_object.clean_mem_test_set()
30 |
31 | def reset_run(self):
32 | self.cur_task = 0
33 |
34 | def new_run(self):
35 | self.cur_task = 0
36 | self.cur_run += 1
37 | self.data_object.new_run(cur_run=self.cur_run)
38 |
39 |
40 |
--------------------------------------------------------------------------------
/utils/global_vars.py:
--------------------------------------------------------------------------------
1 | MODLES_NDPM_VAE_NF_BASE = 32
2 | MODELS_NDPM_VAE_NF_EXT = 4
3 | MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER = False
4 | MODELS_NDPM_VAE_Z_DIM = 64
5 | MODELS_NDPM_VAE_RECON_LOSS = 'gaussian'
6 | MODELS_NDPM_VAE_LEARN_X_LOG_VAR = False
7 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM = 0
8 | MODELS_NDPM_VAE_Z_SAMPLES = 16
9 | MODELS_NDPM_CLASSIFIER_NUM_BLOCKS = [1, 1, 1, 1]
10 | MODELS_NDPM_CLASSIFIER_NORM_LAYER = 'InstanceNorm2d'
11 | MODELS_NDPM_CLASSIFIER_CLS_NF_BASE = 20
12 | MODELS_NDPM_CLASSIFIER_CLS_NF_EXT = 4
13 | MODELS_NDPM_NDPM_DISABLE_D = False
14 | MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS = False
15 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE = 50
16 | MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS = 0
17 | MODELS_NDPM_NDPM_SLEEP_STEP_G = 4000
18 | MODELS_NDPM_NDPM_SLEEP_STEP_D = 1000
19 | MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE = 0
20 | MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP = 500
21 | MODELS_NDPM_NDPM_WEIGHT_DECAY = 0.00001
22 | MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY = False
23 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}}
24 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}}
25 | MODELS_NDPM_COMPONENT_CLIP_GRAD = {'type': 'value', 'options': {'clip_value': 0.5}}
26 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/dataset_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import os
3 |
4 | class DatasetBase(ABC):
5 | def __init__(self, dataset, scenario, task_nums, run, params):
6 | super(DatasetBase, self).__init__()
7 | self.params = params
8 | self.scenario = scenario
9 | self.dataset = dataset
10 | self.task_nums = task_nums
11 | self.run = run
12 | self.root = os.path.join('./datasets', self.dataset)
13 | self.test_set = []
14 | self.val_set = []
15 | self._is_properly_setup()
16 | self.download_load()
17 |
18 |
19 | @abstractmethod
20 | def download_load(self):
21 | pass
22 |
23 | @abstractmethod
24 | def setup(self, **kwargs):
25 | pass
26 |
27 | @abstractmethod
28 | def new_task(self, cur_task, **kwargs):
29 | pass
30 |
31 | def _is_properly_setup(self):
32 | pass
33 |
34 | @abstractmethod
35 | def new_run(self, **kwargs):
36 | pass
37 |
38 | @property
39 | def dataset_info(self):
40 | return self.dataset
41 |
42 | def get_test_set(self):
43 | return self.test_set
44 |
45 | def clean_mem_test_set(self):
46 | self.test_set = None
47 | self.test_data = None
48 | self.test_label = None
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import pandas as pd
3 | import os
4 | import psutil
5 | import torch
6 |
7 | def load_yaml(path, key='parameters'):
8 | with open(path, 'r') as stream:
9 | try:
10 | return yaml.load(stream, Loader=yaml.FullLoader)[key]
11 | except yaml.YAMLError as exc:
12 | print(exc)
13 |
14 | def save_dataframe_csv(df, path, name):
15 | df.to_csv(path + '/' + name, index=False)
16 |
17 |
18 | def load_dataframe_csv(path, name=None, delimiter=None, names=None):
19 | if not name:
20 | return pd.read_csv(path, delimiter=delimiter, names=names)
21 | else:
22 | return pd.read_csv(path+name, delimiter=delimiter, names=names)
23 |
24 | def check_ram_usage():
25 | """
26 | Compute the RAM usage of the current process.
27 | Returns:
28 | mem (float): Memory occupation in Megabytes
29 | """
30 |
31 | process = psutil.Process(os.getpid())
32 | mem = process.memory_info().rss / (1024 * 1024)
33 |
34 | return mem
35 |
36 | def save_model(model, optimizer, opt, epoch, save_file):
37 | print('==> Saving...')
38 | state = {
39 | 'opt': opt,
40 | 'model': model.state_dict(),
41 | 'optimizer': optimizer.state_dict(),
42 | 'epoch': epoch,
43 | }
44 | torch.save(state, save_file)
45 | del state
46 |
--------------------------------------------------------------------------------
/features/extractor.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class BaseModule(nn.Module):
8 | def __init__(self):
9 | super(BaseModule, self).__init__()
10 | self.n_features = 0
11 | self._name = "BaseModule"
12 |
13 | def forward(self, x):
14 | return x
15 |
16 | @property
17 | def name(self):
18 | return self._name
19 |
20 | @name.setter
21 | def name(self, name):
22 | self._name = name
23 |
24 | def init_weights(self, std=0.01):
25 | print("Initialize weights of %s with normal dist: mean=0, std=%0.2f" % (type(self), std))
26 | for m in self.modules():
27 | if type(m) == nn.Linear:
28 | nn.init.normal_(m.weight, 0, std)
29 | if m.bias is not None:
30 | m.bias.data.zero_()
31 | elif isinstance(m, nn.BatchNorm2d):
32 | nn.init.constant_(m.weight, 1)
33 | if m.bias is not None:
34 | m.bias.data.zero_()
35 | elif type(m) == nn.Conv2d:
36 | nn.init.normal_(m.weight, 0, std)
37 | if m.bias is not None:
38 | m.bias.data.zero_()
39 |
40 |
41 | if __name__ == '__main__':
42 | net = BaseModule()
43 | print(net)
44 | print("n_features:", net.n_features)
45 |
--------------------------------------------------------------------------------
/models/ndpm/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import Tensor
4 | from torch.nn.functional import binary_cross_entropy
5 |
6 |
7 | def gaussian_nll(x, mean, log_var, min_noise=0.001):
8 | return (
9 | ((x - mean) ** 2 + min_noise) / (2 * log_var.exp() + 1e-8)
10 | + 0.5 * log_var + 0.5 * np.log(2 * np.pi)
11 | )
12 |
13 |
14 | def laplace_nll(x, median, log_scale, min_noise=0.01):
15 | return (
16 | ((x - median).abs() + min_noise) / (log_scale.exp() + 1e-8)
17 | + log_scale + np.log(2)
18 | )
19 |
20 |
21 | def bernoulli_nll(x, p):
22 | # Broadcast
23 | x_exp, p_exp = [], []
24 | for x_size, p_size in zip(x.size(), p.size()):
25 | if x_size > p_size:
26 | x_exp.append(-1)
27 | p_exp.append(x_size)
28 | elif x_size < p_size:
29 | x_exp.append(p_size)
30 | p_exp.append(-1)
31 | else:
32 | x_exp.append(-1)
33 | p_exp.append(-1)
34 | x = x.expand(*x_exp)
35 | p = p.expand(*p_exp)
36 |
37 | return binary_cross_entropy(p, x, reduction='none')
38 |
39 |
40 | def logistic_nll(x, mean, log_scale):
41 | bin_size = 1 / 256
42 | scale = log_scale.exp()
43 | x_centered = x - mean
44 | cdf1 = x_centered / scale
45 | cdf2 = (x_centered + bin_size) / scale
46 | p = torch.sigmoid(cdf2) - torch.sigmoid(cdf1) + 1e-12
47 | return -p.log()
48 |
--------------------------------------------------------------------------------
/agents/cndpm.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from models.ndpm.ndpm import Ndpm
4 | from utils.setup_elements import transforms_match
5 | from torch.utils import data
6 | from utils.utils import maybe_cuda, AverageMeter
7 | import torch
8 |
9 |
10 | class Cndpm(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(Cndpm, self).__init__(model, opt, params)
13 | self.model = model
14 |
15 |
16 | def train_learner(self, x_train, y_train):
17 | # set up loader
18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
20 | drop_last=True)
21 | # setup tracker
22 | losses_batch = AverageMeter()
23 | acc_batch = AverageMeter()
24 |
25 | self.model.train()
26 |
27 | for ep in range(self.epoch):
28 | for i, batch_data in enumerate(train_loader):
29 | # batch update
30 | batch_x, batch_y = batch_data
31 | batch_x = maybe_cuda(batch_x, self.cuda)
32 | batch_y = maybe_cuda(batch_y, self.cuda)
33 | self.model.learn(batch_x, batch_y)
34 | if self.params.verbose:
35 | print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format(
36 | i,
37 | len(self.model.stm_x), self.params.stm_capacity,
38 | len(self.model.experts) - 1
39 | ), end='')
40 | print()
41 |
--------------------------------------------------------------------------------
/utils/buffer/buffer.py:
--------------------------------------------------------------------------------
1 | from utils.setup_elements import input_size_match
2 | from utils import name_match #import update_methods, retrieve_methods
3 | from utils.utils import maybe_cuda
4 | import torch
5 | from utils.buffer.buffer_utils import BufferClassTracker
6 | from utils.setup_elements import n_classes
7 |
8 | class Buffer(torch.nn.Module):
9 | def __init__(self, model, params):
10 | super().__init__()
11 | self.params = params
12 | self.model = model
13 | self.cuda = self.params.cuda
14 | self.current_index = 0
15 | self.n_seen_so_far = 0
16 | self.device = "cuda" if self.params.cuda else "cpu"
17 |
18 | # define buffer
19 | buffer_size = params.mem_size
20 | print('buffer has %d slots' % buffer_size)
21 | input_size = input_size_match[params.data]
22 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0))
23 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0))
24 |
25 | # registering as buffer allows us to save the object using `torch.save`
26 | self.register_buffer('buffer_img', buffer_img)
27 | self.register_buffer('buffer_label', buffer_label)
28 |
29 | # define update and retrieve method
30 | self.update_method = name_match.update_methods[params.update](params)
31 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params)
32 |
33 | if self.params.buffer_tracker:
34 | self.buffer_tracker = BufferClassTracker(n_classes[params.data], self.device)
35 |
36 | def update(self, x, y,**kwargs):
37 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs)
38 |
39 |
40 | def retrieve(self, **kwargs):
41 | return self.retrieve_method.retrieve(buffer=self, **kwargs)
--------------------------------------------------------------------------------
/models/ndpm/priors.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 |
4 | from utils.utils import maybe_cuda
5 |
6 |
7 | class Prior(ABC):
8 | def __init__(self, params):
9 | self.params = params
10 |
11 | @abstractmethod
12 | def add_expert(self):
13 | pass
14 |
15 | @abstractmethod
16 | def record_usage(self, usage, index=None):
17 | pass
18 |
19 | @abstractmethod
20 | def nl_prior(self, normalize=False):
21 | pass
22 |
23 |
24 | class CumulativePrior(Prior):
25 | def __init__(self, params):
26 | super().__init__(params)
27 | self.log_counts = maybe_cuda(torch.tensor(
28 | params.log_alpha
29 | )).float().unsqueeze(0)
30 |
31 | def add_expert(self):
32 | self.log_counts = torch.cat(
33 | [self.log_counts, maybe_cuda(torch.zeros(1))],
34 | dim=0
35 | )
36 |
37 | def record_usage(self, usage, index=None):
38 | """Record expert usage
39 |
40 | Args:
41 | usage: Tensor of shape [K+1] if index is None else scalar
42 | index: expert index
43 | """
44 | if index is None:
45 | self.log_counts = torch.logsumexp(torch.stack([
46 | self.log_counts,
47 | usage.log()
48 | ], dim=1), dim=1)
49 | else:
50 | self.log_counts[index] = torch.logsumexp(torch.stack([
51 | self.log_counts[index],
52 | maybe_cuda(torch.tensor(usage)).float().log()
53 | ], dim=0), dim=0)
54 |
55 | def nl_prior(self, normalize=False):
56 | nl_prior = -self.log_counts
57 | if normalize:
58 | nl_prior += torch.logsumexp(self.log_counts, dim=0)
59 | return nl_prior
60 |
61 | @property
62 | def counts(self):
63 | return self.log_counts.exp()
64 |
--------------------------------------------------------------------------------
/utils/name_match.py:
--------------------------------------------------------------------------------
1 | from agents.gdumb import Gdumb
2 | from continuum.dataset_scripts.cifar100 import CIFAR100
3 | from continuum.dataset_scripts.cifar10 import CIFAR10
4 | from continuum.dataset_scripts.core50 import CORE50
5 | from continuum.dataset_scripts.mini_imagenet import Mini_ImageNet
6 | from continuum.dataset_scripts.openloris import OpenLORIS
7 | from agents.exp_replay import ExperienceReplay
8 | from agents.exp_replay_dvc import ExperienceReplay_DVC
9 | from agents.agem import AGEM
10 | from agents.ewc_pp import EWC_pp
11 | from agents.cndpm import Cndpm
12 | from agents.lwf import Lwf
13 | from agents.icarl import Icarl
14 | from agents.scr import SupContrastReplay
15 | from utils.buffer.random_retrieve import Random_retrieve
16 | from utils.buffer.reservoir_update import Reservoir_update
17 | from utils.buffer.mir_retrieve import MIR_retrieve
18 |
19 | from utils.buffer.gss_greedy_update import GSSGreedyUpdate
20 | from utils.buffer.aser_retrieve import ASER_retrieve
21 | from utils.buffer.aser_update import ASER_update
22 | from utils.buffer.sc_retrieve import Match_retrieve
23 | from utils.buffer.mem_match import MemMatch_retrieve
24 | from utils.buffer.mgi_retrieve import MGI_retrieve
25 |
26 | data_objects = {
27 | 'cifar100': CIFAR100,
28 | 'cifar10': CIFAR10,
29 | 'core50': CORE50,
30 | 'mini_imagenet': Mini_ImageNet,
31 | 'openloris': OpenLORIS
32 | }
33 |
34 | agents = {
35 | 'ER': ExperienceReplay,
36 | 'ER_DVC': ExperienceReplay_DVC,
37 | 'EWC': EWC_pp,
38 | 'AGEM': AGEM,
39 | 'CNDPM': Cndpm,
40 | 'LWF': Lwf,
41 | 'ICARL': Icarl,
42 | 'GDUMB': Gdumb,
43 | 'SCR': SupContrastReplay,
44 | }
45 |
46 | retrieve_methods = {
47 | 'MIR': MIR_retrieve,
48 | 'MGI':MGI_retrieve,
49 | 'random': Random_retrieve,
50 | 'ASER': ASER_retrieve,
51 | 'match': Match_retrieve,
52 | 'mem_match': MemMatch_retrieve
53 |
54 | }
55 |
56 | update_methods = {
57 | 'random': Reservoir_update,
58 | 'GSS': GSSGreedyUpdate,
59 | 'ASER': ASER_update,
60 | }
61 |
62 |
--------------------------------------------------------------------------------
/utils/buffer/aser_update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.buffer.reservoir_update import Reservoir_update
3 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve
4 | from utils.buffer.aser_utils import compute_knn_sv, add_minority_class_input
5 | from utils.setup_elements import n_classes
6 | from utils.utils import nonzero_indices, maybe_cuda
7 |
8 |
9 | class ASER_update(object):
10 | def __init__(self, params, **kwargs):
11 | super().__init__()
12 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
13 | self.k = params.k
14 | self.mem_size = params.mem_size
15 | self.num_tasks = params.num_tasks
16 | self.out_dim = n_classes[params.data]
17 | self.n_smp_cls = int(params.n_smp_cls)
18 | self.n_total_smp = int(params.n_smp_cls * self.out_dim)
19 | self.reservoir_update = Reservoir_update(params)
20 | ClassBalancedRandomSampling.class_index_cache = None
21 |
22 | def update(self, buffer, x, y, **kwargs):
23 | model = buffer.model
24 |
25 | place_left = self.mem_size - buffer.current_index
26 |
27 | # If buffer is not filled, use available space to store whole or part of batch
28 | if place_left:
29 | x_fit = x[:place_left]
30 | y_fit = y[:place_left]
31 |
32 | ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device)
33 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim,
34 | new_y=y_fit, ind=ind, device=self.device)
35 | self.reservoir_update.update(buffer, x_fit, y_fit)
36 |
37 | # If buffer is filled, update buffer by sv
38 | if buffer.current_index == self.mem_size:
39 | self.reservoir_update.update(buffer, x_fit, y_fit)
40 |
41 |
42 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim,
43 | new_y=y_upt, ind=ind_buffer, device=self.device)
44 |
45 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DVC
2 | Code For CVPR2022 paper "[Not Just Selection, but Exploration: Online Class-Incremental Continual Learning via Dual View Consistency](https://openaccess.thecvf.com/content/CVPR2022/papers/Gu_Not_Just_Selection_but_Exploration_Online_Class-Incremental_Continual_Learning_via_CVPR_2022_paper.pdf)"
3 |
4 | ## Usage
5 |
6 | ### Requirements
7 | requirements.txt
8 |
9 | ### Data preparation
10 | - CIFAR10 & CIFAR100 will be downloaded during the first run. (datasets/cifar10;/datasets/cifar100)
11 | - Mini-ImageNet: Download from https://www.kaggle.com/whitemoon/miniimagenet/download, and place it in datasets/mini_imagenet/
12 |
13 |
14 | ### CIFAR-100
15 | ```shell
16 | python general_main.py --data cifar100 --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 1000 --dl_weight 4.0
17 | ```
18 |
19 | ### CIFAR-10
20 | ```shell
21 | python general_main.py --data cifar10 --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 200 --dl_weight 2.0 --num_tasks 5
22 | ```
23 |
24 | ### Mini-Imagenet
25 | ```shell
26 | python general_main.py --data mini_imagenet --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 1000 --dl_weight 4.0
27 | ```
28 |
29 |
30 |
31 |
32 |
33 | ## Reference
34 | [online-continual-learning](https://github.com/RaptorMai/online-continual-learning)
35 |
36 | [agmax](https://github.com/roatienza/agmax)
37 |
38 |
39 | If our code or models help your work, please cite our [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Gu_Not_Just_Selection_but_Exploration_Online_Class-Incremental_Continual_Learning_via_CVPR_2022_paper.pdf):
40 |
41 | ```shell
42 | @InProceedings{Gu_2022_CVPR,
43 | author = {Gu, Yanan and Yang, Xu and Wei, Kun and Deng, Cheng},
44 | title = {Not Just Selection, but Exploration: Online Class-Incremental Continual Learning via Dual View Consistency},
45 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
46 | month = {June},
47 | year = {2022},
48 | pages = {7442-7451}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/experiment/tune_hyperparam.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | from sklearn.model_selection import ParameterGrid
3 | from utils.setup_elements import setup_opt, setup_architecture
4 | from utils.utils import maybe_cuda
5 | from utils.name_match import agents
6 | import numpy as np
7 | from experiment.metrics import compute_performance
8 |
9 |
10 | def tune_hyper(tune_data, tune_test_loaders, default_params, tune_params):
11 | param_grid_list = list(ParameterGrid(tune_params))
12 | print(len(param_grid_list))
13 | tune_accs = []
14 | tune_fgt = []
15 | for param_set in param_grid_list:
16 | final_params = vars(default_params)
17 | print(param_set)
18 | final_params.update(param_set)
19 | final_params = SimpleNamespace(**final_params)
20 | accuracy_list = []
21 | for run in range(final_params.num_runs_val):
22 | tmp_acc = []
23 | model = setup_architecture(final_params)
24 | model = maybe_cuda(model, final_params.cuda)
25 | opt = setup_opt(final_params.optimizer, model, final_params.learning_rate, final_params.weight_decay)
26 | agent = agents[final_params.agent](model, opt, final_params)
27 | for i, (x_train, y_train, labels) in enumerate(tune_data):
28 | print("-----------tune run {} task {}-------------".format(run, i))
29 | print('size: {}, {}'.format(x_train.shape, y_train.shape))
30 | agent.train_learner(x_train, y_train)
31 | acc_array = agent.evaluate(tune_test_loaders)
32 | tmp_acc.append(acc_array)
33 | print(
34 | "-----------tune run {}-----------avg_end_acc {}-----------".format(run, np.mean(tmp_acc[-1])))
35 | accuracy_list.append(np.array(tmp_acc))
36 | accuracy_list = np.array(accuracy_list)
37 | avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt = compute_performance(accuracy_list)
38 | tune_accs.append(avg_end_acc[0])
39 | tune_fgt.append(avg_end_fgt[0])
40 | best_tune = param_grid_list[tune_accs.index(max(tune_accs))]
41 | return best_tune
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from timm.loss import LabelSmoothingCrossEntropy
6 |
7 |
8 | def cross_entropy_loss(z, zt, ytrue, label_smoothing=0):
9 | zz = torch.cat((z, zt))
10 | yy = torch.cat((ytrue, ytrue))
11 | if label_smoothing > 0:
12 | ce = LabelSmoothingCrossEntropy(label_smoothing)(zz, yy)
13 | else:
14 | ce = nn.CrossEntropyLoss()(zz, yy)
15 | return ce
16 |
17 |
18 | def cross_entropy(z, zt):
19 | # eps = np.finfo(float).eps
20 | Pz = F.softmax(z, dim=1)
21 | Pzt = F.softmax(zt, dim=1)
22 | # make sure no zero for log
23 | # Pz [(Pz < eps).data] = eps
24 | # Pzt [(Pzt < eps).data] = eps
25 | return -(Pz * torch.log(Pzt)).mean()
26 |
27 |
28 | def agmax_loss(y, ytrue, dl_weight=1.0):
29 | z, zt, zzt,_ = y
30 | Pz = F.softmax(z, dim=1)
31 | Pzt = F.softmax(zt, dim=1)
32 | Pzzt = F.softmax(zzt, dim=1)
33 |
34 | dl_loss = nn.L1Loss()
35 | yy = torch.cat((Pz, Pzt))
36 | zz = torch.cat((Pzzt, Pzzt))
37 | dl = dl_loss(zz, yy)
38 | dl *= dl_weight
39 |
40 | # -1/3*(H(z) + H(zt) + H(z, zt)), H(x) = -E[log(x)]
41 | entropy = entropy_loss(Pz, Pzt, Pzzt)
42 | return entropy, dl
43 |
44 |
45 |
46 |
47 | def clamp_to_eps(Pz, Pzt, Pzzt):
48 | eps = np.finfo(float).eps
49 | # make sure no zero for log
50 | Pz[(Pz < eps).data] = eps
51 | Pzt[(Pzt < eps).data] = eps
52 | Pzzt[(Pzzt < eps).data] = eps
53 |
54 | return Pz, Pzt, Pzzt
55 |
56 |
57 | def batch_probability(Pz, Pzt, Pzzt):
58 | Pz = Pz.sum(dim=0)
59 | Pzt = Pzt.sum(dim=0)
60 | Pzzt = Pzzt.sum(dim=0)
61 |
62 | Pz = Pz / Pz.sum()
63 | Pzt = Pzt / Pzt.sum()
64 | Pzzt = Pzzt / Pzzt.sum()
65 |
66 | # return Pz, Pzt, Pzzt
67 | return clamp_to_eps(Pz, Pzt, Pzzt)
68 |
69 |
70 | def entropy_loss(Pz, Pzt, Pzzt):
71 | # negative entropy loss
72 | Pz, Pzt, Pzzt = batch_probability(Pz, Pzt, Pzzt)
73 | entropy = (Pz * torch.log(Pz)).sum()
74 | entropy += (Pzt * torch.log(Pzt)).sum()
75 | entropy += (Pzzt * torch.log(Pzzt)).sum()
76 | entropy /= 3
77 | return entropy
78 |
79 |
80 |
--------------------------------------------------------------------------------
/models/ndpm/expert.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from models.ndpm.classifier import ResNetSharingClassifier
4 | from models.ndpm.vae import CnnSharingVae
5 | from utils.utils import maybe_cuda
6 |
7 | from utils.global_vars import *
8 |
9 |
10 | class Expert(nn.Module):
11 | def __init__(self, params, experts=()):
12 | super().__init__()
13 | self.id = len(experts)
14 | self.experts = experts
15 |
16 | self.g = maybe_cuda(CnnSharingVae(params, experts))
17 | self.d = maybe_cuda(ResNetSharingClassifier(params, experts)) if not MODELS_NDPM_NDPM_DISABLE_D else None
18 |
19 |
20 | # use random initialized g if it's a placeholder
21 | if self.id == 0:
22 | self.eval()
23 | for p in self.g.parameters():
24 | p.requires_grad = False
25 |
26 | # use random initialized d if it's a placeholder
27 | if self.id == 0 and self.d is not None:
28 | for p in self.d.parameters():
29 | p.requires_grad = False
30 |
31 | def forward(self, x):
32 | return self.d(x)
33 |
34 | def nll(self, x, y, step=None):
35 | """Negative log likelihood"""
36 | nll = self.g.nll(x, step)
37 | if self.d is not None:
38 | d_nll = self.d.nll(x, y, step)
39 | nll = nll + d_nll
40 | return nll
41 |
42 | def collect_nll(self, x, y, step=None):
43 | if self.id == 0:
44 | nll = self.nll(x, y, step)
45 | return nll.unsqueeze(1)
46 |
47 | nll = self.g.collect_nll(x, step)
48 | if self.d is not None:
49 | d_nll = self.d.collect_nll(x, y, step)
50 | nll = nll + d_nll
51 |
52 | return nll
53 |
54 | def lr_scheduler_step(self):
55 | if self.g.lr_scheduler is not NotImplemented:
56 | self.g.lr_scheduler.step()
57 | if self.d is not None and self.d.lr_scheduler is not NotImplemented:
58 | self.d.lr_scheduler.step()
59 |
60 | def clip_grad(self):
61 | self.g.clip_grad()
62 | if self.d is not None:
63 | self.d.clip_grad()
64 |
65 | def optimizer_step(self):
66 | self.g.optimizer.step()
67 | if self.d is not None:
68 | self.d.optimizer.step()
69 |
--------------------------------------------------------------------------------
/agents/lwf.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils.setup_elements import transforms_match
4 | from torch.utils import data
5 | from utils.utils import maybe_cuda, AverageMeter
6 | import torch
7 | import copy
8 |
9 |
10 | class Lwf(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(Lwf, self).__init__(model, opt, params)
13 |
14 | def train_learner(self, x_train, y_train):
15 | self.before_train(x_train, y_train)
16 |
17 | # set up loader
18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
20 | drop_last=True)
21 |
22 | # set up model
23 | self.model = self.model.train()
24 |
25 | # setup tracker
26 | losses_batch = AverageMeter()
27 | acc_batch = AverageMeter()
28 |
29 | for ep in range(self.epoch):
30 | for i, batch_data in enumerate(train_loader):
31 | # batch update
32 | batch_x, batch_y = batch_data
33 | batch_x = maybe_cuda(batch_x, self.cuda)
34 | batch_y = maybe_cuda(batch_y, self.cuda)
35 |
36 | logits = self.forward(batch_x)
37 | loss_old = self.kd_manager.get_kd_loss(logits, batch_x)
38 | loss_new = self.criterion(logits, batch_y)
39 | loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old
40 | _, pred_label = torch.max(logits, 1)
41 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
42 | # update tracker
43 | acc_batch.update(correct_cnt, batch_y.size(0))
44 | losses_batch.update(loss, batch_y.size(0))
45 | # backward
46 | self.opt.zero_grad()
47 | loss.backward()
48 | self.opt.step()
49 |
50 | if i % 100 == 1 and self.verbose:
51 | print(
52 | '==>>> it: {}, avg. loss: {:.6f}, '
53 | 'running train acc: {:.3f}'
54 | .format(i, losses_batch.avg(), acc_batch.avg())
55 | )
56 | self.after_train()
57 |
--------------------------------------------------------------------------------
/experiment/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.stats import sem
3 | import scipy.stats as stats
4 |
5 | def compute_performance(end_task_acc_arr):
6 | """
7 | Given test accuracy results from multiple runs saved in end_task_acc_arr,
8 | compute the average accuracy, forgetting, and task accuracies as well as their confidence intervals.
9 |
10 | :param end_task_acc_arr: (list) List of lists
11 | :param task_ids: (list or tuple) Task ids to keep track of
12 | :return: (avg_end_acc, forgetting, avg_acc_task)
13 | """
14 | n_run, n_tasks = end_task_acc_arr.shape[:2]
15 | t_coef = stats.t.ppf((1+0.95) / 2, n_run-1) # t coefficient used to compute 95% CIs: mean +- t *
16 |
17 | # compute average test accuracy and CI
18 | end_acc = end_task_acc_arr[:, -1, :] # shape: (num_run, num_task)
19 | avg_acc_per_run = np.mean(end_acc, axis=1) # mean of end task accuracies per run
20 | avg_end_acc = (np.mean(avg_acc_per_run), t_coef * sem(avg_acc_per_run))
21 |
22 | # compute forgetting
23 | best_acc = np.max(end_task_acc_arr, axis=1)
24 | final_forgets = best_acc - end_acc
25 | avg_fgt = np.mean(final_forgets, axis=1)
26 | avg_end_fgt = (np.mean(avg_fgt), t_coef * sem(avg_fgt))
27 |
28 | # compute ACC
29 | acc_per_run = np.mean((np.sum(np.tril(end_task_acc_arr), axis=2) /
30 | (np.arange(n_tasks) + 1)), axis=1)
31 | avg_acc = (np.mean(acc_per_run), t_coef * sem(acc_per_run))
32 |
33 |
34 | # compute BWT+
35 | bwt_per_run = (np.sum(np.tril(end_task_acc_arr, -1), axis=(1,2)) -
36 | np.sum(np.diagonal(end_task_acc_arr, axis1=1, axis2=2) *
37 | (np.arange(n_tasks, 0, -1) - 1), axis=1)) / (n_tasks * (n_tasks - 1) / 2)
38 | bwtp_per_run = np.maximum(bwt_per_run, 0)
39 | avg_bwtp = (np.mean(bwtp_per_run), t_coef * sem(bwtp_per_run))
40 |
41 | # compute FWT
42 | fwt_per_run = np.sum(np.triu(end_task_acc_arr, 1), axis=(1,2)) / (n_tasks * (n_tasks - 1) / 2)
43 | avg_fwt = (np.mean(fwt_per_run), t_coef * sem(fwt_per_run))
44 | return avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt
45 |
46 |
47 |
48 |
49 | def single_run_avg_end_fgt(acc_array):
50 | best_acc = np.max(acc_array, axis=1)
51 | end_acc = acc_array[-1]
52 | final_forgets = best_acc - end_acc
53 | avg_fgt = np.mean(final_forgets)
54 | return avg_fgt
55 |
--------------------------------------------------------------------------------
/utils/buffer/reservoir_update.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Reservoir_update(object):
5 | def __init__(self, params):
6 | super().__init__()
7 |
8 | def update(self, buffer, x, y, **kwargs):
9 | batch_size = x.size(0)
10 |
11 | # add whatever still fits in the buffer
12 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index)
13 | if place_left:
14 | offset = min(place_left, batch_size)
15 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset])
16 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset])
17 |
18 |
19 | buffer.current_index += offset
20 | buffer.n_seen_so_far += offset
21 |
22 | # everything was added
23 | if offset == x.size(0):
24 | filled_idx = list(range(buffer.current_index - offset, buffer.current_index, ))
25 | if buffer.params.buffer_tracker:
26 | buffer.buffer_tracker.update_cache(buffer.buffer_label, y[:offset], filled_idx)
27 | return filled_idx
28 |
29 |
30 | #TODO: the buffer tracker will have bug when the mem size can't be divided by batch size
31 |
32 | # remove what is already in the buffer
33 | x, y = x[place_left:], y[place_left:]
34 |
35 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long()
36 | valid_indices = (indices < buffer.buffer_img.size(0)).long()
37 |
38 | idx_new_data = valid_indices.nonzero().squeeze(-1)
39 | idx_buffer = indices[idx_new_data]
40 |
41 | buffer.n_seen_so_far += x.size(0)
42 |
43 | if idx_buffer.numel() == 0:
44 | return []
45 |
46 | assert idx_buffer.max() < buffer.buffer_img.size(0)
47 | assert idx_buffer.max() < buffer.buffer_label.size(0)
48 | # assert idx_buffer.max() < self.buffer_task.size(0)
49 |
50 | assert idx_new_data.max() < x.size(0)
51 | assert idx_new_data.max() < y.size(0)
52 |
53 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))}
54 |
55 | replace_y = y[list(idx_map.values())]
56 | if buffer.params.buffer_tracker:
57 | buffer.buffer_tracker.update_cache(buffer.buffer_label, replace_y, list(idx_map.keys()))
58 | # perform overwrite op
59 | buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())]
60 | buffer.buffer_label[list(idx_map.keys())] = replace_y
61 | return list(idx_map.keys())
--------------------------------------------------------------------------------
/utils/buffer/mir_retrieve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import maybe_cuda
3 | import torch.nn.functional as F
4 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector
5 | import copy
6 |
7 |
8 | class MIR_retrieve(object):
9 | def __init__(self, params, **kwargs):
10 | super().__init__()
11 | self.params = params
12 | self.subsample = params.subsample
13 | self.num_retrieve = params.eps_mem_batch
14 |
15 | def retrieve(self, buffer, **kwargs):
16 | sub_x, sub_y = random_retrieve(buffer, self.subsample)
17 | grad_dims = []
18 | for param in buffer.model.parameters():
19 | grad_dims.append(param.data.numel())
20 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims)
21 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims)
22 | if sub_x.size(0) > 0:
23 | with torch.no_grad():
24 | logits_pre,_ = buffer.model.forward(sub_x)
25 | logits_post,_ = model_temp.forward(sub_x)
26 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none')
27 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none')
28 | scores = post_loss - pre_loss
29 | big_ind = scores.sort(descending=True)[1][:self.num_retrieve]
30 | return sub_x[big_ind], sub_y[big_ind]
31 | else:
32 | return sub_x, sub_y
33 |
34 | def get_future_step_parameters(self, model, grad_vector, grad_dims):
35 | """
36 | computes \theta-\delta\theta
37 | :param this_net:
38 | :param grad_vector:
39 | :return:
40 | """
41 | new_model = copy.deepcopy(model)
42 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims)
43 | with torch.no_grad():
44 | for param in new_model.parameters():
45 | if param.grad is not None:
46 | param.data = param.data - self.params.learning_rate * param.grad.data
47 | return new_model
48 |
49 | def overwrite_grad(self, pp, new_grad, grad_dims):
50 | """
51 | This is used to overwrite the gradients with a new gradient
52 | vector, whenever violations occur.
53 | pp: parameters
54 | newgrad: corrected gradient
55 | grad_dims: list storing number of parameters at each layer
56 | """
57 | cnt = 0
58 | for param in pp():
59 | param.grad = torch.zeros_like(param.data)
60 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
61 | en = sum(grad_dims[:cnt + 1])
62 | this_grad = new_grad[beg: en].contiguous().view(
63 | param.data.size())
64 | param.grad.data.copy_(this_grad)
65 | cnt += 1
--------------------------------------------------------------------------------
/utils/setup_elements.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.resnet import Reduced_ResNet18, SupConResNet,Reduced_ResNet18_DVC
3 | from torchvision import transforms
4 | import torch.nn as nn
5 |
6 |
7 | default_trick = {'labels_trick': False, 'kd_trick': False, 'separated_softmax': False,
8 | 'review_trick': False, 'ncm_trick': False, 'kd_trick_star': False}
9 |
10 |
11 | input_size_match = {
12 | 'cifar100': [3, 32, 32],
13 | 'cifar10': [3, 32, 32],
14 | 'core50': [3, 128, 128],
15 | 'mini_imagenet': [3, 84, 84],
16 | 'openloris': [3, 50, 50]
17 | }
18 |
19 |
20 | n_classes = {
21 | 'cifar100': 100,
22 | 'cifar10': 10,
23 | 'core50': 50,
24 | 'mini_imagenet': 100,
25 | 'openloris': 69
26 | }
27 |
28 |
29 | transforms_match = {
30 | 'core50': transforms.Compose([
31 | transforms.ToTensor(),
32 | ]),
33 | 'cifar100': transforms.Compose([
34 | transforms.ToTensor(),
35 | ]),
36 | 'cifar10': transforms.Compose([
37 | transforms.ToTensor(),
38 | ]),
39 | 'mini_imagenet': transforms.Compose([
40 | transforms.ToTensor()]),
41 | 'openloris': transforms.Compose([
42 | transforms.ToTensor()])
43 | }
44 |
45 |
46 | def setup_architecture(params):
47 | nclass = n_classes[params.data]
48 | if params.agent in ['SCR', 'SCP']:
49 | if params.data == 'mini_imagenet':
50 | return SupConResNet(640, head=params.head)
51 | return SupConResNet(head=params.head)
52 | if params.agent == 'CNDPM':
53 | from models.ndpm.ndpm import Ndpm
54 | return Ndpm(params)
55 | if params.data == 'cifar100':
56 | if params.agent == 'ER_DVC':
57 | return Reduced_ResNet18_DVC(nclass)
58 | else:
59 | return Reduced_ResNet18(nclass)
60 | elif params.data == 'cifar10':
61 | if params.agent == 'ER_DVC':
62 | return Reduced_ResNet18_DVC(nclass)
63 | else:
64 | return Reduced_ResNet18(nclass)
65 | elif params.data == 'core50':
66 | model = Reduced_ResNet18(nclass)
67 | model.backbone.linear = nn.Linear(2560, nclass, bias=True)
68 | return model
69 | elif params.data == 'mini_imagenet':
70 | if params.agent == 'ER_DVC':
71 | model= Reduced_ResNet18_DVC(nclass)
72 | model.backbone.linear = nn.Linear(640, nclass, bias=True)
73 | else:
74 | model = Reduced_ResNet18(nclass)
75 | model.linear = nn.Linear(640, nclass, bias=True)
76 | return model
77 | elif params.data == 'openloris':
78 | return Reduced_ResNet18(nclass)
79 |
80 |
81 | def setup_opt(optimizer, model, lr, wd):
82 | if optimizer == 'SGD':
83 | optim = torch.optim.SGD(model.parameters(),
84 | lr=lr,
85 | weight_decay=wd)
86 | elif optimizer == 'Adam':
87 | optim = torch.optim.Adam(model.parameters(),
88 | lr=lr,
89 | weight_decay=wd)
90 | else:
91 | raise Exception('wrong optimizer name')
92 | return optim
93 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/cifar100.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets
3 | from continuum.data_utils import create_task_composition, load_task_with_labels
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 |
8 | class CIFAR100(DatasetBase):
9 | def __init__(self, scenario, params):
10 | dataset = 'cifar100'
11 | if scenario == 'ni':
12 | num_tasks = len(params.ns_factor)
13 | else:
14 | num_tasks = params.num_tasks
15 | super(CIFAR100, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
16 |
17 |
18 | def download_load(self):
19 | dataset_train = datasets.CIFAR100(root=self.root, train=True, download=True)
20 | self.train_data = dataset_train.data
21 | self.train_label = np.array(dataset_train.targets)
22 | dataset_test = datasets.CIFAR100(root=self.root, train=False, download=True)
23 | self.test_data = dataset_test.data
24 | self.test_label = np.array(dataset_test.targets)
25 |
26 | def setup(self):
27 | if self.scenario == 'ni':
28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
29 | self.train_label,
30 | self.test_data, self.test_label,
31 | self.task_nums, 32,
32 | self.params.val_size,
33 | self.params.ns_type, self.params.ns_factor,
34 | plot=self.params.plot_sample)
35 | elif self.scenario == 'nc':
36 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, fixed_order=self.params.fix_order)
37 | self.test_set = []
38 | for labels in self.task_labels:
39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
40 | self.test_set.append((x_test, y_test))
41 | else:
42 | raise Exception('wrong scenario')
43 |
44 | def new_task(self, cur_task, **kwargs):
45 | if self.scenario == 'ni':
46 | x_train, y_train = self.train_set[cur_task]
47 | labels = set(y_train)
48 | elif self.scenario == 'nc':
49 | labels = self.task_labels[cur_task]
50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
51 | return x_train, y_train, labels
52 |
53 | def new_run(self, **kwargs):
54 | self.setup()
55 | return self.test_set
56 |
57 | def test_plot(self):
58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
59 | self.params.ns_factor)
60 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/cifar10.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets
3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 |
8 | class CIFAR10(DatasetBase):
9 | def __init__(self, scenario, params):
10 | dataset = 'cifar10'
11 | if scenario == 'ni':
12 | num_tasks = len(params.ns_factor)
13 | else:
14 | num_tasks = params.num_tasks
15 | super(CIFAR10, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
16 |
17 |
18 | def download_load(self):
19 | dataset_train = datasets.CIFAR10(root=self.root, train=True, download=True)
20 | self.train_data = dataset_train.data
21 | self.train_label = np.array(dataset_train.targets)
22 | dataset_test = datasets.CIFAR10(root=self.root, train=False, download=True)
23 | self.test_data = dataset_test.data
24 | self.test_label = np.array(dataset_test.targets)
25 |
26 | def setup(self):
27 | if self.scenario == 'ni':
28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
29 | self.train_label,
30 | self.test_data, self.test_label,
31 | self.task_nums, 32,
32 | self.params.val_size,
33 | self.params.ns_type, self.params.ns_factor,
34 | plot=self.params.plot_sample)
35 | elif self.scenario == 'nc':
36 | self.task_labels = create_task_composition(class_nums=10, num_tasks=self.task_nums, fixed_order=self.params.fix_order)
37 | self.test_set = []
38 | for labels in self.task_labels:
39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
40 | self.test_set.append((x_test, y_test))
41 | else:
42 | raise Exception('wrong scenario')
43 |
44 | def new_task(self, cur_task, **kwargs):
45 | if self.scenario == 'ni':
46 | x_train, y_train = self.train_set[cur_task]
47 | labels = set(y_train)
48 | elif self.scenario == 'nc':
49 | labels = self.task_labels[cur_task]
50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
51 | return x_train, y_train, labels
52 |
53 | def new_run(self, **kwargs):
54 | self.setup()
55 | return self.test_set
56 |
57 | def test_plot(self):
58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
59 | self.params.ns_factor)
60 |
--------------------------------------------------------------------------------
/agents/icarl.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils import utils
4 | from utils.buffer.buffer_utils import random_retrieve
5 | from utils.setup_elements import transforms_match
6 | from torch.utils import data
7 | import numpy as np
8 | from torch.nn import functional as F
9 | from utils.utils import maybe_cuda, AverageMeter
10 | from utils.buffer.buffer import Buffer
11 | import torch
12 | import copy
13 |
14 |
15 | class Icarl(ContinualLearner):
16 | def __init__(self, model, opt, params):
17 | super(Icarl, self).__init__(model, opt, params)
18 | self.model = model
19 | self.mem_size = params.mem_size
20 | self.buffer = Buffer(model, params)
21 | self.prev_model = None
22 |
23 | def train_learner(self, x_train, y_train):
24 | self.before_train(x_train, y_train)
25 | # set up loader
26 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
27 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
28 | drop_last=True)
29 | self.model.train()
30 | self.update_representation(train_loader)
31 | self.prev_model = copy.deepcopy(self.model)
32 | self.after_train()
33 |
34 | def update_representation(self, train_loader):
35 | updated_idx = []
36 | for ep in range(self.epoch):
37 | for i, train_data in enumerate(train_loader):
38 | # batch update
39 | train_x, train_y = train_data
40 | train_x = maybe_cuda(train_x, self.cuda)
41 | train_y = maybe_cuda(train_y, self.cuda)
42 | train_y_copy = train_y.clone()
43 | for k, y in enumerate(train_y_copy):
44 | train_y_copy[k] = len(self.old_labels) + self.new_labels.index(y)
45 | all_cls_num = len(self.new_labels) + len(self.old_labels)
46 | target_labels = utils.ohe_label(train_y_copy, all_cls_num, device=train_y_copy.device).float()
47 | if self.prev_model is not None:
48 | mem_x, mem_y = random_retrieve(self.buffer, self.batch,
49 | excl_indices=updated_idx)
50 | mem_x = maybe_cuda(mem_x, self.cuda)
51 | batch_x = torch.cat([train_x, mem_x])
52 | target_labels = torch.cat([target_labels, torch.zeros_like(target_labels)])
53 | else:
54 | batch_x = train_x
55 | logits = self.forward(batch_x)
56 | self.opt.zero_grad()
57 | if self.prev_model is not None:
58 | with torch.no_grad():
59 | q = torch.sigmoid(self.prev_model.forward(batch_x))
60 | for k, y in enumerate(self.old_labels):
61 | target_labels[:, k] = q[:, k]
62 | loss = F.binary_cross_entropy_with_logits(logits[:, :all_cls_num], target_labels, reduction='none').sum(dim=1).mean()
63 | loss.backward()
64 | self.opt.step()
65 | updated_idx += self.buffer.update(train_x, train_y)
66 |
--------------------------------------------------------------------------------
/utils/L_softmax.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from scipy.special import binom
5 |
6 |
7 | class LSoftmaxLinear(nn.Module):
8 |
9 | def __init__(self, input_features, output_features, margin):
10 | super().__init__()
11 | self.input_dim = input_features # number of input feature i.e. output of the last fc layer
12 | self.output_dim = output_features # number of output = class numbers
13 | self.margin = margin # m
14 | self.beta = 100
15 | self.beta_min = 0
16 | self.scale = 0.99
17 |
18 | # gpu or cpu
19 |
20 | # Initialize L-Softmax parameters
21 | self.weight = nn.Parameter(torch.FloatTensor(input_features, output_features))
22 | self.divisor = math.pi / self.margin # pi/m
23 | self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1, 2))).cuda() # C_m{2n}
24 | self.cos_powers = torch.Tensor(range(self.margin, -1, -2)).cuda() # m - 2n
25 | self.sin2_powers = torch.Tensor(range(len(self.cos_powers))).cuda() # n
26 | self.signs = torch.ones(margin // 2 + 1).cuda() # 1, -1, 1, -1, ...
27 | self.signs[1::2] = -1
28 |
29 | def calculate_cos_m_theta(self, cos_theta):
30 | sin2_theta = 1 - cos_theta**2
31 | cos_terms = cos_theta.unsqueeze(1) ** self.cos_powers.unsqueeze(0) # cos^{m - 2n}
32 | sin2_terms = (sin2_theta.unsqueeze(1) # sin2^{n}
33 | ** self.sin2_powers.unsqueeze(0))
34 |
35 | cos_m_theta = (self.signs.unsqueeze(0) * # -1^{n} * C_m{2n} * cos^{m - 2n} * sin2^{n}
36 | self.C_m_2n.unsqueeze(0) *
37 | cos_terms *
38 | sin2_terms).sum(1) # summation of all terms
39 |
40 | return cos_m_theta
41 |
42 | def reset_parameters(self):
43 | nn.init.kaiming_normal_(self.weight.data.t())
44 |
45 | def find_k(self, cos):
46 | # to account for acos numerical errors
47 | eps = 1e-7
48 | cos = torch.clamp(cos, -1 + eps, 1 - eps)
49 | acos = cos.acos()
50 | k = (acos / self.divisor).floor().detach()
51 | return k
52 |
53 | def forward(self, input, target=None):
54 | if self.training:
55 | assert target is not None
56 | x, w = input, self.weight
57 | beta = max(self.beta, self.beta_min)
58 | logit = x.mm(w)
59 | indexes = range(logit.size(0))
60 | logit_target = logit[indexes, target]
61 |
62 | # cos(theta) = w * x / ||w||*||x||
63 | w_target_norm = w[:, target].norm(p=2, dim=0)
64 | x_norm = x.norm(p=2, dim=1)
65 | cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10)
66 |
67 | # equation 7
68 | cos_m_theta_target = self.calculate_cos_m_theta(cos_theta_target)
69 |
70 | # find k in equation 6
71 | k = self.find_k(cos_theta_target)
72 |
73 | # f_y_i
74 | logit_target_updated = (w_target_norm *
75 | x_norm *
76 | (((-1) ** k * cos_m_theta_target) - 2 * k))
77 | logit_target_updated_beta = (logit_target_updated + beta * logit[indexes, target]) / (1 + beta)
78 |
79 | logit[indexes, target] = logit_target_updated_beta
80 | self.beta *= self.scale
81 | return logit
82 | else:
83 | assert target is None
84 | return input.mm(self.weight)
--------------------------------------------------------------------------------
/continuum/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils import data
4 | from utils.setup_elements import transforms_match
5 |
6 | def create_task_composition(class_nums, num_tasks, fixed_order=False):
7 | classes_per_task = class_nums // num_tasks
8 | total_classes = classes_per_task * num_tasks
9 | label_array = np.arange(0, total_classes)
10 | if not fixed_order:
11 | np.random.shuffle(label_array)
12 |
13 | task_labels = []
14 | for tt in range(num_tasks):
15 | tt_offset = tt * classes_per_task
16 | task_labels.append(list(label_array[tt_offset:tt_offset + classes_per_task]))
17 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt]))
18 | return task_labels
19 |
20 |
21 | def load_task_with_labels_torch(x, y, labels):
22 | tmp = []
23 | for i in labels:
24 | tmp.append((y == i).nonzero().view(-1))
25 | idx = torch.cat(tmp)
26 | return x[idx], y[idx]
27 |
28 |
29 | def load_task_with_labels(x, y, labels):
30 | tmp = []
31 | for i in labels:
32 | tmp.append((np.where(y == i)[0]))
33 | idx = np.concatenate(tmp, axis=None)
34 | return x[idx], y[idx]
35 |
36 |
37 |
38 | class dataset_transform(data.Dataset):
39 | def __init__(self, x, y, transform=None):
40 | self.x = x
41 | self.y = torch.from_numpy(y).type(torch.LongTensor)
42 | self.transform = transform # save the transform
43 |
44 | def __len__(self):
45 | return len(self.y)#self.x.shape[0] # return 1 as we have only one image
46 |
47 | def __getitem__(self, idx):
48 | # return the augmented image
49 | if self.transform:
50 | x = self.transform(self.x[idx])
51 | else:
52 | x = self.x[idx]
53 |
54 | return x.float(), self.y[idx]
55 |
56 |
57 | def setup_test_loader(test_data, params):
58 | test_loaders = []
59 |
60 | for (x_test, y_test) in test_data:
61 | test_dataset = dataset_transform(x_test, y_test, transform=transforms_match[params.data])
62 | test_loader = data.DataLoader(test_dataset, batch_size=params.test_batch, shuffle=True, num_workers=0)
63 | test_loaders.append(test_loader)
64 | return test_loaders
65 |
66 |
67 | def shuffle_data(x, y):
68 | perm_inds = np.arange(0, x.shape[0])
69 | np.random.shuffle(perm_inds)
70 | rdm_x = x[perm_inds]
71 | rdm_y = y[perm_inds]
72 | return rdm_x, rdm_y
73 |
74 |
75 | def train_val_test_split_ni(train_data, train_label, test_data, test_label, task_nums, img_size, val_size=0.1):
76 | train_data_rdm, train_label_rdm = shuffle_data(train_data, train_label)
77 | val_size = int(len(train_data_rdm) * val_size)
78 | val_data_rdm, val_label_rdm = train_data_rdm[:val_size], train_label_rdm[:val_size]
79 | train_data_rdm, train_label_rdm = train_data_rdm[val_size:], train_label_rdm[val_size:]
80 | test_data_rdm, test_label_rdm = shuffle_data(test_data, test_label)
81 | train_data_rdm_split = train_data_rdm.reshape(task_nums, -1, img_size, img_size, 3)
82 | train_label_rdm_split = train_label_rdm.reshape(task_nums, -1)
83 | val_data_rdm_split = val_data_rdm.reshape(task_nums, -1, img_size, img_size, 3)
84 | val_label_rdm_split = val_label_rdm.reshape(task_nums, -1)
85 | test_data_rdm_split = test_data_rdm.reshape(task_nums, -1, img_size, img_size, 3)
86 | test_label_rdm_split = test_label_rdm.reshape(task_nums, -1)
87 | return train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split
--------------------------------------------------------------------------------
/agents/scr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | from utils.buffer.buffer import Buffer
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform
6 | from utils.setup_elements import transforms_match, input_size_match
7 | from utils.utils import maybe_cuda, AverageMeter
8 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale
9 | import torch.nn as nn
10 |
11 | class SupContrastReplay(ContinualLearner):
12 | def __init__(self, model, opt, params):
13 | super(SupContrastReplay, self).__init__(model, opt, params)
14 | self.buffer = Buffer(model, params)
15 | self.mem_size = params.mem_size
16 | self.eps_mem_batch = params.eps_mem_batch
17 | self.mem_iters = params.mem_iters
18 | self.transform = nn.Sequential(
19 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)),
20 | RandomHorizontalFlip(),
21 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
22 | RandomGrayscale(p=0.2)
23 |
24 | )
25 | self.criterion_ce= torch.nn.CrossEntropyLoss(reduction='mean')
26 |
27 | def train_learner(self, x_train, y_train):
28 | self.before_train(x_train, y_train)
29 | # set up loader
30 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
31 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
32 | drop_last=True)
33 | # set up model
34 | self.model = self.model.train()
35 |
36 | # setup tracker
37 | losses = AverageMeter()
38 | acc_batch = AverageMeter()
39 |
40 | for ep in range(self.epoch):
41 | for i, batch_data in enumerate(train_loader):
42 | # batch update
43 | batch_x, batch_y = batch_data
44 | batch_x = maybe_cuda(batch_x, self.cuda)
45 | batch_y = maybe_cuda(batch_y, self.cuda)
46 |
47 | for j in range(self.mem_iters):
48 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
49 |
50 | if mem_x.size(0) > 0:
51 | mem_x = maybe_cuda(mem_x, self.cuda)
52 | mem_y = maybe_cuda(mem_y, self.cuda)
53 | combined_batch = torch.cat((mem_x, batch_x))
54 | combined_labels = torch.cat((mem_y, batch_y))
55 | combined_batch_aug = self.transform(combined_batch)
56 | batch_features, batch_logits = self.model.forward(combined_batch)
57 | batch_features_aug, batch_logits_aug = self.model.forward(combined_batch_aug)
58 | features = torch.cat([batch_features.unsqueeze(1), batch_features_aug.unsqueeze(1)], dim=1)
59 | loss_1 = self.criterion(features, combined_labels)
60 | loss_2 = self.criterion_ce(batch_logits,combined_labels)
61 | loss_3 = self.criterion_ce(batch_logits_aug, combined_labels)
62 | loss = 0.2*loss_1 + loss_2 + loss_3
63 | losses.update(loss, batch_y.size(0))
64 | self.opt.zero_grad()
65 | loss.backward()
66 | self.opt.step()
67 |
68 | # update mem
69 | self.buffer.update(batch_x, batch_y)
70 | if i % 100 == 1 and self.verbose:
71 | print(
72 | '==>>> it: {}, avg. loss: {:.6f}, '
73 | .format(i, losses.avg(), acc_batch.avg())
74 | )
75 | self.after_train()
76 |
--------------------------------------------------------------------------------
/features/wide_resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch.nn as nn
6 | from . import extractor
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=True)
13 |
14 |
15 | def conv1x1(in_planes, out_planes, stride=1):
16 | """1x1 convolution"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
18 | bias=True)
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | def __init__(self, inplanes, planes, dropout, stride=1):
23 | super(BasicBlock, self).__init__()
24 | self.bn1 = nn.BatchNorm2d(inplanes)
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.dropout = nn.Dropout(dropout)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | if stride != 1 or inplanes != planes:
31 | self.shortcut = conv1x1(inplanes, planes, stride)
32 | self.use_conv1x1 = True
33 | else:
34 | self.use_conv1x1 = False
35 |
36 | def forward(self, x):
37 | out = self.bn1(x)
38 | out = self.relu(out)
39 |
40 | if self.use_conv1x1:
41 | shortcut = self.shortcut(out)
42 | else:
43 | shortcut = x
44 |
45 | out = self.conv1(out)
46 |
47 | out = self.bn2(out)
48 | out = self.relu(out)
49 | out = self.dropout(out)
50 | out = self.conv2(out)
51 |
52 | out += shortcut
53 |
54 | return out
55 |
56 |
57 | class WideResNet(extractor.BaseModule):
58 | def __init__(self, config, name):
59 | super(WideResNet, self).__init__()
60 | self.name = name
61 | depth = config["depth"]
62 | width = config["width"]
63 | dropout = config["dropout"]
64 | in_channels = config["channels"]
65 | print("%s depth: %d, width: %d, dropout=%f" \
66 | % (type(self), depth, width, dropout))
67 |
68 | layer = (depth - 4) // 6
69 |
70 | self.inplanes = 16
71 | self.conv = conv3x3(in_channels, 16)
72 | self.layer1 = self._make_layer(16*width, layer, dropout)
73 | self.layer2 = self._make_layer(32*width, layer, dropout, stride=2)
74 | self.layer3 = self._make_layer(64*width, layer, dropout, stride=2)
75 | self.bn = nn.BatchNorm2d(64*width)
76 | self.relu = nn.ReLU(inplace=True)
77 |
78 | self.n_features = 64 * width
79 | #self.avgpool = nn.AdaptiveAvgPool2d(1)
80 | #self.fc = nn.Linear(64*width, num_classes)
81 |
82 | # use default init
83 | #for m in self.modules():
84 | # if isinstance(m, nn.Conv2d):
85 | # m.weight.data = nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu')
86 |
87 | def _make_layer(self, planes, blocks, dropout, stride=1):
88 | layers = []
89 | for i in range(blocks):
90 | layers.append(BasicBlock(self.inplanes, planes, dropout, stride if i == 0 else 1))
91 | self.inplanes = planes
92 |
93 | return nn.Sequential(*layers)
94 |
95 |
96 | def forward(self, x):
97 | x = self.conv(x)
98 |
99 | x = self.layer1(x)
100 | x = self.layer2(x)
101 | x = self.layer3(x)
102 |
103 | x = self.bn(x)
104 | x = self.relu(x)
105 | #x = self.avgpool(x)
106 | #x = x.view(x.size(0), -1)
107 | #x = self.fc(x)
108 |
109 | return x
110 |
111 |
112 |
113 | if __name__ == '__main__':
114 | pass
115 |
--------------------------------------------------------------------------------
/agents/gdumb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | import math
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform
6 | from utils.setup_elements import transforms_match, setup_architecture, setup_opt
7 | from utils.utils import maybe_cuda, EarlyStopping
8 | import numpy as np
9 | import random
10 |
11 |
12 | class Gdumb(ContinualLearner):
13 | def __init__(self, model, opt, params):
14 | super(Gdumb, self).__init__(model, opt, params)
15 | self.mem_img = {}
16 | self.mem_c = {}
17 | #self.early_stopping = EarlyStopping(self.params.min_delta, self.params.patience, self.params.cumulative_delta)
18 |
19 | def greedy_balancing_update(self, x, y):
20 | k_c = self.params.mem_size // max(1, len(self.mem_img))
21 | if y not in self.mem_img or self.mem_c[y] < k_c:
22 | if sum(self.mem_c.values()) >= self.params.mem_size:
23 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0]
24 | idx = random.randrange(self.mem_c[cls_max])
25 | self.mem_img[cls_max].pop(idx)
26 | self.mem_c[cls_max] -= 1
27 | if y not in self.mem_img:
28 | self.mem_img[y] = []
29 | self.mem_c[y] = 0
30 | self.mem_img[y].append(x)
31 | self.mem_c[y] += 1
32 |
33 | def train_learner(self, x_train, y_train):
34 | self.before_train(x_train, y_train)
35 | # set up loader
36 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
37 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
38 | drop_last=True)
39 |
40 | for i, batch_data in enumerate(train_loader):
41 | # batch update
42 | batch_x, batch_y = batch_data
43 | batch_x = maybe_cuda(batch_x, self.cuda)
44 | batch_y = maybe_cuda(batch_y, self.cuda)
45 | # update mem
46 | for j in range(len(batch_x)):
47 | self.greedy_balancing_update(batch_x[j], batch_y[j].item())
48 | #self.early_stopping.reset()
49 | self.train_mem()
50 | self.after_train()
51 |
52 | def train_mem(self):
53 | mem_x = []
54 | mem_y = []
55 | for i in self.mem_img.keys():
56 | mem_x += self.mem_img[i]
57 | mem_y += [i] * self.mem_c[i]
58 |
59 | mem_x = torch.stack(mem_x)
60 | mem_y = torch.LongTensor(mem_y)
61 | self.model = setup_architecture(self.params)
62 | self.model = maybe_cuda(self.model, self.cuda)
63 | opt = setup_opt(self.params.optimizer, self.model, self.params.learning_rate, self.params.weight_decay)
64 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1, T_mult=2, eta_min=self.params.minlr)
65 |
66 | #loss = math.inf
67 | for i in range(self.params.mem_epoch):
68 | idx = np.random.permutation(len(mem_x)).tolist()
69 | mem_x = maybe_cuda(mem_x[idx], self.cuda)
70 | mem_y = maybe_cuda(mem_y[idx], self.cuda)
71 | self.model = self.model.train()
72 | batch_size = self.params.batch
73 | #scheduler.step()
74 | #if opt.param_groups[0]['lr'] == self.params.learning_rate:
75 | # if self.early_stopping.step(-loss):
76 | # return
77 | for j in range(len(mem_y) // batch_size):
78 | opt.zero_grad()
79 | logits = self.model.forward(mem_x[batch_size * j:batch_size * (j + 1)])
80 | loss = self.criterion(logits, mem_y[batch_size * j:batch_size * (j + 1)])
81 | loss.backward()
82 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip)
83 | opt.step()
84 |
85 |
--------------------------------------------------------------------------------
/utils/buffer/mgi_retrieve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector
4 | import copy
5 | from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip,ColorJitter,RandomGrayscale
6 | from utils.setup_elements import transforms_match, input_size_match
7 | import torch.nn as nn
8 | from utils.setup_elements import n_classes
9 |
10 |
11 | class MGI_retrieve(object):
12 | def __init__(self, params, **kwargs):
13 | super().__init__()
14 | self.params = params
15 | self.subsample = params.subsample
16 | self.num_retrieve = params.eps_mem_batch
17 | self.transform = nn.Sequential(
18 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]),
19 | scale=(0.2, 1.)).cuda(),
20 | RandomHorizontalFlip().cuda(),
21 | ColorJitter(0.4, 0.4, 0.4, 0.1),
22 | RandomGrayscale(p=0.2)
23 |
24 | )
25 | self.out_dim = n_classes[params.data]
26 |
27 | def retrieve(self, buffer, **kwargs):
28 | sub_x, sub_y = random_retrieve(buffer, self.subsample)
29 | grad_dims = []
30 | for param in buffer.model.parameters():
31 | grad_dims.append(param.data.numel())
32 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims)
33 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims)
34 | if sub_x.size(0) > 0:
35 | with torch.no_grad():
36 | sub_x_aug = self.transform(sub_x)
37 | logits_pre = buffer.model(sub_x,sub_x_aug)
38 | logits_post = model_temp(sub_x,sub_x_aug)
39 |
40 | z_pre, zt_pre, zzt_pre,fea_z_pre = logits_pre
41 | z_post, zt_post, zzt_post,fea_z_post = logits_post
42 |
43 |
44 | grads_pre_z= torch.sum(torch.abs(F.softmax(z_pre, dim=1) - F.one_hot(sub_y, self.out_dim)), 1)
45 | mgi_pre_z = grads_pre_z * fea_z_pre[0].reshape(-1)
46 | grads_post_z = torch.sum(torch.abs(F.softmax(z_post, dim=1) - F.one_hot(sub_y, self.out_dim)), 1) # N * 1
47 | mgi_post_z = grads_post_z * fea_z_post[0].reshape(-1)
48 |
49 |
50 | scores = mgi_post_z - mgi_pre_z
51 |
52 | big_ind = scores.sort(descending=True)[1][:int(self.num_retrieve)]
53 |
54 |
55 |
56 | return sub_x[big_ind], sub_x_aug[big_ind],sub_y[big_ind]
57 | else:
58 | return sub_x, sub_x,sub_y
59 |
60 | def get_future_step_parameters(self, model, grad_vector, grad_dims):
61 | """
62 | computes \theta-\delta\theta
63 | :param this_net:
64 | :param grad_vector:
65 | :return:
66 | """
67 | new_model = copy.deepcopy(model)
68 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims)
69 | with torch.no_grad():
70 | for param in new_model.parameters():
71 | if param.grad is not None:
72 | param.data = param.data - self.params.learning_rate * param.grad.data
73 | return new_model
74 |
75 | def overwrite_grad(self, pp, new_grad, grad_dims):
76 | """
77 | This is used to overwrite the gradients with a new gradient
78 | vector, whenever violations occur.
79 | pp: parameters
80 | newgrad: corrected gradient
81 | grad_dims: list storing number of parameters at each layer
82 | """
83 | cnt = 0
84 | for param in pp():
85 | param.grad = torch.zeros_like(param.data)
86 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
87 | en = sum(grad_dims[:cnt + 1])
88 | this_grad = new_grad[beg: en].contiguous().view(
89 | param.data.size())
90 | param.grad.data.copy_(this_grad)
91 | cnt += 1
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Yonglong Tian (yonglong@mit.edu)
3 | Date: May 07, 2020
4 | """
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class SupConLoss(nn.Module):
12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
13 | It also supports the unsupervised contrastive loss in SimCLR"""
14 | def __init__(self, temperature=0.07, contrast_mode='all'):
15 | super(SupConLoss, self).__init__()
16 | self.temperature = temperature
17 | self.contrast_mode = contrast_mode
18 |
19 | def forward(self, features, labels=None, mask=None):
20 | """Compute loss for model. If both `labels` and `mask` are None,
21 | it degenerates to SimCLR unsupervised loss:
22 | https://arxiv.org/pdf/2002.05709.pdf
23 |
24 | Args:
25 | features: hidden vector of shape [bsz, n_views, ...].
26 | labels: ground truth of shape [bsz].
27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
28 | has the same class as sample i. Can be asymmetric.
29 | Returns:
30 | A loss scalar.
31 | """
32 | device = (torch.device('cuda')
33 | if features.is_cuda
34 | else torch.device('cpu'))
35 |
36 | if len(features.shape) < 3:
37 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
38 | 'at least 3 dimensions are required')
39 | if len(features.shape) > 3:
40 | features = features.view(features.shape[0], features.shape[1], -1)
41 |
42 | batch_size = features.shape[0]
43 | if labels is not None and mask is not None:
44 | raise ValueError('Cannot define both `labels` and `mask`')
45 | elif labels is None and mask is None:
46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
47 | elif labels is not None:
48 | labels = labels.contiguous().view(-1, 1)
49 | if labels.shape[0] != batch_size:
50 | raise ValueError('Num of labels does not match num of features')
51 | mask = torch.eq(labels, labels.T).float().to(device)
52 | else:
53 | mask = mask.float().to(device)
54 |
55 | contrast_count = features.shape[1]
56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
57 | if self.contrast_mode == 'one':
58 | anchor_feature = features[:, 0]
59 | anchor_count = 1
60 | elif self.contrast_mode == 'all':
61 | anchor_feature = contrast_feature
62 | anchor_count = contrast_count
63 | else:
64 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
65 |
66 | # compute logits
67 | anchor_dot_contrast = torch.div(
68 | torch.matmul(anchor_feature, contrast_feature.T),
69 | self.temperature)
70 | # for numerical stability
71 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
72 | logits = anchor_dot_contrast - logits_max.detach()
73 |
74 | # tile mask
75 | mask = mask.repeat(anchor_count, contrast_count)
76 | # mask-out self-contrast cases
77 | logits_mask = torch.scatter(
78 | torch.ones_like(mask),
79 | 1,
80 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
81 | 0
82 | )
83 | mask = mask * logits_mask
84 |
85 | # compute log_prob
86 | exp_logits = torch.exp(logits) * logits_mask
87 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
88 |
89 | # compute mean of log-likelihood over positive
90 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
91 |
92 | # loss
93 | loss = -1 * mean_log_prob_pos
94 | loss = loss.view(anchor_count, batch_size).mean()
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/continuum/dataset_scripts/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns
6 |
7 | TEST_SPLIT = 1 / 6
8 |
9 |
10 | class Mini_ImageNet(DatasetBase):
11 | def __init__(self, scenario, params):
12 | dataset = 'mini_imagenet'
13 | if scenario == 'ni':
14 | num_tasks = len(params.ns_factor)
15 | else:
16 | num_tasks = params.num_tasks
17 | super(Mini_ImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params)
18 |
19 |
20 | def download_load(self):
21 | train_in = open("datasets/mini_imagenet/mini-imagenet-cache-train.pkl", "rb")
22 | train = pickle.load(train_in)
23 | train_x = train["image_data"].reshape([64, 600, 84, 84, 3])
24 | val_in = open("datasets/mini_imagenet/mini-imagenet-cache-val.pkl", "rb")
25 | val = pickle.load(val_in)
26 | val_x = val['image_data'].reshape([16, 600, 84, 84, 3])
27 | test_in = open("datasets/mini_imagenet/mini-imagenet-cache-test.pkl", "rb")
28 | test = pickle.load(test_in)
29 | test_x = test['image_data'].reshape([20, 600, 84, 84, 3])
30 | all_data = np.vstack((train_x, val_x, test_x))
31 | train_data = []
32 | train_label = []
33 | test_data = []
34 | test_label = []
35 | for i in range(len(all_data)):
36 | cur_x = all_data[i]
37 | cur_y = np.ones((600,)) * i
38 | rdm_x, rdm_y = shuffle_data(cur_x, cur_y)
39 | x_test = rdm_x[: int(600 * TEST_SPLIT)]
40 | y_test = rdm_y[: int(600 * TEST_SPLIT)]
41 | x_train = rdm_x[int(600 * TEST_SPLIT):]
42 | y_train = rdm_y[int(600 * TEST_SPLIT):]
43 | train_data.append(x_train)
44 | train_label.append(y_train)
45 | test_data.append(x_test)
46 | test_label.append(y_test)
47 | self.train_data = np.concatenate(train_data)
48 | self.train_label = np.concatenate(train_label)
49 | self.test_data = np.concatenate(test_data)
50 | self.test_label = np.concatenate(test_label)
51 |
52 | def new_run(self, **kwargs):
53 | self.setup()
54 | return self.test_set
55 |
56 | def new_task(self, cur_task, **kwargs):
57 | if self.scenario == 'ni':
58 | x_train, y_train = self.train_set[cur_task]
59 | labels = set(y_train)
60 | elif self.scenario == 'nc':
61 | labels = self.task_labels[cur_task]
62 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels)
63 | else:
64 | raise Exception('unrecognized scenario')
65 | return x_train, y_train, labels
66 |
67 | def setup(self):
68 | if self.scenario == 'ni':
69 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data,
70 | self.train_label,
71 | self.test_data, self.test_label,
72 | self.task_nums, 84,
73 | self.params.val_size,
74 | self.params.ns_type, self.params.ns_factor,
75 | plot=self.params.plot_sample)
76 |
77 | elif self.scenario == 'nc':
78 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums,
79 | fixed_order=self.params.fix_order)
80 | self.test_set = []
81 | for labels in self.task_labels:
82 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
83 | self.test_set.append((x_test, y_test))
84 |
85 | def test_plot(self):
86 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type,
87 | self.params.ns_factor)
88 |
--------------------------------------------------------------------------------
/models/ndpm/component.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 | import torch.nn as nn
4 | from typing import Tuple
5 |
6 | from utils.utils import maybe_cuda
7 | from utils.global_vars import *
8 |
9 |
10 | class Component(nn.Module, ABC):
11 | def __init__(self, params, experts: Tuple):
12 | super().__init__()
13 | self.params = params
14 | self.experts = experts
15 |
16 | self.optimizer = NotImplemented
17 | self.lr_scheduler = NotImplemented
18 |
19 | @abstractmethod
20 | def nll(self, x, y, step=None):
21 | """Return NLL"""
22 | pass
23 |
24 | @abstractmethod
25 | def collect_nll(self, x, y, step=None):
26 | """Return NLLs including previous experts"""
27 | pass
28 |
29 | def _clip_grad_value(self, clip_value):
30 | for group in self.optimizer.param_groups:
31 | nn.utils.clip_grad_value_(group['params'], clip_value)
32 |
33 | def _clip_grad_norm(self, max_norm, norm_type=2):
34 | for group in self.optimizer.param_groups:
35 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type)
36 |
37 | def clip_grad(self):
38 | clip_grad_config = MODELS_NDPM_COMPONENT_CLIP_GRAD
39 | if clip_grad_config['type'] == 'value':
40 | self._clip_grad_value(**clip_grad_config['options'])
41 | elif clip_grad_config['type'] == 'norm':
42 | self._clip_grad_norm(**clip_grad_config['options'])
43 | else:
44 | raise ValueError('Invalid clip_grad type: {}'
45 | .format(clip_grad_config['type']))
46 |
47 | @staticmethod
48 | def build_optimizer(optim_config, params):
49 | return getattr(torch.optim, optim_config['type'])(
50 | params, **optim_config['options'])
51 |
52 | @staticmethod
53 | def build_lr_scheduler(lr_config, optimizer):
54 | return getattr(torch.optim.lr_scheduler, lr_config['type'])(
55 | optimizer, **lr_config['options'])
56 |
57 | def weight_decay_loss(self):
58 | loss = maybe_cuda(torch.zeros([]))
59 | for param in self.parameters():
60 | loss += torch.norm(param) ** 2
61 | return loss
62 |
63 |
64 | class ComponentG(Component, ABC):
65 | def setup_optimizer(self):
66 | self.optimizer = self.build_optimizer(
67 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters())
68 | self.lr_scheduler = self.build_lr_scheduler(
69 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G, self.optimizer)
70 |
71 | def collect_nll(self, x, y=None, step=None):
72 | """Default `collect_nll`
73 |
74 | Warning: Parameter-sharing components should implement their own
75 | `collect_nll`
76 |
77 | Returns:
78 | nll: Tensor of shape [B, 1+K]
79 | """
80 | outputs = [expert.g.nll(x, y, step) for expert in self.experts]
81 | nll = outputs
82 | output = self.nll(x, y, step)
83 | nll.append(output)
84 | return torch.stack(nll, dim=1)
85 |
86 |
87 |
88 | class ComponentD(Component, ABC):
89 | def setup_optimizer(self):
90 | self.optimizer = self.build_optimizer(
91 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters())
92 | self.lr_scheduler = self.build_lr_scheduler(
93 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D, self.optimizer)
94 |
95 | def collect_forward(self, x):
96 | """Default `collect_forward`
97 |
98 | Warning: Parameter-sharing components should implement their own
99 | `collect_forward`
100 |
101 | Returns:
102 | output: Tensor of shape [B, 1+K, C]
103 | """
104 | outputs = [expert.d(x) for expert in self.experts]
105 | outputs.append(self.forward(x))
106 | return torch.stack(outputs, 1)
107 |
108 | def collect_nll(self, x, y, step=None):
109 | """Default `collect_nll`
110 |
111 | Warning: Parameter-sharing components should implement their own
112 | `collect_nll`
113 |
114 | Returns:
115 | nll: Tensor of shape [B, 1+K]
116 | """
117 | outputs = [expert.d.nll(x, y, step) for expert in self.experts]
118 | nll = outputs
119 | output = self.nll(x, y, step)
120 | nll.append(output)
121 | return torch.stack(nll, dim=1)
122 |
--------------------------------------------------------------------------------
/utils/buffer/aser_retrieve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.buffer.buffer_utils import random_retrieve, ClassBalancedRandomSampling
3 | from utils.buffer.aser_utils import compute_knn_sv
4 | from utils.utils import maybe_cuda
5 | from utils.setup_elements import n_classes
6 |
7 |
8 | class ASER_retrieve(object):
9 | def __init__(self, params, **kwargs):
10 | super().__init__()
11 | self.num_retrieve = params.eps_mem_batch
12 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
13 | self.k = params.k
14 | self.mem_size = params.mem_size
15 | self.aser_type = params.aser_type
16 | self.n_smp_cls = int(params.n_smp_cls)
17 | self.out_dim = n_classes[params.data]
18 | self.is_aser_upt = params.update == "ASER"
19 | ClassBalancedRandomSampling.class_index_cache = None
20 |
21 | def retrieve(self, buffer, **kwargs):
22 | model = buffer.model
23 |
24 | if buffer.n_seen_so_far <= self.mem_size:
25 | # Use random retrieval until buffer is filled
26 | ret_x, ret_y = random_retrieve(buffer, self.num_retrieve)
27 | else:
28 | # Use ASER retrieval if buffer is filled
29 | cur_x, cur_y = kwargs['x'], kwargs['y']
30 | buffer_x, buffer_y = buffer.buffer_img, buffer.buffer_label
31 | ret_x, ret_y = self._retrieve_by_knn_sv(model, buffer_x, buffer_y, cur_x, cur_y, self.num_retrieve)
32 | return ret_x, ret_y
33 |
34 | def _retrieve_by_knn_sv(self, model, buffer_x, buffer_y, cur_x, cur_y, num_retrieve):
35 | """
36 | Retrieves data instances with top-N Shapley Values from candidate set.
37 | Args:
38 | model (object): neural network.
39 | buffer_x (tensor): data buffer.
40 | buffer_y (tensor): label buffer.
41 | cur_x (tensor): current input data tensor.
42 | cur_y (tensor): current input label tensor.
43 | num_retrieve (int): number of data instances to be retrieved.
44 | Returns
45 | ret_x (tensor): retrieved data tensor.
46 | ret_y (tensor): retrieved label tensor.
47 | """
48 | cur_x = maybe_cuda(cur_x)
49 | cur_y = maybe_cuda(cur_y)
50 |
51 | # Reset and update ClassBalancedRandomSampling cache if ASER update is not enabled
52 | if not self.is_aser_upt:
53 | ClassBalancedRandomSampling.update_cache(buffer_y, self.out_dim)
54 |
55 | # Get candidate data for retrieval (i.e., cand <- class balanced subsamples from memory)
56 | cand_x, cand_y, cand_ind = \
57 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, device=self.device)
58 |
59 | # Type 1 - Adversarial SV
60 | # Get evaluation data for type 1 (i.e., eval <- current input)
61 | eval_adv_x, eval_adv_y = cur_x, cur_y
62 | # Compute adversarial Shapley value of candidate data
63 | # (i.e., sv wrt current input)
64 | sv_matrix_adv = compute_knn_sv(model, eval_adv_x, eval_adv_y, cand_x, cand_y, self.k, device=self.device)
65 |
66 | if self.aser_type != "neg_sv":
67 | # Type 2 - Cooperative SV
68 | # Get evaluation data for type 2
69 | # (i.e., eval <- class balanced subsamples from memory excluding those already in candidate set)
70 | excl_indices = set(cand_ind.tolist())
71 | eval_coop_x, eval_coop_y, _ = \
72 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls,
73 | excl_indices=excl_indices, device=self.device)
74 | # Compute Shapley value
75 | sv_matrix_coop = \
76 | compute_knn_sv(model, eval_coop_x, eval_coop_y, cand_x, cand_y, self.k, device=self.device)
77 | if self.aser_type == "asv":
78 | # Use extremal SVs for computation
79 | sv = sv_matrix_coop.max(0).values - sv_matrix_adv.min(0).values
80 | else:
81 | # Use mean variation for aser_type == "asvm" or anything else
82 | sv = sv_matrix_coop.mean(0) - sv_matrix_adv.mean(0)
83 | else:
84 | # aser_type == "neg_sv"
85 | # No Type 1 - Cooperative SV; Use sum of Adversarial SV only
86 | sv = sv_matrix_adv.sum(0) * -1
87 |
88 | ret_ind = sv.argsort(descending=True)
89 |
90 | ret_x = cand_x[ret_ind][:num_retrieve]
91 | ret_y = cand_y[ret_ind][:num_retrieve]
92 | return ret_x, ret_y
93 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def maybe_cuda(what, use_cuda=True, **kw):
5 | """
6 | Moves `what` to CUDA and returns it, if `use_cuda` and it's available.
7 | Args:
8 | what (object): any object to move to eventually gpu
9 | use_cuda (bool): if we want to use gpu or cpu.
10 | Returns
11 | object: the same object but eventually moved to gpu.
12 | """
13 |
14 | if use_cuda is not False and torch.cuda.is_available():
15 | what = what.cuda()
16 | return what
17 |
18 |
19 | def boolean_string(s):
20 | if s not in {'False', 'True'}:
21 | raise ValueError('Not a valid boolean string')
22 | return s == 'True'
23 |
24 |
25 | class AverageMeter(object):
26 | """Computes and stores the average and current value"""
27 |
28 | def __init__(self):
29 | self.reset()
30 |
31 | def reset(self):
32 | self.sum = 0
33 | self.count = 0
34 |
35 | def update(self, val, n):
36 | self.sum += val * n
37 | self.count += n
38 |
39 | def avg(self):
40 | if self.count == 0:
41 | return 0
42 | return float(self.sum) / self.count
43 |
44 |
45 | def mini_batch_deep_features(model, total_x, num):
46 | """
47 | Compute deep features with mini-batches.
48 | Args:
49 | model (object): neural network.
50 | total_x (tensor): data tensor.
51 | num (int): number of data.
52 | Returns
53 | deep_features (tensor): deep feature representation of data tensor.
54 | """
55 | is_train = False
56 | if model.training:
57 | is_train = True
58 | model.eval()
59 | if hasattr(model, "features"):
60 | model_has_feature_extractor = True
61 | else:
62 | model_has_feature_extractor = False
63 | # delete the last fully connected layer
64 | modules = list(model.children())[:-1]
65 | # make feature extractor
66 | model_features = torch.nn.Sequential(*modules)
67 |
68 | with torch.no_grad():
69 | bs = 64
70 | num_itr = num // bs + int(num % bs > 0)
71 | sid = 0
72 | deep_features_list = []
73 | for i in range(num_itr):
74 | eid = sid + bs if i != num_itr - 1 else num
75 | batch_x = total_x[sid: eid]
76 |
77 | if model_has_feature_extractor:
78 | batch_deep_features_ = model.features(batch_x)
79 | else:
80 | batch_deep_features_ = torch.squeeze(model_features(batch_x))
81 |
82 | deep_features_list.append(batch_deep_features_.reshape((batch_x.size(0), -1)))
83 | sid = eid
84 | if num_itr == 1:
85 | deep_features_ = deep_features_list[0]
86 | else:
87 | deep_features_ = torch.cat(deep_features_list, 0)
88 | if is_train:
89 | model.train()
90 | return deep_features_
91 |
92 |
93 | def euclidean_distance(u, v):
94 | euclidean_distance_ = (u - v).pow(2).sum(1)
95 | return euclidean_distance_
96 |
97 |
98 | def ohe_label(label_tensor, dim, device="cpu"):
99 | # Returns one-hot-encoding of input label tensor
100 | n_labels = label_tensor.size(0)
101 | zero_tensor = torch.zeros((n_labels, dim), device=device, dtype=torch.long)
102 | return zero_tensor.scatter_(1, label_tensor.reshape((n_labels, 1)), 1)
103 |
104 |
105 | def nonzero_indices(bool_mask_tensor):
106 | # Returns tensor which contains indices of nonzero elements in bool_mask_tensor
107 | return bool_mask_tensor.nonzero(as_tuple=True)[0]
108 |
109 |
110 | class EarlyStopping():
111 | def __init__(self, min_delta, patience, cumulative_delta):
112 | self.min_delta = min_delta
113 | self.patience = patience
114 | self.cumulative_delta = cumulative_delta
115 | self.counter = 0
116 | self.best_score = None
117 |
118 | def step(self, score):
119 | if self.best_score is None:
120 | self.best_score = score
121 | elif score <= self.best_score + self.min_delta:
122 | if not self.cumulative_delta and score > self.best_score:
123 | self.best_score = score
124 | self.counter += 1
125 | if self.counter >= self.patience:
126 | return True
127 | else:
128 | self.best_score = score
129 | self.counter = 0
130 | return False
131 |
132 | def reset(self):
133 | self.counter = 0
134 | self.best_score = None
135 |
--------------------------------------------------------------------------------
/agents/agem.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils.setup_elements import transforms_match
4 | from torch.utils import data
5 | from utils.buffer.buffer import Buffer
6 | from utils.utils import maybe_cuda, AverageMeter
7 | import torch
8 |
9 |
10 | class AGEM(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(AGEM, self).__init__(model, opt, params)
13 | self.buffer = Buffer(model, params)
14 | self.mem_size = params.mem_size
15 | self.eps_mem_batch = params.eps_mem_batch
16 | self.mem_iters = params.mem_iters
17 |
18 | def train_learner(self, x_train, y_train):
19 | self.before_train(x_train, y_train)
20 |
21 | # set up loader
22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
24 | drop_last=True)
25 | # set up model
26 | self.model = self.model.train()
27 |
28 | # setup tracker
29 | losses_batch = AverageMeter()
30 | acc_batch = AverageMeter()
31 |
32 | for ep in range(self.epoch):
33 | for i, batch_data in enumerate(train_loader):
34 | # batch update
35 | batch_x, batch_y = batch_data
36 | batch_x = maybe_cuda(batch_x, self.cuda)
37 | batch_y = maybe_cuda(batch_y, self.cuda)
38 | for j in range(self.mem_iters):
39 | logits = self.forward(batch_x)
40 | loss = self.criterion(logits, batch_y)
41 | if self.params.trick['kd_trick']:
42 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
43 | self.kd_manager.get_kd_loss(logits, batch_x)
44 | if self.params.trick['kd_trick_star']:
45 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
46 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
47 | _, pred_label = torch.max(logits, 1)
48 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
49 | # update tracker
50 | acc_batch.update(correct_cnt, batch_y.size(0))
51 | losses_batch.update(loss, batch_y.size(0))
52 | # backward
53 | self.opt.zero_grad()
54 | loss.backward()
55 |
56 | if self.task_seen > 0:
57 | # sample from memory of previous tasks
58 | mem_x, mem_y = self.buffer.retrieve()
59 | if mem_x.size(0) > 0:
60 | params = [p for p in self.model.parameters() if p.requires_grad]
61 | # gradient computed using current batch
62 | grad = [p.grad.clone() for p in params]
63 | mem_x = maybe_cuda(mem_x, self.cuda)
64 | mem_y = maybe_cuda(mem_y, self.cuda)
65 | mem_logits = self.forward(mem_x)
66 | loss_mem = self.criterion(mem_logits, mem_y)
67 | self.opt.zero_grad()
68 | loss_mem.backward()
69 | # gradient computed using memory samples
70 | grad_ref = [p.grad.clone() for p in params]
71 |
72 | # inner product of grad and grad_ref
73 | prod = sum([torch.sum(g * g_r) for g, g_r in zip(grad, grad_ref)])
74 | if prod < 0:
75 | prod_ref = sum([torch.sum(g_r ** 2) for g_r in grad_ref])
76 | # do projection
77 | grad = [g - prod / prod_ref * g_r for g, g_r in zip(grad, grad_ref)]
78 | # replace params' grad
79 | for g, p in zip(grad, params):
80 | p.grad.data.copy_(g)
81 | self.opt.step()
82 | # update mem
83 | self.buffer.update(batch_x, batch_y)
84 |
85 | if i % 100 == 1 and self.verbose:
86 | print(
87 | '==>>> it: {}, avg. loss: {:.6f}, '
88 | 'running train acc: {:.3f}'
89 | .format(i, losses_batch.avg(), acc_batch.avg())
90 | )
91 | self.after_train()
--------------------------------------------------------------------------------
/continuum/dataset_scripts/openloris.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from PIL import Image
3 | import numpy as np
4 | from continuum.dataset_scripts.dataset_base import DatasetBase
5 | import time
6 | from continuum.data_utils import shuffle_data
7 |
8 |
9 | class OpenLORIS(DatasetBase):
10 | """
11 | tasks_nums is predefined and it depends on the ns_type.
12 | """
13 | def __init__(self, scenario, params): # scenario refers to "ni" or "nc"
14 | dataset = 'openloris'
15 | self.ns_type = params.ns_type
16 | task_nums = openloris_ntask[self.ns_type] # ns_type can be (illumination, occlusion, pixel, clutter, sequence)
17 | super(OpenLORIS, self).__init__(dataset, scenario, task_nums, params.num_runs, params)
18 |
19 |
20 | def download_load(self):
21 | s = time.time()
22 | self.train_set = []
23 | for batch_num in range(1, self.task_nums+1):
24 | train_x = []
25 | train_y = []
26 | test_x = []
27 | test_y = []
28 | for i in range(len(datapath)):
29 | train_temp = glob.glob('datasets/openloris/' + self.ns_type + '/train/task{}/{}/*.jpg'.format(batch_num, datapath[i]))
30 |
31 | train_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in train_temp])
32 | train_y.extend([i] * len(train_temp))
33 |
34 | test_temp = glob.glob(
35 | 'datasets/openloris/' + self.ns_type + '/test/task{}/{}/*.jpg'.format(batch_num, datapath[i]))
36 |
37 | test_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in test_temp])
38 | test_y.extend([i] * len(test_temp))
39 |
40 | print(" --> batch{}'-dataset consisting of {} samples".format(batch_num, len(train_x)))
41 | print(" --> test'-dataset consisting of {} samples".format(len(test_x)))
42 | self.train_set.append((np.array(train_x), np.array(train_y)))
43 | self.test_set.append((np.array(test_x), np.array(test_y)))
44 | e = time.time()
45 | print('loading time: {}'.format(str(e - s)))
46 |
47 | def new_run(self, **kwargs):
48 | pass
49 |
50 | def new_task(self, cur_task, **kwargs):
51 | train_x, train_y = self.train_set[cur_task]
52 | # get val set
53 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y)
54 | val_size = int(len(train_x_rdm) * self.params.val_size)
55 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size]
56 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:]
57 | self.val_set.append((val_data_rdm, val_label_rdm))
58 | labels = set(train_label_rdm)
59 | return train_data_rdm, train_label_rdm, labels
60 |
61 | def setup(self, **kwargs):
62 | pass
63 |
64 |
65 |
66 | openloris_ntask = {
67 | 'illumination': 9,
68 | 'occlusion': 9,
69 | 'pixel': 9,
70 | 'clutter': 9,
71 | 'sequence': 12
72 | }
73 |
74 | datapath = ['bottle_01', 'bottle_02', 'bottle_03', 'bottle_04', 'bowl_01', 'bowl_02', 'bowl_03', 'bowl_04', 'bowl_05',
75 | 'corkscrew_01', 'cottonswab_01', 'cottonswab_02', 'cup_01', 'cup_02', 'cup_03', 'cup_04', 'cup_05',
76 | 'cup_06', 'cup_07', 'cup_08', 'cup_10', 'cushion_01', 'cushion_02', 'cushion_03', 'glasses_01',
77 | 'glasses_02', 'glasses_03', 'glasses_04', 'knife_01', 'ladle_01', 'ladle_02', 'ladle_03', 'ladle_04',
78 | 'mask_01', 'mask_02', 'mask_03', 'mask_04', 'mask_05', 'paper_cutter_01', 'paper_cutter_02',
79 | 'paper_cutter_03', 'paper_cutter_04', 'pencil_01', 'pencil_02', 'pencil_03', 'pencil_04', 'pencil_05',
80 | 'plasticbag_01', 'plasticbag_02', 'plasticbag_03', 'plug_01', 'plug_02', 'plug_03', 'plug_04', 'pot_01',
81 | 'scissors_01', 'scissors_02', 'scissors_03', 'stapler_01', 'stapler_02', 'stapler_03', 'thermometer_01',
82 | 'thermometer_02', 'thermometer_03', 'toy_01', 'toy_02', 'toy_03', 'toy_04', 'toy_05','nail_clippers_01','nail_clippers_02',
83 | 'nail_clippers_03', 'bracelet_01', 'bracelet_02','bracelet_03', 'comb_01','comb_02',
84 | 'comb_03', 'umbrella_01','umbrella_02','umbrella_03','socks_01','socks_02','socks_03',
85 | 'toothpaste_01','toothpaste_02','toothpaste_03','wallet_01','wallet_02','wallet_03',
86 | 'headphone_01','headphone_02','headphone_03', 'key_01','key_02','key_03',
87 | 'battery_01', 'battery_02', 'mouse_01', 'pencilcase_01', 'pencilcase_02', 'tape_01',
88 | 'chopsticks_01', 'chopsticks_02', 'chopsticks_03',
89 | 'notebook_01', 'notebook_02', 'notebook_03',
90 | 'spoon_01', 'spoon_02', 'spoon_03',
91 | 'tissue_01', 'tissue_02', 'tissue_03',
92 | 'clamp_01', 'clamp_02', 'hat_01', 'hat_02', 'u_disk_01', 'u_disk_02', 'swimming_glasses_01'
93 | ]
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/agents/ewc_pp.py:
--------------------------------------------------------------------------------
1 | from agents.base import ContinualLearner
2 | from continuum.data_utils import dataset_transform
3 | from utils.setup_elements import transforms_match
4 | from torch.utils import data
5 | from utils.utils import maybe_cuda, AverageMeter
6 | import torch
7 |
8 | class EWC_pp(ContinualLearner):
9 | def __init__(self, model, opt, params):
10 | super(EWC_pp, self).__init__(model, opt, params)
11 | self.weights = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
12 | self.lambda_ = params.lambda_
13 | self.alpha = params.alpha
14 | self.fisher_update_after = params.fisher_update_after
15 | self.prev_params = {}
16 | self.running_fisher = self.init_fisher()
17 | self.tmp_fisher = self.init_fisher()
18 | self.normalized_fisher = self.init_fisher()
19 |
20 | def train_learner(self, x_train, y_train):
21 | self.before_train(x_train, y_train)
22 | # set up loader
23 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
24 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
25 | drop_last=True)
26 | # setup tracker
27 | losses_batch = AverageMeter()
28 | acc_batch = AverageMeter()
29 |
30 | # set up model
31 | self.model.train()
32 |
33 | for ep in range(self.epoch):
34 | for i, batch_data in enumerate(train_loader):
35 | # batch update
36 | batch_x, batch_y = batch_data
37 | batch_x = maybe_cuda(batch_x, self.cuda)
38 | batch_y = maybe_cuda(batch_y, self.cuda)
39 |
40 | # update the running fisher
41 | if (ep * len(train_loader) + i + 1) % self.fisher_update_after == 0:
42 | self.update_running_fisher()
43 |
44 | out = self.forward(batch_x)
45 | loss = self.total_loss(out, batch_y)
46 | if self.params.trick['kd_trick']:
47 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
48 | self.kd_manager.get_kd_loss(out, batch_x)
49 | if self.params.trick['kd_trick_star']:
50 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
51 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x)
52 | # update tracker
53 | losses_batch.update(loss.item(), batch_y.size(0))
54 | _, pred_label = torch.max(out, 1)
55 | acc = (pred_label == batch_y).sum().item() / batch_y.size(0)
56 | acc_batch.update(acc, batch_y.size(0))
57 | # backward
58 | self.opt.zero_grad()
59 | loss.backward()
60 |
61 | # accumulate the fisher of current batch
62 | self.accum_fisher()
63 | self.opt.step()
64 |
65 | if i % 100 == 1 and self.verbose:
66 | print(
67 | '==>>> it: {}, avg. loss: {:.6f}, '
68 | 'running train acc: {:.3f}'
69 | .format(i, losses_batch.avg(), acc_batch.avg())
70 | )
71 |
72 | # save params for current task
73 | for n, p in self.weights.items():
74 | self.prev_params[n] = p.clone().detach()
75 |
76 | # update normalized fisher of current task
77 | max_fisher = max([torch.max(m) for m in self.running_fisher.values()])
78 | min_fisher = min([torch.min(m) for m in self.running_fisher.values()])
79 | for n, p in self.running_fisher.items():
80 | self.normalized_fisher[n] = (p - min_fisher) / (max_fisher - min_fisher + 1e-32)
81 | self.after_train()
82 |
83 | def total_loss(self, inputs, targets):
84 | # cross entropy loss
85 | loss = self.criterion(inputs, targets)
86 | if len(self.prev_params) > 0:
87 | # add regularization loss
88 | reg_loss = 0
89 | for n, p in self.weights.items():
90 | reg_loss += (self.normalized_fisher[n] * (p - self.prev_params[n]) ** 2).sum()
91 | loss += self.lambda_ * reg_loss
92 | return loss
93 |
94 | def init_fisher(self):
95 | return {n: p.clone().detach().fill_(0) for n, p in self.model.named_parameters() if p.requires_grad}
96 |
97 | def update_running_fisher(self):
98 | for n, p in self.running_fisher.items():
99 | self.running_fisher[n] = (1. - self.alpha) * p \
100 | + 1. / self.fisher_update_after * self.alpha * self.tmp_fisher[n]
101 | # reset the accumulated fisher
102 | self.tmp_fisher = self.init_fisher()
103 |
104 | def accum_fisher(self):
105 | for n, p in self.tmp_fisher.items():
106 | p += self.weights[n].grad ** 2
--------------------------------------------------------------------------------
/agents/exp_replay.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | from utils.buffer.buffer import Buffer
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform
6 | from utils.setup_elements import transforms_match
7 | from utils.utils import maybe_cuda, AverageMeter
8 |
9 |
10 | class ExperienceReplay(ContinualLearner):
11 | def __init__(self, model, opt, params):
12 | super(ExperienceReplay, self).__init__(model, opt, params)
13 | self.buffer = Buffer(model, params)
14 | self.mem_size = params.mem_size
15 | self.agent = params.agent
16 | self.eps_mem_batch = params.eps_mem_batch
17 | self.mem_iters = params.mem_iters
18 |
19 | def train_learner(self, x_train, y_train):
20 | self.before_train(x_train, y_train)
21 | # set up loader
22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
24 | drop_last=True)
25 | # set up model
26 | self.model = self.model.train()
27 |
28 | # setup tracker
29 | losses_batch = AverageMeter()
30 | losses_mem = AverageMeter()
31 | acc_batch = AverageMeter()
32 | acc_mem = AverageMeter()
33 |
34 | for ep in range(self.epoch):
35 | for i, batch_data in enumerate(train_loader):
36 | # batch update
37 | batch_x, batch_y = batch_data
38 | batch_x = maybe_cuda(batch_x, self.cuda)
39 | batch_y = maybe_cuda(batch_y, self.cuda)
40 | for j in range(self.mem_iters):
41 | logits,_ = self.model.forward(batch_x)
42 | loss = self.criterion(logits, batch_y)
43 | if self.params.trick['kd_trick']:
44 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
45 | self.kd_manager.get_kd_loss(logits, batch_x)
46 | if self.params.trick['kd_trick_star']:
47 | loss = 1/((self.task_seen + 1) ** 0.5) * loss + \
48 | (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
49 | _, pred_label = torch.max(logits, 1)
50 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
51 | # update tracker
52 | acc_batch.update(correct_cnt, batch_y.size(0))
53 | losses_batch.update(loss, batch_y.size(0))
54 | # backward
55 | self.opt.zero_grad()
56 | loss.backward()
57 |
58 | # mem update
59 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
60 | if mem_x.size(0) > 0:
61 | mem_x = maybe_cuda(mem_x, self.cuda)
62 | mem_y = maybe_cuda(mem_y, self.cuda)
63 | mem_logits,_ = self.model.forward(mem_x)
64 | loss_mem = self.criterion(mem_logits, mem_y)
65 | if self.params.trick['kd_trick']:
66 | loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \
67 | self.kd_manager.get_kd_loss(mem_logits, mem_x)
68 | if self.params.trick['kd_trick_star']:
69 | loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \
70 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(mem_logits,
71 | mem_x)
72 | # update tracker
73 | losses_mem.update(loss_mem, mem_y.size(0))
74 | _, pred_label = torch.max(mem_logits, 1)
75 | correct_cnt = (pred_label == mem_y).sum().item() / mem_y.size(0)
76 | acc_mem.update(correct_cnt, mem_y.size(0))
77 |
78 | loss_mem.backward()
79 |
80 | if self.params.update == 'ASER' or self.params.retrieve == 'ASER':
81 | # opt update
82 | self.opt.zero_grad()
83 | combined_batch = torch.cat((mem_x, batch_x))
84 | combined_labels = torch.cat((mem_y, batch_y))
85 | combined_logits = self.model.forward(combined_batch)
86 | loss_combined = self.criterion(combined_logits, combined_labels)
87 | loss_combined.backward()
88 | self.opt.step()
89 | else:
90 | self.opt.step()
91 |
92 | # update mem
93 | self.buffer.update(batch_x, batch_y)
94 |
95 | if i % 100 == 1 and self.verbose:
96 | print(
97 | '==>>> it: {}, avg. loss: {:.6f}, '
98 | 'running train acc: {:.3f}'
99 | .format(i, losses_batch.avg(), acc_batch.avg())
100 | )
101 | print(
102 | '==>>> it: {}, mem avg. loss: {:.6f}, '
103 | 'running mem acc: {:.3f}'
104 | .format(i, losses_mem.avg(), acc_mem.avg())
105 | )
106 | self.after_train()
--------------------------------------------------------------------------------
/continuum/dataset_scripts/core50.py:
--------------------------------------------------------------------------------
1 | import os
2 | from continuum.dataset_scripts.dataset_base import DatasetBase
3 | import pickle as pkl
4 | import logging
5 | from hashlib import md5
6 | import numpy as np
7 | from PIL import Image
8 | from continuum.data_utils import shuffle_data, load_task_with_labels
9 | import time
10 |
11 | core50_ntask = {
12 | 'ni': 8,
13 | 'nc': 9,
14 | 'nic': 79,
15 | 'nicv2_79': 79,
16 | 'nicv2_196': 196,
17 | 'nicv2_391': 391
18 | }
19 |
20 |
21 | class CORE50(DatasetBase):
22 | def __init__(self, scenario, params):
23 | if isinstance(params.num_runs, int) and params.num_runs > 10:
24 | raise Exception('the max number of runs for CORE50 is 10')
25 | dataset = 'core50'
26 | task_nums = core50_ntask[scenario]
27 | super(CORE50, self).__init__(dataset, scenario, task_nums, params.num_runs, params)
28 |
29 |
30 | def download_load(self):
31 |
32 | print("Loading paths...")
33 | with open(os.path.join(self.root, 'paths.pkl'), 'rb') as f:
34 | self.paths = pkl.load(f)
35 |
36 | print("Loading LUP...")
37 | with open(os.path.join(self.root, 'LUP.pkl'), 'rb') as f:
38 | self.LUP = pkl.load(f)
39 |
40 | print("Loading labels...")
41 | with open(os.path.join(self.root, 'labels.pkl'), 'rb') as f:
42 | self.labels = pkl.load(f)
43 |
44 |
45 |
46 | def setup(self, cur_run):
47 | self.val_set = []
48 | self.test_set = []
49 | print('Loading test set...')
50 | test_idx_list = self.LUP[self.scenario][cur_run][-1]
51 |
52 | #test paths
53 | test_paths = []
54 | for idx in test_idx_list:
55 | test_paths.append(os.path.join(self.root, self.paths[idx]))
56 |
57 | # test imgs
58 | self.test_data = self.get_batch_from_paths(test_paths)
59 | self.test_label = np.asarray(self.labels[self.scenario][cur_run][-1])
60 |
61 |
62 | if self.scenario == 'nc':
63 | self.task_labels = self.labels[self.scenario][cur_run][:-1]
64 | for labels in self.task_labels:
65 | labels = list(set(labels))
66 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels)
67 | self.test_set.append((x_test, y_test))
68 | elif self.scenario == 'ni':
69 | self.test_set = [(self.test_data, self.test_label)]
70 |
71 | def new_task(self, cur_task, **kwargs):
72 | cur_run = kwargs['cur_run']
73 | s = time.time()
74 | train_idx_list = self.LUP[self.scenario][cur_run][cur_task]
75 | print("Loading data...")
76 | # Getting the actual paths
77 | train_paths = []
78 | for idx in train_idx_list:
79 | train_paths.append(os.path.join(self.root, self.paths[idx]))
80 | # loading imgs
81 | train_x = self.get_batch_from_paths(train_paths)
82 | train_y = self.labels[self.scenario][cur_run][cur_task]
83 | train_y = np.asarray(train_y)
84 | # get val set
85 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y)
86 | val_size = int(len(train_x_rdm) * self.params.val_size)
87 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size]
88 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:]
89 | self.val_set.append((val_data_rdm, val_label_rdm))
90 | e = time.time()
91 | print('loading time {}'.format(str(e-s)))
92 | return train_data_rdm, train_label_rdm, set(train_label_rdm)
93 |
94 |
95 | def new_run(self, **kwargs):
96 | cur_run = kwargs['cur_run']
97 | self.setup(cur_run)
98 |
99 |
100 | @staticmethod
101 | def get_batch_from_paths(paths, compress=False, snap_dir='',
102 | on_the_fly=True, verbose=False):
103 | """ Given a number of abs. paths it returns the numpy array
104 | of all the images. """
105 |
106 | # Getting root logger
107 | log = logging.getLogger('mylogger')
108 |
109 | # If we do not process data on the fly we check if the same train
110 | # filelist has been already processed and saved. If so, we load it
111 | # directly. In either case we end up returning x and y, as the full
112 | # training set and respective labels.
113 | num_imgs = len(paths)
114 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest()
115 | log.debug("Paths Hex: " + str(hexdigest))
116 | loaded = False
117 | x = None
118 | file_path = None
119 |
120 | if compress:
121 | file_path = snap_dir + hexdigest + ".npz"
122 | if os.path.exists(file_path) and not on_the_fly:
123 | loaded = True
124 | with open(file_path, 'rb') as f:
125 | npzfile = np.load(f)
126 | x, y = npzfile['x']
127 | else:
128 | x_file_path = snap_dir + hexdigest + "_x.bin"
129 | if os.path.exists(x_file_path) and not on_the_fly:
130 | loaded = True
131 | with open(x_file_path, 'rb') as f:
132 | x = np.fromfile(f, dtype=np.uint8) \
133 | .reshape(num_imgs, 128, 128, 3)
134 |
135 | # Here we actually load the images.
136 | if not loaded:
137 | # Pre-allocate numpy arrays
138 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8)
139 |
140 | for i, path in enumerate(paths):
141 | if verbose:
142 | print("\r" + path + " processed: " + str(i + 1), end='')
143 | x[i] = np.array(Image.open(path))
144 |
145 | if verbose:
146 | print()
147 |
148 | if not on_the_fly:
149 | # Then we save x
150 | if compress:
151 | with open(file_path, 'wb') as g:
152 | np.savez_compressed(g, x=x)
153 | else:
154 | x.tofile(snap_dir + hexdigest + "_x.bin")
155 |
156 | assert (x is not None), 'Problems loading data. x is None!'
157 |
158 | return x
159 |
--------------------------------------------------------------------------------
/agents/exp_replay_dvc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | from utils.buffer.buffer import Buffer
4 | from agents.base import ContinualLearner
5 | from continuum.data_utils import dataset_transform
6 | import torch.nn as nn
7 | from utils.setup_elements import transforms_match, input_size_match
8 | from utils.utils import maybe_cuda, AverageMeter
9 | from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale
10 | from loss import agmax_loss, cross_entropy_loss
11 |
12 |
13 | class ExperienceReplay_DVC(ContinualLearner):
14 | def __init__(self, model, opt, params):
15 | super(ExperienceReplay_DVC, self).__init__(model, opt, params)
16 | self.buffer = Buffer(model, params)
17 | self.mem_size = params.mem_size
18 | self.agent = params.agent
19 | self.dl_weight = params.dl_weight
20 |
21 | self.eps_mem_batch = params.eps_mem_batch
22 | self.mem_iters = params.mem_iters
23 | self.transform = nn.Sequential(
24 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)),
25 | RandomHorizontalFlip(),
26 | ColorJitter(0.4, 0.4, 0.4, 0.1),
27 | RandomGrayscale(p=0.2)
28 |
29 | )
30 | self.L2loss = torch.nn.MSELoss()
31 |
32 | def train_learner(self, x_train, y_train):
33 | self.before_train(x_train, y_train)
34 | # set up loader
35 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
36 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
37 | drop_last=True)
38 | # set up model
39 | self.model = self.model.train()
40 | self.transform = self.transform.cuda()
41 | # setup tracker
42 | losses_batch = AverageMeter()
43 | losses_mem = AverageMeter()
44 | acc_batch = AverageMeter()
45 | acc_mem = AverageMeter()
46 |
47 | for ep in range(self.epoch):
48 | for i, batch_data in enumerate(train_loader):
49 | # batch update
50 | batch_x, batch_y = batch_data
51 | batch_x = maybe_cuda(batch_x, self.cuda)
52 | batch_x_aug = self.transform(batch_x)
53 | batch_y = maybe_cuda(batch_y, self.cuda)
54 | for j in range(self.mem_iters):
55 | y = self.model(batch_x, batch_x_aug)
56 | z, zt, _,_ = y
57 | ce = cross_entropy_loss(z, zt, batch_y, label_smoothing=0)
58 |
59 |
60 | agreement_loss, dl = agmax_loss(y, batch_y, dl_weight=self.dl_weight)
61 | loss = ce + agreement_loss + dl
62 |
63 | if self.params.trick['kd_trick']:
64 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
65 | self.kd_manager.get_kd_loss(z, batch_x)
66 | if self.params.trick['kd_trick_star']:
67 | loss = 1/((self.task_seen + 1) ** 0.5) * loss + \
68 | (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(z, batch_x)
69 | _, pred_label = torch.max(z, 1)
70 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
71 | # update tracker
72 | acc_batch.update(correct_cnt, batch_y.size(0))
73 | losses_batch.update(loss, batch_y.size(0))
74 | # backward
75 | self.opt.zero_grad()
76 | loss.backward()
77 |
78 | # mem update
79 | if self.params.retrieve == 'MGI':
80 | mem_x, mem_x_aug, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
81 |
82 | else:
83 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
84 | if mem_x.size(0) > 0:
85 | mem_x_aug = self.transform(mem_x)
86 |
87 | if mem_x.size(0) > 0:
88 | mem_x = maybe_cuda(mem_x, self.cuda)
89 | mem_x_aug = maybe_cuda(mem_x_aug, self.cuda)
90 | mem_y = maybe_cuda(mem_y, self.cuda)
91 | y = self.model(mem_x, mem_x_aug)
92 | z, zt, _,_ = y
93 | ce = cross_entropy_loss(z, zt, mem_y, label_smoothing=0)
94 | agreement_loss, dl = agmax_loss(y, mem_y, dl_weight=self.dl_weight)
95 | loss_mem = ce + agreement_loss + dl
96 |
97 | if self.params.trick['kd_trick']:
98 | loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \
99 | self.kd_manager.get_kd_loss(z, mem_x)
100 | if self.params.trick['kd_trick_star']:
101 | loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \
102 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(z,
103 | mem_x)
104 | # update tracker
105 | losses_mem.update(loss_mem, mem_y.size(0))
106 | _, pred_label = torch.max(z, 1)
107 | correct_cnt = (pred_label == mem_y).sum().item() / mem_y.size(0)
108 | acc_mem.update(correct_cnt, mem_y.size(0))
109 |
110 | loss_mem.backward()
111 | self.opt.step()
112 |
113 | # update mem
114 | self.buffer.update(batch_x, batch_y)
115 |
116 | if i % 100 == 1 and self.verbose:
117 | print(
118 | '==>>> it: {}, avg. loss: {:.6f}, '
119 | 'running train acc: {:.3f}'
120 | .format(i, losses_batch.avg(), acc_batch.avg())
121 | )
122 | print(
123 | '==>>> it: {}, mem avg. loss: {:.6f}, '
124 | 'running mem acc: {:.3f}'
125 | .format(i, losses_mem.avg(), acc_mem.avg())
126 | )
127 | self.after_train()
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/utils/buffer/gss_greedy_update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from utils.buffer.buffer_utils import get_grad_vector, cosine_similarity
5 | from utils.utils import maybe_cuda
6 |
7 | class GSSGreedyUpdate(object):
8 | def __init__(self, params):
9 | super().__init__()
10 | # the number of gradient vectors to estimate new samples similarity, line 5 in alg.2
11 | self.mem_strength = params.gss_mem_strength
12 | self.gss_batch_size = params.gss_batch_size
13 | self.buffer_score = maybe_cuda(torch.FloatTensor(params.mem_size).fill_(0))
14 |
15 | def update(self, buffer, x, y, **kwargs):
16 | buffer.model.eval()
17 |
18 | grad_dims = []
19 | for param in buffer.model.parameters():
20 | grad_dims.append(param.data.numel())
21 |
22 | place_left = buffer.buffer_img.size(0) - buffer.current_index
23 | if place_left <= 0: # buffer is full
24 | batch_sim, mem_grads = self.get_batch_sim(buffer, grad_dims, x, y)
25 | if batch_sim < 0:
26 | buffer_score = self.buffer_score[:buffer.current_index].cpu()
27 | buffer_sim = (buffer_score - torch.min(buffer_score)) / \
28 | ((torch.max(buffer_score) - torch.min(buffer_score)) + 0.01)
29 | # draw candidates for replacement from the buffer
30 | index = torch.multinomial(buffer_sim, x.size(0), replacement=False)
31 | # estimate the similarity of each sample in the recieved batch
32 | # to the randomly drawn samples from the buffer.
33 | batch_item_sim = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y)
34 | # normalize to [0,1]
35 | scaled_batch_item_sim = ((batch_item_sim + 1) / 2).unsqueeze(1)
36 | buffer_repl_batch_sim = ((self.buffer_score[index] + 1) / 2).unsqueeze(1)
37 | # draw an event to decide on replacement decision
38 | outcome = torch.multinomial(torch.cat((scaled_batch_item_sim, buffer_repl_batch_sim), dim=1), 1,
39 | replacement=False)
40 | # replace samples with outcome =1
41 | added_indx = torch.arange(end=batch_item_sim.size(0))
42 | sub_index = outcome.squeeze(1).bool()
43 | buffer.buffer_img[index[sub_index]] = x[added_indx[sub_index]].clone()
44 | buffer.buffer_label[index[sub_index]] = y[added_indx[sub_index]].clone()
45 | self.buffer_score[index[sub_index]] = batch_item_sim[added_indx[sub_index]].clone()
46 | else:
47 | offset = min(place_left, x.size(0))
48 | x = x[:offset]
49 | y = y[:offset]
50 | # first buffer insertion
51 | if buffer.current_index == 0:
52 | batch_sample_memory_cos = torch.zeros(x.size(0)) + 0.1
53 | else:
54 | # draw random samples from buffer
55 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims)
56 | # estimate a score for each added sample
57 | batch_sample_memory_cos = self.get_each_batch_sample_sim(buffer, grad_dims, mem_grads, x, y)
58 | buffer.buffer_img[buffer.current_index:buffer.current_index + offset].data.copy_(x)
59 | buffer.buffer_label[buffer.current_index:buffer.current_index + offset].data.copy_(y)
60 | self.buffer_score[buffer.current_index:buffer.current_index + offset] \
61 | .data.copy_(batch_sample_memory_cos)
62 | buffer.current_index += offset
63 | buffer.model.train()
64 |
65 | def get_batch_sim(self, buffer, grad_dims, batch_x, batch_y):
66 | """
67 | Args:
68 | buffer: memory buffer
69 | grad_dims: gradient dimensions
70 | batch_x: batch images
71 | batch_y: batch labels
72 | Returns: score of current batch, gradient from memory subsets
73 | """
74 | mem_grads = self.get_rand_mem_grads(buffer, grad_dims)
75 | buffer.model.zero_grad()
76 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y)
77 | loss.backward()
78 | batch_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0)
79 | batch_sim = max(cosine_similarity(mem_grads, batch_grad))
80 | return batch_sim, mem_grads
81 |
82 | def get_rand_mem_grads(self, buffer, grad_dims):
83 | """
84 | Args:
85 | buffer: memory buffer
86 | grad_dims: gradient dimensions
87 | Returns: gradient from memory subsets
88 | """
89 | gss_batch_size = min(self.gss_batch_size, buffer.current_index)
90 | num_mem_subs = min(self.mem_strength, buffer.current_index // gss_batch_size)
91 | mem_grads = maybe_cuda(torch.zeros(num_mem_subs, sum(grad_dims), dtype=torch.float32))
92 | shuffeled_inds = torch.randperm(buffer.current_index)
93 | for i in range(num_mem_subs):
94 | random_batch_inds = shuffeled_inds[
95 | i * gss_batch_size:i * gss_batch_size + gss_batch_size]
96 | batch_x = buffer.buffer_img[random_batch_inds]
97 | batch_y = buffer.buffer_label[random_batch_inds]
98 | buffer.model.zero_grad()
99 | loss = F.cross_entropy(buffer.model.forward(batch_x), batch_y)
100 | loss.backward()
101 | mem_grads[i].data.copy_(get_grad_vector(buffer.model.parameters, grad_dims))
102 | return mem_grads
103 |
104 | def get_each_batch_sample_sim(self, buffer, grad_dims, mem_grads, batch_x, batch_y):
105 | """
106 | Args:
107 | buffer: memory buffer
108 | grad_dims: gradient dimensions
109 | mem_grads: gradient from memory subsets
110 | batch_x: batch images
111 | batch_y: batch labels
112 | Returns: score of each sample from current batch
113 | """
114 | cosine_sim = maybe_cuda(torch.zeros(batch_x.size(0)))
115 | for i, (x, y) in enumerate(zip(batch_x, batch_y)):
116 | buffer.model.zero_grad()
117 | ptloss = F.cross_entropy(buffer.model.forward(x.unsqueeze(0)), y.unsqueeze(0))
118 | ptloss.backward()
119 | # add the new grad to the memory grads and add it is cosine similarity
120 | this_grad = get_grad_vector(buffer.model.parameters, grad_dims).unsqueeze(0)
121 | cosine_sim[i] = max(cosine_similarity(mem_grads, this_grad))
122 | return cosine_sim
123 |
--------------------------------------------------------------------------------
/utils/buffer/aser_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import maybe_cuda, mini_batch_deep_features, euclidean_distance, nonzero_indices, ohe_label
3 | from utils.setup_elements import n_classes
4 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling
5 |
6 |
7 | def compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, k, device="cpu"):
8 | """
9 | Compute KNN SV of candidate data w.r.t. evaluation data.
10 | Args:
11 | model (object): neural network.
12 | eval_x (tensor): evaluation data tensor.
13 | eval_y (tensor): evaluation label tensor.
14 | cand_x (tensor): candidate data tensor.
15 | cand_y (tensor): candidate label tensor.
16 | k (int): number of nearest neighbours.
17 | device (str): device for tensor allocation.
18 | Returns
19 | sv_matrix (tensor): KNN Shapley value matrix of candidate data w.r.t. evaluation data.
20 | """
21 | # Compute KNN SV score for candidate samples w.r.t. evaluation samples
22 | n_eval = eval_x.size(0)
23 | n_cand = cand_x.size(0)
24 | # Initialize SV matrix to matrix of -1
25 | sv_matrix = torch.zeros((n_eval, n_cand), device=device)
26 | # Get deep features
27 | eval_df, cand_df = deep_features(model, eval_x, n_eval, cand_x, n_cand)
28 | # Sort indices based on distance in deep feature space
29 | sorted_ind_mat = sorted_cand_ind(eval_df, cand_df, n_eval, n_cand)
30 |
31 | # Evaluation set labels
32 | el = eval_y
33 | el_vec = el.repeat([n_cand, 1]).T
34 | # Sorted candidate set labels
35 | cl = cand_y[sorted_ind_mat]
36 |
37 | # Indicator function matrix
38 | indicator = (el_vec == cl).float()
39 | indicator_next = torch.zeros_like(indicator, device=device)
40 | indicator_next[:, 0:n_cand - 1] = indicator[:, 1:]
41 | indicator_diff = indicator - indicator_next
42 |
43 | cand_ind = torch.arange(n_cand, dtype=torch.float, device=device) + 1
44 | denom_factor = cand_ind.clone()
45 | denom_factor[:n_cand - 1] = denom_factor[:n_cand - 1] * k
46 | numer_factor = cand_ind.clone()
47 | numer_factor[k:n_cand - 1] = k
48 | numer_factor[n_cand - 1] = 1
49 | factor = numer_factor / denom_factor
50 |
51 | indicator_factor = indicator_diff * factor
52 | indicator_factor_cumsum = indicator_factor.flip(1).cumsum(1).flip(1)
53 |
54 | # Row indices
55 | row_ind = torch.arange(n_eval, device=device)
56 | row_mat = torch.repeat_interleave(row_ind, n_cand).reshape([n_eval, n_cand])
57 |
58 | # Compute SV recursively
59 | sv_matrix[row_mat, sorted_ind_mat] = indicator_factor_cumsum
60 |
61 | return sv_matrix
62 |
63 |
64 | def deep_features(model, eval_x, n_eval, cand_x, n_cand):
65 | """
66 | Compute deep features of evaluation and candidate data.
67 | Args:
68 | model (object): neural network.
69 | eval_x (tensor): evaluation data tensor.
70 | n_eval (int): number of evaluation data.
71 | cand_x (tensor): candidate data tensor.
72 | n_cand (int): number of candidate data.
73 | Returns
74 | eval_df (tensor): deep features of evaluation data.
75 | cand_df (tensor): deep features of evaluation data.
76 | """
77 | # Get deep features
78 | if cand_x is None:
79 | num = n_eval
80 | total_x = eval_x
81 | else:
82 | num = n_eval + n_cand
83 | total_x = torch.cat((eval_x, cand_x), 0)
84 |
85 | # compute deep features with mini-batches
86 | total_x = maybe_cuda(total_x)
87 | deep_features_ = mini_batch_deep_features(model, total_x, num)
88 |
89 | eval_df = deep_features_[0:n_eval]
90 | cand_df = deep_features_[n_eval:]
91 | return eval_df, cand_df
92 |
93 |
94 | def sorted_cand_ind(eval_df, cand_df, n_eval, n_cand):
95 | """
96 | Sort indices of candidate data according to
97 | their Euclidean distance to each evaluation data in deep feature space.
98 | Args:
99 | eval_df (tensor): deep features of evaluation data.
100 | cand_df (tensor): deep features of evaluation data.
101 | n_eval (int): number of evaluation data.
102 | n_cand (int): number of candidate data.
103 | Returns
104 | sorted_cand_ind (tensor): sorted indices of candidate set w.r.t. each evaluation data.
105 | """
106 | # Sort indices of candidate set according to distance w.r.t. evaluation set in deep feature space
107 | # Preprocess feature vectors to facilitate vector-wise distance computation
108 | eval_df_repeat = eval_df.repeat([1, n_cand]).reshape([n_eval * n_cand, eval_df.shape[1]])
109 | cand_df_tile = cand_df.repeat([n_eval, 1])
110 | # Compute distance between evaluation and candidate feature vectors
111 | distance_vector = euclidean_distance(eval_df_repeat, cand_df_tile)
112 | # Turn distance vector into distance matrix
113 | distance_matrix = distance_vector.reshape((n_eval, n_cand))
114 | # Sort candidate set indices based on distance
115 | sorted_cand_ind_ = distance_matrix.argsort(1)
116 | return sorted_cand_ind_
117 |
118 |
119 | def add_minority_class_input(cur_x, cur_y, mem_size, num_class):
120 | """
121 | Find input instances from minority classes, and concatenate them to evaluation data/label tensors later.
122 | This facilitates the inclusion of minority class samples into memory when ASER's update method is used under online-class incremental setting.
123 |
124 | More details:
125 |
126 | Evaluation set may not contain any samples from minority classes (i.e., those classes with very few number of corresponding samples stored in the memory).
127 | This happens after task changes in online-class incremental setting.
128 | Minority class samples can then get very low or negative KNN-SV, making it difficult to store any of them in the memory.
129 |
130 | By identifying minority class samples in the current input batch, and concatenating them to the evaluation set,
131 | KNN-SV of the minority class samples can be artificially boosted (i.e., positive value with larger magnitude).
132 | This allows to quickly accomodate new class samples in the memory right after task changes.
133 |
134 | Threshold for being a minority class is a hyper-parameter related to the class proportion.
135 | In this implementation, it is randomly selected between 0 and 1 / number of all classes for each current input batch.
136 |
137 |
138 | Args:
139 | cur_x (tensor): current input data tensor.
140 | cur_y (tensor): current input label tensor.
141 | mem_size (int): memory size.
142 | num_class (int): number of classes in dataset.
143 | Returns
144 | minority_batch_x (tensor): subset of current input data from minority class.
145 | minority_batch_y (tensor): subset of current input label from minority class.
146 | """
147 | # Select input instances from minority classes that will be concatenated to pre-selected data
148 | threshold = torch.tensor(1).float().uniform_(0, 1 / num_class).item()
149 |
150 | # If number of buffered samples from certain class is lower than random threshold,
151 | # that class is minority class
152 | cls_proportion = ClassBalancedRandomSampling.class_num_cache.float() / mem_size
153 | minority_ind = nonzero_indices(cls_proportion[cur_y] < threshold)
154 |
155 | minority_batch_x = cur_x[minority_ind]
156 | minority_batch_y = cur_y[minority_ind]
157 | return minority_batch_x, minority_batch_y
158 |
--------------------------------------------------------------------------------
/models/ndpm/ndpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler
4 |
5 | from utils.utils import maybe_cuda
6 | from utils.global_vars import *
7 | from .expert import Expert
8 | from .priors import CumulativePrior
9 |
10 |
11 | class Ndpm(nn.Module):
12 | def __init__(self, params):
13 | super().__init__()
14 | self.params = params
15 | self.experts = nn.ModuleList([Expert(params)])
16 | self.stm_capacity = params.stm_capacity
17 | self.stm_x = []
18 | self.stm_y = []
19 | self.prior = CumulativePrior(params)
20 |
21 | def get_experts(self):
22 | return tuple(self.experts.children())
23 |
24 | def forward(self, x):
25 | with torch.no_grad():
26 | if len(self.experts) == 1:
27 | raise RuntimeError('There\'s no expert to run on the input')
28 | x = maybe_cuda(x)
29 | log_evid = -self.experts[-1].g.collect_nll(x) # [B, 1+K]
30 | log_evid = log_evid[:, 1:].unsqueeze(2) # [B, K, 1]
31 | log_prior = -self.prior.nl_prior()[1:] # [K]
32 | log_prior -= torch.logsumexp(log_prior, dim=0)
33 | log_prior = log_prior.unsqueeze(0).unsqueeze(2) # [1, K, 1]
34 | log_joint = log_prior + log_evid # [B, K, 1]
35 | if not MODELS_NDPM_NDPM_DISABLE_D:
36 | log_pred = self.experts[-1].d.collect_forward(x) # [B, 1+K, C]
37 | log_pred = log_pred[:, 1:, :] # [B, K, C]
38 | log_joint = log_joint + log_pred # [B, K, C]
39 |
40 | log_joint = log_joint.logsumexp(dim=1).squeeze() # [B,] or [B, C]
41 | return log_joint
42 |
43 |
44 | def learn(self, x, y):
45 | x, y = maybe_cuda(x), maybe_cuda(y)
46 |
47 | if MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS:
48 | self.stm_x.extend(torch.unbind(x.cpu()))
49 | self.stm_y.extend(torch.unbind(y.cpu()))
50 | else:
51 | # Determine the destination of each data point
52 | nll = self.experts[-1].collect_nll(x, y) # [B, 1+K]
53 | nl_prior = self.prior.nl_prior() # [1+K]
54 | nl_joint = nll + nl_prior.unsqueeze(0).expand(
55 | nll.size(0), -1) # [B, 1+K]
56 |
57 | # Save to short-term memory
58 | destination = maybe_cuda(torch.argmin(nl_joint, dim=1)) # [B]
59 | to_stm = destination == 0 # [B]
60 | self.stm_x.extend(torch.unbind(x[to_stm].cpu()))
61 | self.stm_y.extend(torch.unbind(y[to_stm].cpu()))
62 |
63 | # Train expert
64 | with torch.no_grad():
65 | min_joint = nl_joint.min(dim=1)[0].view(-1, 1)
66 | to_expert = torch.exp(-nl_joint + min_joint) # [B, 1+K]
67 | to_expert[:, 0] = 0. # [B, 1+K]
68 | to_expert = \
69 | to_expert / (to_expert.sum(dim=1).view(-1, 1) + 1e-7)
70 |
71 | # Compute losses per expert
72 | nll_for_train = nll * (1. - to_stm.float()).unsqueeze(1) # [B,1+K]
73 | losses = (nll_for_train * to_expert).sum(0) # [1+K]
74 |
75 | # Record expert usage
76 | expert_usage = to_expert.sum(dim=0) # [K+1]
77 | self.prior.record_usage(expert_usage)
78 |
79 | # Do lr_decay implicitly
80 | if MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY:
81 | losses = losses \
82 | * self.params.stm_capacity / (self.prior.counts + 1e-8)
83 | loss = losses.sum()
84 |
85 | if loss.requires_grad:
86 | update_threshold = 0
87 | for k, usage in enumerate(expert_usage):
88 | if usage > update_threshold:
89 | self.experts[k].zero_grad()
90 | loss.backward()
91 | for k, usage in enumerate(expert_usage):
92 | if usage > update_threshold:
93 | self.experts[k].clip_grad()
94 | self.experts[k].optimizer_step()
95 | self.experts[k].lr_scheduler_step()
96 |
97 | # Sleep
98 | if len(self.stm_x) >= self.stm_capacity:
99 | dream_dataset = TensorDataset(
100 | torch.stack(self.stm_x), torch.stack(self.stm_y))
101 | self.sleep(dream_dataset)
102 | self.stm_x = []
103 | self.stm_y = []
104 |
105 | def sleep(self, dream_dataset):
106 | print('\nGoing to sleep...')
107 | # Add new expert and optimizer
108 | expert = Expert(self.params, self.get_experts())
109 | self.experts.append(expert)
110 | self.prior.add_expert()
111 |
112 | stacked_stm_x = torch.stack(self.stm_x)
113 | stacked_stm_y = torch.stack(self.stm_y)
114 | indices = torch.randperm(stacked_stm_x.size(0))
115 | train_size = stacked_stm_x.size(0) - MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE
116 | dream_dataset = TensorDataset(
117 | stacked_stm_x[indices[:train_size]],
118 | stacked_stm_y[indices[:train_size]])
119 |
120 | # Prepare data iterator
121 | self.prior.record_usage(len(dream_dataset), index=-1)
122 | dream_iterator = iter(DataLoader(
123 | dream_dataset,
124 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE,
125 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS,
126 | sampler=RandomSampler(
127 | dream_dataset,
128 | replacement=True,
129 | num_samples=(
130 | MODELS_NDPM_NDPM_SLEEP_STEP_G *
131 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE
132 | ))
133 | ))
134 |
135 | # Train generative component
136 | for step, (x, y) in enumerate(dream_iterator):
137 | step += 1
138 | x, y = maybe_cuda(x), maybe_cuda(y)
139 | g_loss = expert.g.nll(x, y, step=step)
140 | g_loss = (g_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY
141 | * expert.g.weight_decay_loss())
142 | expert.g.zero_grad()
143 | g_loss.mean().backward()
144 | expert.g.clip_grad()
145 | expert.g.optimizer.step()
146 |
147 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0:
148 | print('\r [Sleep-G %6d] loss: %5.1f' % (
149 | step, g_loss.mean()
150 | ), end='')
151 | print()
152 |
153 | dream_iterator = iter(DataLoader(
154 | dream_dataset,
155 | batch_size=MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE,
156 | num_workers=MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS,
157 | sampler=RandomSampler(
158 | dream_dataset,
159 | replacement=True,
160 | num_samples=(
161 | MODELS_NDPM_NDPM_SLEEP_STEP_D *
162 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE)
163 | )
164 | ))
165 |
166 | # Train discriminative component
167 | if not MODELS_NDPM_NDPM_DISABLE_D:
168 | for step, (x, y) in enumerate(dream_iterator):
169 | step += 1
170 | x, y = maybe_cuda(x), maybe_cuda(y)
171 | d_loss = expert.d.nll(x, y, step=step)
172 | d_loss = (d_loss + MODELS_NDPM_NDPM_WEIGHT_DECAY
173 | * expert.d.weight_decay_loss())
174 | expert.d.zero_grad()
175 | d_loss.mean().backward()
176 | expert.d.clip_grad()
177 | expert.d.optimizer.step()
178 |
179 | if step % MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP == 0:
180 | print('\r [Sleep-D %6d] loss: %5.1f' % (
181 | step, d_loss.mean()
182 | ), end='')
183 |
184 | expert.lr_scheduler_step()
185 | expert.lr_scheduler_step()
186 | expert.eval()
187 | print()
188 |
189 | @staticmethod
190 | def _nl_joint(nl_prior, nll):
191 | batch = nll.size(0)
192 | nl_prior = nl_prior.unsqueeze(0).expand(batch, -1) # [B, 1+K]
193 | return nll + nl_prior
194 |
195 | def train(self, mode=True):
196 | # Disabled
197 | pass
198 |
--------------------------------------------------------------------------------
/features/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | #from .utils import load_state_dict_from_url
3 |
4 | from . import extractor
5 |
6 | __all__ = [
7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
8 | 'vgg19_bn', 'vgg19',
9 | ]
10 |
11 |
12 | model_urls = {
13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
21 | }
22 |
23 |
24 | class VGG(extractor.BaseModule):
25 |
26 | def __init__(self, config, name): #features, num_classes=1000, init_weights=True):
27 | super(VGG, self).__init__()
28 |
29 | #self.features = features
30 | self.name = name
31 | cfg = config["cfg"]
32 | in_channels = config["channels"]
33 | batch_norm = config["batch_norm"]
34 | self.features = make_layers(cfgs[cfg], batch_norm=batch_norm, in_channels=in_channels)
35 | self.n_features = 512
36 | #self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
37 | #self.classifier = nn.Sequential(
38 | # nn.Linear(512 * 7 * 7, 4096),
39 | # nn.ReLU(True),
40 | # nn.Dropout(),
41 | # nn.Linear(4096, 4096),
42 | # nn.ReLU(True),
43 | # nn.Dropout(),
44 | # nn.Linear(4096, num_classes),
45 | #)
46 | #if init_weights:
47 | #self._initialize_weights()
48 |
49 | def forward(self, x):
50 | x = self.features(x)
51 | #x = self.avgpool(x)
52 | #x = torch.flatten(x, 1)
53 | #x = self.classifier(x)
54 | return x
55 |
56 |
57 | def _initialize_weights(self):
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
61 | if m.bias is not None:
62 | nn.init.constant_(m.bias, 0)
63 | elif isinstance(m, nn.BatchNorm2d):
64 | nn.init.constant_(m.weight, 1)
65 | nn.init.constant_(m.bias, 0)
66 | elif isinstance(m, nn.Linear):
67 | nn.init.normal_(m.weight, 0, 0.01)
68 | nn.init.constant_(m.bias, 0)
69 |
70 |
71 | def make_layers(cfg, batch_norm=False, in_channels=3):
72 | layers = []
73 | #in_channels = 3
74 | for v in cfg:
75 | if v == 'M':
76 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
77 | else:
78 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
79 | if batch_norm:
80 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
81 | else:
82 | layers += [conv2d, nn.ReLU(inplace=True)]
83 | in_channels = v
84 | return nn.Sequential(*layers)
85 |
86 |
87 | cfgs = {
88 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
89 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
90 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
91 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
92 | }
93 |
94 |
95 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
96 | if pretrained:
97 | kwargs['init_weights'] = False
98 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
99 | if pretrained:
100 | state_dict = load_state_dict_from_url(model_urls[arch],
101 | progress=progress)
102 | model.load_state_dict(state_dict)
103 | return model
104 |
105 |
106 | def vgg11(pretrained=False, progress=True, **kwargs):
107 | r"""VGG 11-layer model (configuration "A") from
108 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
109 |
110 | Args:
111 | pretrained (bool): If True, returns a model pre-trained on ImageNet
112 | progress (bool): If True, displays a progress bar of the download to stderr
113 | """
114 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
115 |
116 |
117 | def vgg11_bn(pretrained=False, progress=True, **kwargs):
118 | r"""VGG 11-layer model (configuration "A") with batch normalization
119 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
120 |
121 | Args:
122 | pretrained (bool): If True, returns a model pre-trained on ImageNet
123 | progress (bool): If True, displays a progress bar of the download to stderr
124 | """
125 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
126 |
127 |
128 | def vgg13(pretrained=False, progress=True, **kwargs):
129 | r"""VGG 13-layer model (configuration "B")
130 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
131 |
132 | Args:
133 | pretrained (bool): If True, returns a model pre-trained on ImageNet
134 | progress (bool): If True, displays a progress bar of the download to stderr
135 | """
136 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
137 |
138 |
139 | def vgg13_bn(pretrained=False, progress=True, **kwargs):
140 | r"""VGG 13-layer model (configuration "B") with batch normalization
141 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
142 |
143 | Args:
144 | pretrained (bool): If True, returns a model pre-trained on ImageNet
145 | progress (bool): If True, displays a progress bar of the download to stderr
146 | """
147 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
148 |
149 |
150 | def vgg16(pretrained=False, progress=True, **kwargs):
151 | r"""VGG 16-layer model (configuration "D")
152 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
153 |
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | progress (bool): If True, displays a progress bar of the download to stderr
157 | """
158 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
159 |
160 |
161 | def vgg16_bn(pretrained=False, progress=True, **kwargs):
162 | r"""VGG 16-layer model (configuration "D") with batch normalization
163 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
164 |
165 | Args:
166 | pretrained (bool): If True, returns a model pre-trained on ImageNet
167 | progress (bool): If True, displays a progress bar of the download to stderr
168 | """
169 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
170 |
171 |
172 | def vgg19(pretrained=False, progress=True, **kwargs):
173 | r"""VGG 19-layer model (configuration "E")
174 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
175 |
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | progress (bool): If True, displays a progress bar of the download to stderr
179 | """
180 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
181 |
182 |
183 | def vgg19_bn(pretrained=False, progress=True, **kwargs):
184 | r"""VGG 19-layer model (configuration 'E') with batch normalization
185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
186 |
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | progress (bool): If True, displays a progress bar of the download to stderr
190 | """
191 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
192 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/facebookresearch/GradientEpisodicMemory
3 | &
4 | https://github.com/kuangliu/pytorch-cifar
5 | """
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from torch.nn.functional import relu, avg_pool2d
9 | import torch
10 | from features.extractor import BaseModule
11 | def conv3x3(in_planes, out_planes, stride=1):
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, in_planes, planes, stride=1):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = conv3x3(in_planes, planes, stride)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 |
25 | self.shortcut = nn.Sequential()
26 | if stride != 1 or in_planes != self.expansion * planes:
27 | self.shortcut = nn.Sequential(
28 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
29 | stride=stride, bias=False),
30 | nn.BatchNorm2d(self.expansion * planes)
31 | )
32 |
33 | def forward(self, x):
34 | out = relu(self.bn1(self.conv1(x)))
35 | out = self.bn2(self.conv2(out))
36 | out += self.shortcut(x)
37 | out = relu(out)
38 | return out
39 |
40 | class Bottleneck(nn.Module):
41 | expansion = 4
42 |
43 | def __init__(self, in_planes, planes, stride=1):
44 | super(Bottleneck, self).__init__()
45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
46 | self.bn1 = nn.BatchNorm2d(planes)
47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
48 | stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, self.expansion *
51 | planes, kernel_size=1, bias=False)
52 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
53 |
54 | self.shortcut = nn.Sequential()
55 | if stride != 1 or in_planes != self.expansion * planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, self.expansion * planes,
58 | kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(self.expansion * planes)
60 | )
61 |
62 | def forward(self, x):
63 | out = relu(self.bn1(self.conv1(x)))
64 | out = relu(self.bn2(self.conv2(out)))
65 | out = self.bn3(self.conv3(out))
66 | out += self.shortcut(x)
67 | out = relu(out)
68 | return out
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, block, num_blocks, num_classes, nf, bias):
72 | super(ResNet, self).__init__()
73 | self.in_planes = nf
74 | self.conv1 = conv3x3(3, nf * 1)
75 | self.bn1 = nn.BatchNorm2d(nf * 1)
76 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
77 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
78 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
79 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
80 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes, bias=bias)
81 |
82 |
83 | def _make_layer(self, block, planes, num_blocks, stride):
84 | strides = [stride] + [1] * (num_blocks - 1)
85 | layers = []
86 | for stride in strides:
87 | layers.append(block(self.in_planes, planes, stride))
88 | self.in_planes = planes * block.expansion
89 | return nn.Sequential(*layers)
90 |
91 | def features(self, x):
92 | '''Features before FC layers'''
93 | out = relu(self.bn1(self.conv1(x)))
94 | out = self.layer1(out)
95 | out = self.layer2(out)
96 | out = self.layer3(out)
97 | out = self.layer4(out)
98 | out = avg_pool2d(out, 4)
99 | out = out.contiguous().view(out.size(0), -1)
100 | return out
101 |
102 | def logits(self, x):
103 | '''Apply the last FC linear mapping to get logits'''
104 | x = self.linear(x)
105 | return x
106 |
107 | def forward(self, x):
108 | out = self.features(x)
109 | logits = self.logits(out)
110 | return logits,out
111 |
112 |
113 | class QNet(BaseModule):
114 | def __init__(self,
115 | n_units,
116 | n_classes):
117 | super(QNet, self).__init__()
118 |
119 | self.model = nn.Sequential(
120 | nn.Linear(2 * n_classes, n_units),
121 | nn.ReLU(True),
122 | nn.Linear(n_units, n_classes),
123 | )
124 |
125 | def forward(self, zcat):
126 | zzt = self.model(zcat)
127 | return zzt
128 |
129 |
130 | class DVCNet(BaseModule):
131 | def __init__(self,
132 | backbone,
133 | n_units,
134 | n_classes,
135 | has_mi_qnet=True):
136 | super(DVCNet, self).__init__()
137 |
138 | self.backbone = backbone
139 | self.has_mi_qnet = has_mi_qnet
140 |
141 | if has_mi_qnet:
142 | self.qnet = QNet(n_units=n_units,
143 | n_classes=n_classes)
144 |
145 | def forward(self, x, xt):
146 | size = x.size(0)
147 | xx = torch.cat((x, xt))
148 | zz,fea = self.backbone(xx)
149 | z = zz[0:size]
150 | zt = zz[size:]
151 |
152 | fea_z = fea[0:size]
153 | fea_zt = fea[size:]
154 |
155 | if not self.has_mi_qnet:
156 | return z, zt, None
157 |
158 | zcat = torch.cat((z, zt), dim=1)
159 | zzt = self.qnet(zcat)
160 |
161 | return z, zt, zzt,[torch.sum(torch.abs(fea_z), 1).reshape(-1, 1),torch.sum(torch.abs(fea_zt), 1).reshape(-1, 1)]
162 |
163 |
164 | def Reduced_ResNet18_DVC(nclasses, nf=20, bias=True):
165 | """
166 | Reduced ResNet18 as in GEM MIR(note that nf=20).
167 | """
168 | backnone = ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias)
169 | return DVCNet(backbone=backnone,n_units=128,n_classes=nclasses,has_mi_qnet=True)
170 |
171 |
172 | def Reduced_ResNet18(nclasses, nf=20, bias=True):
173 | """
174 | Reduced ResNet18 as in GEM MIR(note that nf=20).
175 | """
176 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias)
177 |
178 |
179 |
180 | def ResNet18(nclasses, nf=64, bias=True):
181 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, bias)
182 |
183 | '''
184 | See https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
185 | '''
186 |
187 | def ResNet34(nclasses, nf=64, bias=True):
188 | return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf, bias)
189 |
190 | def ResNet50(nclasses, nf=64, bias=True):
191 | return ResNet(Bottleneck, [3, 4, 6, 3], nclasses, nf, bias)
192 |
193 |
194 | def ResNet101(nclasses, nf=64, bias=True):
195 | return ResNet(Bottleneck, [3, 4, 23, 3], nclasses, nf, bias)
196 |
197 |
198 | def ResNet152(nclasses, nf=64, bias=True):
199 | return ResNet(Bottleneck, [3, 8, 36, 3], nclasses, nf, bias)
200 |
201 |
202 | class SupConResNet(nn.Module):
203 | """backbone + projection head"""
204 | def __init__(self, dim_in=160, head='mlp', feat_dim=128):
205 | super(SupConResNet, self).__init__()
206 | self.encoder = Reduced_ResNet18(100)
207 | if head == 'linear':
208 | self.head = nn.Linear(dim_in, feat_dim)
209 | elif head == 'mlp':
210 | self.head = nn.Sequential(
211 | nn.Linear(dim_in, dim_in),
212 | nn.ReLU(inplace=True),
213 | nn.Linear(dim_in, feat_dim)
214 | )
215 | elif head == 'None':
216 | self.head = None
217 | else:
218 | raise NotImplementedError(
219 | 'head not supported: {}'.format(head))
220 |
221 | def forward(self, x):
222 | feat = self.encoder.features(x)
223 | logit = self.encoder.linear(feat)
224 | if self.head:
225 | feat = F.normalize(self.head(feat), dim=1)
226 | else:
227 | feat = F.normalize(feat, dim=1)
228 | return feat
229 |
230 | def features(self, x):
231 | return self.encoder.features(x)
--------------------------------------------------------------------------------
/continuum/non_stationary.py:
--------------------------------------------------------------------------------
1 | import random
2 | from copy import deepcopy
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from skimage.filters import gaussian
6 | from continuum.data_utils import train_val_test_split_ni
7 |
8 |
9 | class Original(object):
10 | def __init__(self, x, y, unroll=False, color=False):
11 | if color:
12 | self.x = x / 255.0
13 | else:
14 | self.x = x
15 | self.next_x = self.x
16 | self.next_y = y
17 | self.y = y
18 | self.unroll = unroll
19 |
20 | def get_dims(self):
21 | # Get data input and output dimensions
22 | print("input size {}\noutput size {}".format(self.x.shape[1], self.y.shape[1]))
23 | return self.x.shape[1], self.y.shape[1]
24 |
25 | def show_sample(self, num_plot=1):
26 | # idx = np.random.choice(self.x.shape[0])
27 | for i in range(num_plot):
28 | plt.subplot(1, 2, 1)
29 | if self.x[i].shape[2] == 1:
30 | plt.imshow(np.squeeze(self.x[i]))
31 | else:
32 | plt.imshow(self.x[i])
33 | plt.title("original task image")
34 | plt.subplot(1, 2, 2)
35 | if self.x[i].shape[2] == 1:
36 | plt.imshow(np.squeeze(self.next_x[i]))
37 | else:
38 | plt.imshow(self.next_x[i])
39 | plt.title(self.get_name())
40 | plt.axis('off')
41 | plt.show()
42 |
43 | def create_output(self):
44 | if self.unroll:
45 | ret = self.next_x.reshape((-1, self.x.shape[1] ** 2)), self.next_y
46 | else:
47 | ret = self.next_x, self.next_y
48 | return ret
49 |
50 | @staticmethod
51 | def clip_minmax(l, min_=0., max_=1.):
52 | return np.clip(l, min_, max_)
53 |
54 | def get_name(self):
55 | if hasattr(self, 'factor'):
56 | return str(self.__class__.__name__) + '_' + str(self.factor)
57 |
58 | def next_task(self, *args):
59 | self.next_x = self.x
60 | self.next_y = self.y
61 | return self.create_output()
62 |
63 |
64 | class Noisy(Original):
65 | def __init__(self, x, y, full=False, color=False):
66 | super(Noisy, self).__init__(x, y, full, color)
67 |
68 | def next_task(self, noise_factor=0.8, sig=0.1, noise_type='Gaussian'):
69 | next_x = deepcopy(self.x)
70 | self.factor = noise_factor
71 | if noise_type == 'Gaussian':
72 | self.next_x = next_x + noise_factor * np.random.normal(loc=0.0, scale=sig, size=next_x.shape)
73 | elif noise_factor == 'S&P':
74 | # TODO implement S&P
75 | pass
76 |
77 | self.next_x = super().clip_minmax(self.next_x, 0, 1)
78 |
79 | return super().create_output()
80 |
81 |
82 | class Blurring(Original):
83 | def __init__(self, x, y, full=False, color=False):
84 | super(Blurring, self).__init__(x, y, full, color)
85 |
86 | def next_task(self, blurry_factor=0.6, blurry_type='Gaussian'):
87 | next_x = deepcopy(self.x)
88 | self.factor = blurry_factor
89 | if blurry_type == 'Gaussian':
90 | self.next_x = gaussian(next_x, sigma=blurry_factor, multichannel=True)
91 | elif blurry_type == 'Average':
92 | pass
93 | # TODO implement average
94 |
95 | self.next_x = super().clip_minmax(self.next_x, 0, 1)
96 |
97 | return super().create_output()
98 |
99 |
100 | class Occlusion(Original):
101 | def __init__(self, x, y, full=False, color=False):
102 | super(Occlusion, self).__init__(x, y, full, color)
103 |
104 | def next_task(self, occlusion_factor=0.2):
105 | next_x = deepcopy(self.x)
106 | self.factor = occlusion_factor
107 | self.image_size = next_x.shape[1]
108 |
109 | occlusion_size = int(occlusion_factor * self.image_size)
110 | half_size = occlusion_size // 2
111 | occlusion_x = random.randint(min(half_size, self.image_size - half_size),
112 | max(half_size, self.image_size - half_size))
113 | occlusion_y = random.randint(min(half_size, self.image_size - half_size),
114 | max(half_size, self.image_size - half_size))
115 |
116 | # self.next_x = next_x.reshape((-1, self.image_size, self.image_size))
117 |
118 | next_x[:, max((occlusion_x - half_size), 0):min((occlusion_x + half_size), self.image_size), \
119 | max((occlusion_y - half_size), 0):min((occlusion_y + half_size), self.image_size)] = 1
120 |
121 | self.next_x = next_x
122 | super().clip_minmax(self.next_x, 0, 1)
123 |
124 | return super().create_output()
125 |
126 |
127 | def test_ns(x, y, ns_type, change_factor):
128 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring}
129 | change = ns_match[ns_type]
130 | tmp = change(x, y, color=True)
131 | tmp.next_task(change_factor)
132 | tmp.show_sample(10)
133 |
134 |
135 | ns_match = {'noise': Noisy, 'occlusion': Occlusion, 'blur': Blurring}
136 |
137 |
138 | def construct_ns_single(train_x_split, train_y_split, test_x_split, test_y_split, ns_type, change_factor, ns_task,
139 | plot=True):
140 | # Data splits
141 | train_list = []
142 | test_list = []
143 | change = ns_match[ns_type]
144 | i = 0
145 | if len(change_factor) == 1:
146 | change_factor = change_factor[0]
147 | for idx, val in enumerate(ns_task):
148 | if idx % 2 == 0:
149 | for _ in range(val):
150 | print(i, 'normal')
151 | # train
152 | tmp = Original(train_x_split[i], train_y_split[i], color=True)
153 | train_list.append(tmp.next_task())
154 | if plot:
155 | tmp.show_sample()
156 |
157 | # test
158 | tmp_test = Original(test_x_split[i], test_y_split[i], color=True)
159 | test_list.append(tmp_test.next_task())
160 | if plot:
161 | tmp_test.show_sample()
162 |
163 | i += 1
164 | else:
165 | for _ in range(val):
166 | print(i, 'change')
167 | # train
168 | tmp = change(train_x_split[i], train_y_split[i], color=True)
169 | train_list.append(tmp.next_task(change_factor))
170 | if plot:
171 | tmp.show_sample()
172 | # test
173 | tmp_test = change(test_x_split[i], test_y_split[i], color=True)
174 | test_list.append(tmp_test.next_task(change_factor))
175 | if plot:
176 | tmp_test.show_sample()
177 |
178 | i += 1
179 | return train_list, test_list
180 |
181 |
182 | def construct_ns_multiple(train_x_split, train_y_split, val_x_rdm_split, val_y_rdm_split, test_x_split,
183 | test_y_split, ns_type, change_factors, plot):
184 | train_list = []
185 | val_list = []
186 | test_list = []
187 | ns_len = len(change_factors)
188 | for i in range(ns_len):
189 | factor = change_factors[i]
190 | if factor == 0:
191 | ns_generator = Original
192 | else:
193 | ns_generator = ns_match[ns_type]
194 | print(i, factor)
195 | # train
196 | tmp = ns_generator(train_x_split[i], train_y_split[i], color=True)
197 | train_list.append(tmp.next_task(factor))
198 | if plot:
199 | tmp.show_sample()
200 |
201 | tmp_val = ns_generator(val_x_rdm_split[i], val_y_rdm_split[i], color=True)
202 | val_list.append(tmp_val.next_task(factor))
203 |
204 | tmp_test = ns_generator(test_x_split[i], test_y_split[i], color=True)
205 | test_list.append(tmp_test.next_task(factor))
206 | return train_list, val_list, test_list
207 |
208 |
209 | def construct_ns_multiple_wrapper(train_data, train_label, test_data, est_label, task_nums, img_size,
210 | val_size, ns_type, ns_factor, plot):
211 | train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split = train_val_test_split_ni(
212 | train_data, train_label, test_data, est_label, task_nums, img_size,
213 | val_size)
214 | train_set, val_set, test_set = construct_ns_multiple(train_data_rdm_split, train_label_rdm_split,
215 | val_data_rdm_split, val_label_rdm_split,
216 | test_data_rdm_split, test_label_rdm_split,
217 | ns_type,
218 | ns_factor,
219 | plot=plot)
220 | return train_set, val_set, test_set
221 |
--------------------------------------------------------------------------------
/models/ndpm/classifier.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from utils.utils import maybe_cuda
6 | from .component import ComponentD
7 | from utils.global_vars import *
8 | from utils.setup_elements import n_classes
9 |
10 |
11 | class Classifier(ComponentD, ABC):
12 | def __init__(self, params, experts):
13 | super().__init__(params, experts)
14 | self.ce_loss = nn.NLLLoss(reduction='none')
15 |
16 | @abstractmethod
17 | def forward(self, x):
18 | """Output log P(y|x)"""
19 | pass
20 |
21 | def nll(self, x, y, step=None):
22 | x, y = maybe_cuda(x), maybe_cuda(y)
23 | log_softmax = self.forward(x)
24 | loss_pred = self.ce_loss(log_softmax, y)
25 |
26 | # Classifier chilling
27 | chilled_log_softmax = F.log_softmax(
28 | log_softmax / self.params.classifier_chill, dim=1)
29 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y)
30 |
31 | # Value with chill & gradient without chill
32 | loss_pred = loss_pred - loss_pred.detach() \
33 | + chilled_loss_pred.detach()
34 |
35 | return loss_pred
36 |
37 |
38 | class SharingClassifier(Classifier, ABC):
39 | @abstractmethod
40 | def forward(self, x, collect=False):
41 | pass
42 |
43 | def collect_forward(self, x):
44 | dummy_pred = self.experts[0](x)
45 | preds, _ = self.forward(x, collect=True)
46 | return torch.stack([dummy_pred] + preds, dim=1)
47 |
48 | def collect_nll(self, x, y, step=None):
49 | preds = self.collect_forward(x) # [B, 1+K, C]
50 | loss_preds = []
51 | for log_softmax in preds.unbind(dim=1):
52 | loss_pred = self.ce_loss(log_softmax, y)
53 |
54 | # Classifier chilling
55 | chilled_log_softmax = F.log_softmax(
56 | log_softmax / self.params.classifier_chill, dim=1)
57 | chilled_loss_pred = self.ce_loss(chilled_log_softmax, y)
58 |
59 | # Value with chill & gradient without chill
60 | loss_pred = loss_pred - loss_pred.detach() \
61 | + chilled_loss_pred.detach()
62 |
63 | loss_preds.append(loss_pred)
64 | return torch.stack(loss_preds, dim=1)
65 |
66 |
67 | class BasicBlock(nn.Module):
68 | expansion = 1
69 |
70 | def __init__(self, inplanes, planes, stride=1,
71 | downsample=None, upsample=None,
72 | dilation=1, norm_layer=nn.BatchNorm2d):
73 | super(BasicBlock, self).__init__()
74 | if dilation > 1:
75 | raise NotImplementedError(
76 | "Dilation > 1 not supported in BasicBlock"
77 | )
78 | transpose = upsample is not None and stride != 1
79 | self.conv1 = (
80 | conv4x4t(inplanes, planes, stride) if transpose else
81 | conv3x3(inplanes, planes, stride)
82 | )
83 | self.bn1 = norm_layer(planes)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.conv2 = conv3x3(planes, planes)
86 | self.bn2 = norm_layer(planes)
87 | self.downsample = downsample
88 | self.upsample = upsample
89 | self.stride = stride
90 |
91 | def forward(self, x):
92 | identity = x
93 |
94 | out = self.conv1(x)
95 | out = self.bn1(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv2(out)
99 | out = self.bn2(out)
100 |
101 | if self.downsample is not None:
102 | identity = self.downsample(x)
103 | elif self.upsample is not None:
104 | identity = self.upsample(x)
105 |
106 | out += identity
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 |
112 | def conv4x4t(in_planes, out_planes, stride=1, groups=1, dilation=1):
113 | """4x4 transposed convolution with padding"""
114 | return nn.ConvTranspose2d(
115 | in_planes, out_planes, kernel_size=4, stride=stride,
116 | padding=dilation, groups=groups, bias=False, dilation=dilation,
117 | )
118 |
119 |
120 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
121 | """3x3 convolution with padding"""
122 | return nn.Conv2d(
123 | in_planes, out_planes, kernel_size=3, stride=stride,
124 | padding=dilation, groups=groups, bias=False, dilation=dilation
125 | )
126 |
127 |
128 | def conv1x1(in_planes, out_planes, stride=1):
129 | """1x1 convolution"""
130 | return nn.Conv2d(
131 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False
132 | )
133 |
134 |
135 | class ResNetSharingClassifier(SharingClassifier):
136 | block = BasicBlock
137 | num_blocks = [2, 2, 2, 2]
138 | norm_layer = nn.InstanceNorm2d
139 |
140 | def __init__(self, params, experts):
141 | super().__init__(params, experts)
142 | self.precursors = [expert.d for expert in self.experts[1:]]
143 | first = len(self.precursors) == 0
144 |
145 | if MODELS_NDPM_CLASSIFIER_NUM_BLOCKS is not None:
146 | num_blocks = MODELS_NDPM_CLASSIFIER_NUM_BLOCKS
147 | else:
148 | num_blocks = self.num_blocks
149 | if MODELS_NDPM_CLASSIFIER_NORM_LAYER is not None:
150 | self.norm_layer = getattr(nn, MODELS_NDPM_CLASSIFIER_NORM_LAYER)
151 | else:
152 | self.norm_layer = nn.BatchNorm2d
153 |
154 | num_classes = n_classes[params.data]
155 | nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT
156 | nf_cat = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE \
157 | + len(self.precursors) * MODELS_NDPM_CLASSIFIER_CLS_NF_EXT
158 | self.nf = MODELS_NDPM_CLASSIFIER_CLS_NF_BASE if first else MODELS_NDPM_CLASSIFIER_CLS_NF_EXT
159 | self.nf_cat = nf_cat
160 |
161 | self.layer0 = nn.Sequential(
162 | nn.Conv2d(
163 | 3, nf * 1, kernel_size=3, stride=1, padding=1, bias=False
164 | ),
165 | self.norm_layer(nf * 1),
166 | nn.ReLU()
167 | )
168 | self.layer1 = self._make_layer(
169 | nf_cat * 1, nf * 1, num_blocks[0], stride=1)
170 | self.layer2 = self._make_layer(
171 | nf_cat * 1, nf * 2, num_blocks[1], stride=2)
172 | self.layer3 = self._make_layer(
173 | nf_cat * 2, nf * 4, num_blocks[2], stride=2)
174 | self.layer4 = self._make_layer(
175 | nf_cat * 4, nf * 8, num_blocks[3], stride=2)
176 | self.predict = nn.Sequential(
177 | nn.Linear(nf_cat * 8, num_classes),
178 | nn.LogSoftmax(dim=1)
179 | )
180 | self.setup_optimizer()
181 |
182 | def _make_layer(self, nf_in, nf_out, num_blocks, stride):
183 | norm_layer = self.norm_layer
184 | block = self.block
185 | downsample = None
186 | if stride != 1 or nf_in != nf_out:
187 | downsample = nn.Sequential(
188 | conv1x1(nf_in, nf_out, stride),
189 | norm_layer(nf_out),
190 | )
191 | layers = [block(
192 | nf_in, nf_out, stride,
193 | downsample=downsample,
194 | norm_layer=norm_layer
195 | )]
196 | for _ in range(1, num_blocks):
197 | layers.append(block(nf_out, nf_out, norm_layer=norm_layer))
198 |
199 | return nn.Sequential(*layers)
200 |
201 | def forward(self, x, collect=False):
202 | x = maybe_cuda(x)
203 |
204 | # First component
205 | if len(self.precursors) == 0:
206 | h1 = self.layer0(x)
207 | h2 = self.layer1(h1)
208 | h3 = self.layer2(h2)
209 | h4 = self.layer3(h3)
210 | h5 = self.layer4(h4)
211 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1)
212 | pred = self.predict(h5)
213 |
214 | if collect:
215 | return [pred], [
216 | h1.detach(), h2.detach(), h3.detach(),
217 | h4.detach(), h5.detach()]
218 | else:
219 | return pred
220 |
221 | # Second or layer component
222 | preds, features = self.precursors[-1](x, collect=True)
223 | h1 = self.layer0(x)
224 | h1_cat = torch.cat([features[0], h1], dim=1)
225 | h2 = self.layer1(h1_cat)
226 | h2_cat = torch.cat([features[1], h2], dim=1)
227 | h3 = self.layer2(h2_cat)
228 | h3_cat = torch.cat([features[2], h3], dim=1)
229 | h4 = self.layer3(h3_cat)
230 | h4_cat = torch.cat([features[3], h4], dim=1)
231 | h5 = self.layer4(h4_cat)
232 | h5 = F.avg_pool2d(h5, h5.size(2)).view(h5.size(0), -1)
233 | h5_cat = torch.cat([features[4], h5], dim=1)
234 | pred = self.predict(h5_cat)
235 |
236 | if collect:
237 | preds.append(pred)
238 | return preds, [
239 | h1_cat.detach(), h2_cat.detach(), h3_cat.detach(),
240 | h4_cat.detach(), h5_cat.detach(),
241 | ]
242 | else:
243 | return pred
244 |
--------------------------------------------------------------------------------
/models/ndpm/vae.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from itertools import accumulate
3 | import torch
4 | import torch.nn as nn
5 |
6 | from utils.utils import maybe_cuda
7 | from .loss import bernoulli_nll, logistic_nll, gaussian_nll, laplace_nll
8 | from .component import ComponentG
9 | from .utils import Lambda
10 | from utils.global_vars import *
11 | from utils.setup_elements import input_size_match
12 |
13 | class Vae(ComponentG, ABC):
14 | def __init__(self, params, experts):
15 | super().__init__(params, experts)
16 | x_c, x_h, x_w = input_size_match[params.data]
17 | bernoulli = MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli'
18 | if bernoulli:
19 | self.log_var_param = None
20 | elif MODELS_NDPM_VAE_LEARN_X_LOG_VAR:
21 | self.log_var_param = nn.Parameter(
22 | torch.ones([x_c]) * MODELS_NDPM_VAE_X_LOG_VAR_PARAM,
23 | requires_grad=True
24 | )
25 | else:
26 | self.log_var_param = (
27 | maybe_cuda(torch.ones([x_c])) *
28 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM
29 | )
30 |
31 | def forward(self, x):
32 | x = maybe_cuda(x)
33 | z_mean, z_log_var = self.encode(x)
34 | z = self.reparameterize(z_mean, z_log_var, 1)
35 | return self.decode(z)
36 |
37 | def nll(self, x, y=None, step=None):
38 | x = maybe_cuda(x)
39 | z_mean, z_log_var = self.encode(x)
40 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES)
41 | x_mean = self.decode(z)
42 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, *x.shape[1:])
43 | x_log_var = (
44 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else
45 | self.log_var.view(1, 1, -1, 1, 1)
46 | )
47 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var)
48 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES, -1)
49 | loss_recon = loss_recon.sum(2).mean(1)
50 | loss_kl = self.gaussian_kl(z_mean, z_log_var)
51 | loss_vae = loss_recon + loss_kl
52 |
53 | return loss_vae
54 |
55 | def sample(self, n=1):
56 | z = maybe_cuda(torch.randn(n, MODELS_NDPM_VAE_Z_DIM))
57 | x_mean = self.decode(z)
58 | return x_mean
59 |
60 | def reconstruction_loss(self, x, x_mean, x_log_var=None):
61 | loss_type = MODELS_NDPM_VAE_RECON_LOSS
62 | loss = (
63 | bernoulli_nll if loss_type == 'bernoulli' else
64 | gaussian_nll if loss_type == 'gaussian' else
65 | laplace_nll if loss_type == 'laplace' else
66 | logistic_nll if loss_type == 'logistic' else None
67 | )
68 | if loss is None:
69 | raise ValueError('Unknown recon_loss type: {}'.format(loss_type))
70 |
71 | if len(x_mean.size()) > len(x.size()):
72 | x = x.unsqueeze(1)
73 |
74 | return (
75 | loss(x, x_mean) if x_log_var is None else
76 | loss(x, x_mean, x_log_var)
77 | )
78 |
79 | @staticmethod
80 | def gaussian_kl(q_mean, q_log_var, p_mean=None, p_log_var=None):
81 | # p defaults to N(0, 1)
82 | zeros = torch.zeros_like(q_mean)
83 | p_mean = p_mean if p_mean is not None else zeros
84 | p_log_var = p_log_var if p_log_var is not None else zeros
85 | # calcaulate KL(q, p)
86 | kld = 0.5 * (
87 | p_log_var - q_log_var +
88 | (q_log_var.exp() + (q_mean - p_mean) ** 2) / p_log_var.exp() - 1
89 | )
90 | kld = kld.sum(1)
91 | return kld
92 |
93 | @staticmethod
94 | def reparameterize(z_mean, z_log_var, num_samples=1):
95 | z_std = (z_log_var * 0.5).exp()
96 | z_std = z_std.unsqueeze(1).expand(-1, num_samples, -1)
97 | z_mean = z_mean.unsqueeze(1).expand(-1, num_samples, -1)
98 | unit_normal = torch.randn_like(z_std)
99 | z = z_mean + unit_normal * z_std
100 | z = z.view(-1, z_std.size(2))
101 | return z
102 |
103 | @abstractmethod
104 | def encode(self, x):
105 | pass
106 |
107 | @abstractmethod
108 | def decode(self, x):
109 | pass
110 |
111 | @property
112 | def log_var(self):
113 | return (
114 | None if self.log_var_param is None else
115 | self.log_var_param
116 | )
117 |
118 |
119 | class SharingVae(Vae, ABC):
120 | def collect_nll(self, x, y=None, step=None):
121 | """Collect NLL values
122 |
123 | Returns:
124 | loss_vae: Tensor of shape [B, 1+K]
125 | """
126 | x = maybe_cuda(x)
127 |
128 | # Dummy VAE
129 | dummy_nll = self.experts[0].g.nll(x, y, step)
130 |
131 | # Encode
132 | z_means, z_log_vars, features = self.encode(x, collect=True)
133 |
134 | # Decode
135 | loss_vaes = [dummy_nll]
136 | vaes = [expert.g for expert in self.experts[1:]] + [self]
137 | x_logits = []
138 | for z_mean, z_log_var, vae in zip(z_means, z_log_vars, vaes):
139 | z = self.reparameterize(z_mean, z_log_var, MODELS_NDPM_VAE_Z_SAMPLES)
140 | if MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER:
141 | x_logit = vae.decode(z, as_logit=True)
142 | x_logits.append(x_logit)
143 | continue
144 | x_mean = vae.decode(z)
145 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
146 | *x.shape[1:])
147 | x_log_var = (
148 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else
149 | self.log_var.view(1, 1, -1, 1, 1)
150 | )
151 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var)
152 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
153 | -1)
154 | loss_recon = loss_recon.sum(2).mean(1)
155 | loss_kl = self.gaussian_kl(z_mean, z_log_var)
156 | loss_vae = loss_recon + loss_kl
157 |
158 | loss_vaes.append(loss_vae)
159 |
160 | x_logits = list(accumulate(
161 | x_logits, func=(lambda x, y: x.detach() + y)
162 | ))
163 | for x_logit in x_logits:
164 | x_mean = torch.sigmoid(x_logit)
165 | x_mean = x_mean.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
166 | *x.shape[1:])
167 | x_log_var = (
168 | None if MODELS_NDPM_VAE_RECON_LOSS == 'bernoulli' else
169 | self.log_var.view(1, 1, -1, 1, 1)
170 | )
171 | loss_recon = self.reconstruction_loss(x, x_mean, x_log_var)
172 | loss_recon = loss_recon.view(x.size(0), MODELS_NDPM_VAE_Z_SAMPLES,
173 | -1)
174 | loss_recon = loss_recon.sum(2).mean(1)
175 | loss_kl = self.gaussian_kl(z_mean, z_log_var)
176 | loss_vae = loss_recon + loss_kl
177 | loss_vaes.append(loss_vae)
178 |
179 | return torch.stack(loss_vaes, dim=1)
180 |
181 | @abstractmethod
182 | def encode(self, x, collect=False):
183 | pass
184 |
185 | @abstractmethod
186 | def decode(self, z, as_logit=False):
187 | """
188 | Decode do not share parameters
189 | """
190 | pass
191 |
192 |
193 | class CnnSharingVae(SharingVae):
194 | def __init__(self, params, experts):
195 | super().__init__(params, experts)
196 | self.precursors = [expert.g for expert in self.experts[1:]]
197 | first = len(self.precursors) == 0
198 | nf_base, nf_ext = MODLES_NDPM_VAE_NF_BASE, MODELS_NDPM_VAE_NF_EXT
199 | nf = nf_base if first else nf_ext
200 | nf_cat = nf_base + len(self.precursors) * nf_ext
201 |
202 | h1_dim = 1 * nf
203 | h2_dim = 2 * nf
204 | fc_dim = 4 * nf
205 | h1_cat_dim = 1 * nf_cat
206 | h2_cat_dim = 2 * nf_cat
207 | fc_cat_dim = 4 * nf_cat
208 |
209 | x_c, x_h, x_w = input_size_match[params.data]
210 |
211 | self.fc_dim = fc_dim
212 | feature_volume = ((x_h // 4) * (x_w // 4) *
213 | h2_cat_dim)
214 |
215 | self.enc1 = nn.Sequential(
216 | nn.Conv2d(x_c, h1_dim, 3, 1, 1),
217 | nn.MaxPool2d(2),
218 | nn.ReLU()
219 | )
220 | self.enc2 = nn.Sequential(
221 | nn.Conv2d(h1_cat_dim, h2_dim, 3, 1, 1),
222 | nn.MaxPool2d(2),
223 | nn.ReLU(),
224 | Lambda(lambda x: x.view(x.size(0), -1))
225 | )
226 | self.enc3 = nn.Sequential(
227 | nn.Linear(feature_volume, fc_dim),
228 | nn.ReLU()
229 | )
230 | self.enc_z_mean = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM)
231 | self.enc_z_log_var = nn.Linear(fc_cat_dim, MODELS_NDPM_VAE_Z_DIM)
232 |
233 | self.dec_z = nn.Sequential(
234 | nn.Linear(MODELS_NDPM_VAE_Z_DIM, 4 * nf_base),
235 | nn.ReLU()
236 | )
237 | self.dec3 = nn.Sequential(
238 | nn.Linear(
239 | 4 * nf_base,
240 | (x_h // 4) * (x_w // 4) * 2 * nf_base),
241 | nn.ReLU()
242 | )
243 | self.dec2 = nn.Sequential(
244 | Lambda(lambda x: x.view(
245 | x.size(0), 2 * nf_base,
246 | x_h // 4, x_w // 4)),
247 | nn.ConvTranspose2d(2 * nf_base, 1 * nf_base,
248 | kernel_size=4, stride=2, padding=1),
249 | nn.ReLU()
250 | )
251 | self.dec1 = nn.ConvTranspose2d(1 * nf_base, x_c,
252 | kernel_size=4, stride=2, padding=1)
253 |
254 | self.setup_optimizer()
255 |
256 | def encode(self, x, collect=False):
257 | # When first component
258 | if len(self.precursors) == 0:
259 | h1 = self.enc1(x)
260 | h2 = self.enc2(h1)
261 | h3 = self.enc3(h2)
262 | z_mean = self.enc_z_mean(h3)
263 | z_log_var = self.enc_z_log_var(h3)
264 |
265 | if collect:
266 | return [z_mean], [z_log_var], \
267 | [h1.detach(), h2.detach(), h3.detach()]
268 | else:
269 | return z_mean, z_log_var
270 |
271 | # Second or later component
272 | z_means, z_log_vars, features = \
273 | self.precursors[-1].encode(x, collect=True)
274 |
275 | h1 = self.enc1(x)
276 | h1_cat = torch.cat([features[0], h1], dim=1)
277 | h2 = self.enc2(h1_cat)
278 | h2_cat = torch.cat([features[1], h2], dim=1)
279 | h3 = self.enc3(h2_cat)
280 | h3_cat = torch.cat([features[2], h3], dim=1)
281 | z_mean = self.enc_z_mean(h3_cat)
282 | z_log_var = self.enc_z_log_var(h3_cat)
283 |
284 | if collect:
285 | z_means.append(z_mean)
286 | z_log_vars.append(z_log_var)
287 | features = [h1_cat.detach(), h2_cat.detach(), h3_cat.detach()]
288 | return z_means, z_log_vars, features
289 | else:
290 | return z_mean, z_log_var
291 |
292 | def decode(self, z, as_logit=False):
293 | h3 = self.dec_z(z)
294 | h2 = self.dec3(h3)
295 | h1 = self.dec2(h2)
296 | x_logit = self.dec1(h1)
297 | return x_logit if as_logit else torch.sigmoid(x_logit)
298 |
--------------------------------------------------------------------------------