├── LICENSE
├── README.md
├── requirements.txt
└── src
├── approach
├── .DS_Store
├── __init__.py
├── aux_loss.py
├── bic.py
├── dmc.py
├── eeil.py
├── ewc.py
├── finetuning.py
├── freezing.py
├── il2m.py
├── incremental_learning.py
├── joint.py
├── lucir.py
├── lucir_cwd.py
├── lucir_oracle.py
├── lucir_utils.py
├── lwf.py
├── lwm.py
├── mas.py
├── path_integral.py
├── r_walk.py
└── utils.py
├── data
└── imagenet
│ └── gen_lst_imagenet.py
├── datasets
├── base_dataset.py
├── data_loader.py
├── dataset_config.py
├── exemplars_dataset.py
├── exemplars_selection.py
└── memory_dataset.py
├── exp_cifar_lucir.sh
├── exp_cifar_lucir_cwd.sh
├── exp_im100_joint.sh
├── exp_im100_lucir.sh
├── exp_im100_lucir_cwd.sh
├── exp_im100_lucir_oracle.sh
├── gridsearch.py
├── gridsearch_config.py
├── last_layer_analysis.py
├── loggers
├── disk_logger.py
├── exp_logger.py
└── tensorboard_logger.py
├── main_incremental.py
├── networks
├── __init__.py
├── network.py
├── resnet18.py
└── resnet18_cifar.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Yujun Shi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # (CVPR 2022) Mimicking the Oracle: An Initial Phase Decorrelation Approach for Class Incremental Learning [ArXiv](https://arxiv.org/abs/2112.04731)
2 | This repo contains Official Implementation of our CVPR 2022 paper: Mimicking the Oracle: An Initial Phase Decorrelation Approach for Class Incremental Learning.
3 |
4 |
5 |
6 | ### 1. Abstract
7 |
8 | Class Incremental Learning (CIL) aims at learning a classifier in a phase-by-phase manner, in which only data of a subset of the classes are provided at each phase. Previous works mainly focus on mitigating forgetting in phases after the initial one. However, we find that improving CIL at its initial phase is also a promising direction. Specifically, we experimentally show that directly encouraging CIL Learner at the initial phase to output similar representations as the model jointly trained on all classes can greatly boost the CIL performance. Motivated by this, we study the difference between a na\"ively-trained initial-phase model and the oracle model. Specifically, since one major difference between these two models is the number of training classes, we investigate how such difference affects the model representations. We find that, with fewer training classes, the data representations of each class lie in a long and narrow region; with more training classes, the representations of each class scatter more uniformly. Inspired by this observation, we propose **C**lass-**w**ise **D**ecorrelation (**CwD**) that effectively regularizes representations of each class to scatter more uniformly, thus mimicking the model jointly trained with all classes (i.e., the oracle model). Our CwD is simple to implement and easy to plug into existing methods. Extensive experiments on various benchmark datasets show that CwD consistently and significantly improves the performance of existing state-of-the-art methods by around 1% to 3%.
9 |
10 |
11 |
12 |
13 |
14 | ### 2. Instructions to Run Our Code
15 |
16 | Current codebase only contain experiments on [LUCIR](https://openaccess.thecvf.com/content_CVPR_2019/papers/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.pdf) with CIFAR100 and ImageNet100. Code reproducing results based on [PODNet](https://github.com/arthurdouillard/incremental_learning.pytorch) and [AANet](https://github.com/yaoyao-liu/class-incremental-learning) are based on their repo and will be coming soon!
17 |
18 |
19 |
20 | #### CIFAR100 Experiments w/ LUCIR
21 |
22 | No need to download the datasets, everything will be dealt with automatically.
23 |
24 | For LUCIR baseline, simply first navigate under "src" folder and run:
25 |
26 | ```bash
27 | bash exp_cifar_lucir.sh
28 | ```
29 |
30 | For LUCIR + CwD, first navigate under "src" folder and run:
31 |
32 | ```bash
33 | bash exp_cifar_lucir_cwd.sh
34 | ```
35 |
36 | #### ImageNet100 Experiments w/ LUCIR
37 |
38 | To run ImageNet100, please follow the following two steps:
39 |
40 | Step 1:
41 |
42 | download and extract imagenet dataset under "src/data/imagenet" folder.
43 |
44 | Then, under "src/data/imagenet", run:
45 |
46 | ```bash
47 | python3 gen_lst.py
48 | ```
49 |
50 | This command will generate two list that determine the order of classes for class incremental learning. The class order is shuffled by seed 1993 like most previous works.
51 |
52 |
53 |
54 | Step 2:
55 |
56 | For LUCIR baseline, first navigate under "src" folder and run:
57 |
58 | ```bash
59 | bash exp_im100_lucir.sh
60 | ```
61 |
62 | For LUCIR+CWD, first navigate under "src" folder and run:
63 |
64 | ```bash
65 | bash exp_im100_lucir_cwd.sh
66 | ```
67 |
68 |
69 |
70 | #### Some Comments on Running Scripts.
71 |
72 | For "SEED" variable in the scripts, it is not the seed that used to shuffle the class order, it is the seed that determines model initialisation/data loader sampling, etc. We vary "SEED" from 0,1,2 and average the Average Incremental Accuracy to obtain results reported in the paper.
73 |
74 |
75 |
76 |
77 |
78 | ### 3. For customized usage
79 |
80 | To use our CwD loss in your own project, simply copy and paste the CwD loss implemented in "src/approach/aux\_loss.py" will be fine.
81 |
82 |
83 |
84 |
85 |
86 | ### 4. Citation
87 |
88 | If you find our repo/paper helpful, please consider citing our work :)
89 | ```
90 | @inproceedings{shi2022mimicking,
91 | title={Mimicking the oracle: an initial phase decorrelation approach for class incremental learning},
92 | author={Shi, Yujun and Zhou, Kuangqi and Liang, Jian and Jiang, Zihang and Feng, Jiashi and Torr, Philip HS and Bai, Song and Tan, Vincent YF},
93 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
94 | pages={16722--16731},
95 | year={2022}
96 | }
97 | ```
98 |
99 |
100 |
101 | ### 5. Contact
102 |
103 | Yujun Shi (shi.yujun@u.nus.edu)
104 |
105 |
106 |
107 | ### 6. Acknowledgements
108 |
109 | Our code is based on [FACIL](https://github.com/mmasana/FACIL), one of the most well-written CIL library in my opinion:)
110 |
111 |
112 |
113 | ### 7. Some Additional Remarks
114 |
115 | Based on the original implementation of FACIL, I also implemented Distributed Data Parallel to enable multi-GPU training. However, it seems that the performance is not as good as single card training (about 0.5% lower). Therefore, in all experiments, I still use single card training.
116 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # NOTE: Previous versions of pytorch and torchvision might also work as well,
2 | # but we haven't test them yet
3 | torch>=1.7.1
4 | torchvision>=0.8.2
5 | matplotlib
6 | numpy
7 | tensorboard
--------------------------------------------------------------------------------
/src/approach/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yujun-Shi/CwD/291e2289b00140b81477f5a5b3e5e78938c6e8cd/src/approach/.DS_Store
--------------------------------------------------------------------------------
/src/approach/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # list all approaches available
4 | __all__ = list(
5 | map(lambda x: x[:-3],
6 | filter(lambda x: x not in ['__init__.py', 'aux_loss.py', 'incremental_learning.py'] and x.endswith('.py'),
7 | os.listdir(os.path.dirname(__file__))
8 | )
9 | )
10 | )
11 |
--------------------------------------------------------------------------------
/src/approach/aux_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from .utils import *
6 | from torch.distributions import Normal, Independent
7 | from torch import distributed as dist
8 |
9 | # function credit to https://github.com/facebookresearch/barlowtwins/blob/main/main.py
10 | def off_diagonal(x):
11 | # return a flattened view of the off-diagonal elements of a square matrix
12 | n, m = x.shape
13 | assert n == m
14 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
15 |
16 | class DecorrelateLossClass(nn.Module):
17 |
18 | def __init__(self, reject_threshold=1, ddp=False):
19 | super(DecorrelateLossClass, self).__init__()
20 | self.eps = 1e-8
21 | self.reject_threshold = reject_threshold
22 | self.ddp = ddp
23 |
24 | def forward(self, x, y):
25 | _, C = x.shape
26 | if self.ddp:
27 | # if DDP
28 | # first gather all x and labels from the world
29 | x = torch.cat(GatherLayer.apply(x), dim=0)
30 | y = global_gather(y)
31 |
32 | loss = 0.0
33 | uniq_l, uniq_c = y.unique(return_counts=True)
34 | n_count = 0
35 | for i, label in enumerate(uniq_l):
36 | if uniq_c[i] <= self.reject_threshold:
37 | continue
38 | x_label = x[y==label, :]
39 | x_label = x_label - x_label.mean(dim=0, keepdim=True)
40 | x_label = x_label / torch.sqrt(self.eps + x_label.var(dim=0, keepdim=True))
41 |
42 | N = x_label.shape[0]
43 | corr_mat = torch.matmul(x_label.t(), x_label)
44 |
45 | # Notice that here the implementation is a little bit different
46 | # from the paper as we extract only the off-diagonal terms for regularization.
47 | # Mathematically, these two are the same thing since diagonal terms are all constant 1.
48 | # However, we find that this implementation is more numerically stable.
49 | loss += (off_diagonal(corr_mat).pow(2)).mean()
50 |
51 | n_count += N
52 |
53 | if n_count == 0:
54 | # there is no effective class to compute correlation matrix
55 | return 0
56 | else:
57 | loss = loss / n_count
58 | return loss
59 |
--------------------------------------------------------------------------------
/src/approach/dmc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from copy import deepcopy
4 | from argparse import ArgumentParser
5 |
6 | from datasets.data_loader import get_loaders
7 | from .incremental_learning import Inc_Learning_Appr
8 | from datasets.exemplars_dataset import ExemplarsDataset
9 |
10 |
11 | class Appr(Inc_Learning_Appr):
12 | """ Class implementing the Deep Model Consolidation (DMC) approach
13 | described in https://arxiv.org/abs/1903.07864
14 | Original code available at https://github.com/juntingzh/incremental-learning-baselines
15 | """
16 |
17 | def __init__(self, model, device, nepochs=160, lr=0.1, lr_min=1e-4, lr_factor=10, lr_patience=8, clipgrad=10000,
18 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
19 | logger=None, exemplars_dataset=None, aux_dataset='imagenet_32', aux_batch_size=128):
20 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
21 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
22 | exemplars_dataset)
23 | self.model_old = None
24 | self.model_new = None
25 | self.aux_dataset = aux_dataset
26 | self.aux_batch_size = aux_batch_size
27 | # get dataloader for auxiliar dataset
28 | aux_trn_ldr, _, aux_val_ldr, _ = get_loaders([self.aux_dataset], num_tasks=1, nc_first_task=None, validation=0,
29 | batch_size=self.aux_batch_size, num_workers=4, pin_memory=False)
30 | self.aux_trn_loader = aux_trn_ldr[0]
31 | self.aux_val_loader = aux_val_ldr[0]
32 | # Since an auxiliary dataset is available, using exemplars could be redundant
33 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
34 | assert (have_exemplars == 0), 'Warning: DMC does not use exemplars. Comment this line to force it.'
35 |
36 | @staticmethod
37 | def exemplars_dataset_class():
38 | return ExemplarsDataset
39 |
40 | @staticmethod
41 | def extra_parser(args):
42 | """Returns a parser containing the approach specific parameters"""
43 | parser = ArgumentParser()
44 | # Sec. 4.2.1 "We use ImageNet32x32 dataset as the source for auxiliary data in the model consolidation stage."
45 | parser.add_argument('--aux-dataset', default='imagenet_32_reduced', type=str, required=False,
46 | help='Auxiliary dataset (default=%(default)s)')
47 | parser.add_argument('--aux-batch-size', default=128, type=int, required=False,
48 | help='Batch size for auxiliary dataset (default=%(default)s)')
49 | return parser.parse_known_args(args)
50 |
51 | def _get_optimizer(self):
52 | """Returns the optimizer"""
53 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
54 | # if there are no exemplars, previous heads are not modified
55 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
56 | else:
57 | params = self.model.parameters()
58 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
59 |
60 | def pre_train_process(self, t, trn_loader):
61 | """Runs before training all epochs of the task (before the train session)"""
62 | if t > 0:
63 | # Re-initialize model
64 | for m in self.model.modules():
65 | if isinstance(m, (nn.Conv2d, nn.BatchNorm2d, nn.Linear)):
66 | m.reset_parameters()
67 | # Get new model
68 | self.model_new = deepcopy(self.model)
69 | for h in self.model_new.heads[:-1]:
70 | with torch.no_grad():
71 | h.weight.zero_()
72 | h.bias.zero_()
73 | for p in h.parameters():
74 | p.requires_grad = False
75 | else:
76 | self.model_new = self.model
77 |
78 | def train_loop(self, t, trn_loader, val_loader):
79 | """Contains the epochs loop"""
80 | if t > 0:
81 | # Args for the new data trainer and for the student trainer are the same
82 | dmc_args = dict(nepochs=self.nepochs, lr=self.lr, lr_min=self.lr_min, lr_factor=self.lr_factor,
83 | lr_patience=self.lr_patience, clipgrad=self.clipgrad, momentum=self.momentum,
84 | wd=self.wd, multi_softmax=self.multi_softmax, wu_nepochs=self.warmup_epochs,
85 | wu_lr_factor=self.warmup_lr, fix_bn=self.fix_bn, logger=self.logger)
86 | # Train new model in new data
87 | new_trainer = NewTaskTrainer(self.model_new, self.device, **dmc_args)
88 | new_trainer.train_loop(t, trn_loader, val_loader)
89 | self.model_new.eval()
90 | self.model_new.freeze_all()
91 | print('=' * 108)
92 | print("Training of student")
93 | print('=' * 108)
94 | # Train student model using both old and new model
95 | student_trainer = StudentTrainer(self.model, self.model_new, self.model_old, self.device, **dmc_args)
96 | student_trainer.train_loop(t, self.aux_trn_loader, self.aux_val_loader)
97 | else:
98 | # FINETUNING TRAINING -- contains the epochs loop
99 | super().train_loop(t, trn_loader, val_loader)
100 |
101 | def post_train_process(self, t, trn_loader):
102 | """Runs after training all the epochs of the task (after the train session)"""
103 |
104 | # Restore best and save model for future tasks
105 | self.model_old = deepcopy(self.model)
106 | self.model_old.eval()
107 | self.model_old.freeze_all()
108 |
109 |
110 | class NewTaskTrainer(Inc_Learning_Appr):
111 | def __init__(self, model, device, nepochs=160, lr=0.1, lr_min=1e-4, lr_factor=10, lr_patience=8, clipgrad=10000,
112 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
113 | eval_on_train=False, logger=None):
114 | super(NewTaskTrainer, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad,
115 | momentum, wd, multi_softmax, wu_nepochs, wu_lr_factor, fix_bn,
116 | eval_on_train, logger)
117 |
118 |
119 | class StudentTrainer(Inc_Learning_Appr):
120 | def __init__(self, model, model_new, model_old, device, nepochs=160, lr=0.1, lr_min=1e-4, lr_factor=10,
121 | lr_patience=8, clipgrad=10000, momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0,
122 | wu_lr_factor=1, fix_bn=False, eval_on_train=False, logger=None):
123 | super(StudentTrainer, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad,
124 | momentum, wd, multi_softmax, wu_nepochs, wu_lr_factor, fix_bn,
125 | eval_on_train, logger)
126 |
127 | self.model_old = model_old
128 | self.model_new = model_new
129 |
130 | # Runs a single epoch of student's training
131 | def train_epoch(self, t, trn_loader):
132 | self.model.train()
133 | if self.fix_bn and t > 0:
134 | self.model.freeze_bn()
135 | for images, targets in trn_loader:
136 | images, targets = images.cuda(), targets.cuda()
137 | # Forward old and new model
138 | targets_old = self.model_old(images)
139 | targets_new = self.model_new(images)
140 | # Forward current model
141 | outputs = self.model(images)
142 | loss = self.criterion(t, outputs, targets_old, targets_new)
143 | # Backward
144 | self.optimizer.zero_grad()
145 | loss.backward()
146 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
147 | self.optimizer.step()
148 |
149 | # Contains the evaluation code for evaluating the student
150 | def eval(self, t, val_loader):
151 | with torch.no_grad():
152 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
153 | self.model.eval()
154 | for images, targets in val_loader:
155 | images = images.cuda()
156 | # Forward old and new model
157 | targets_old = self.model_old(images)
158 | targets_new = self.model_new(images)
159 | # Forward current model
160 | outputs = self.model(images)
161 | loss = self.criterion(t, outputs, targets_old, targets_new)
162 | # Log
163 | total_loss += loss.item() * len(targets)
164 | total_num += len(targets)
165 | return total_loss / total_num, -1, -1
166 |
167 | # Returns the loss value for the student
168 | def criterion(self, t, outputs, targets_old, targets_new=None):
169 | # Eq. 2: Model Consolidation
170 | with torch.no_grad():
171 | # Eq. 4: "The regression target of the consolidated model is the concatenation of normalized logits of
172 | # the two specialist models."
173 | targets = torch.cat(targets_old[:t] + [targets_new[t]], dim=1)
174 | targets -= targets.mean(0)
175 | # Eq. 3: Double Distillation Loss
176 | return torch.nn.functional.mse_loss(torch.cat(outputs, dim=1), targets.detach(), reduction='mean')
177 |
--------------------------------------------------------------------------------
/src/approach/eeil.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 | from copy import deepcopy
4 | from argparse import ArgumentParser
5 | from torch.nn import functional as F
6 | from torch.utils.data import DataLoader
7 |
8 | from .incremental_learning import Inc_Learning_Appr
9 | from datasets.exemplars_dataset import ExemplarsDataset
10 |
11 |
12 | class Appr(Inc_Learning_Appr):
13 | """Class implementing the End-to-end Incremental Learning (EEIL) approach described in
14 | http://openaccess.thecvf.com/content_ECCV_2018/papers/Francisco_M._Castro_End-to-End_Incremental_Learning_ECCV_2018_paper.pdf
15 | Original code available at https://github.com/fmcp/EndToEndIncrementalLearning
16 | Helpful code from https://github.com/arthurdouillard/incremental_learning.pytorch
17 | """
18 |
19 | def __init__(self, model, device, nepochs=90, lr=0.1, lr_min=1e-6, lr_factor=10, lr_patience=5, clipgrad=10000,
20 | momentum=0.9, wd=0.0001, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
21 | eval_on_train=False, logger=None, exemplars_dataset=None, lamb=1.0, T=2, lr_finetuning_factor=0.1,
22 | nepochs_finetuning=40, noise_grad=False):
23 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
24 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
25 | exemplars_dataset)
26 | self.model_old = None
27 | self.lamb = lamb
28 | self.T = T
29 | self.lr_finetuning_factor = lr_finetuning_factor
30 | self.nepochs_finetuning = nepochs_finetuning
31 | self.noise_grad = noise_grad
32 |
33 | self._train_epoch = 0
34 | self._finetuning_balanced = None
35 |
36 | # EEIL is expected to be used with exemplars. If needed to be used without exemplars, overwrite here the
37 | # `_get_optimizer` function with the one in LwF and update the criterion
38 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
39 | if not have_exemplars:
40 | warnings.warn("Warning: EEIL is expected to use exemplars. Check documentation.")
41 |
42 | @staticmethod
43 | def exemplars_dataset_class():
44 | return ExemplarsDataset
45 |
46 | @staticmethod
47 | def extra_parser(args):
48 | """Returns a parser containing the approach specific parameters"""
49 | parser = ArgumentParser()
50 | # Added trade-off between the terms of Eq. 1 -- L = L_C + lamb * L_D
51 | parser.add_argument('--lamb', default=1.0, type=float, required=False,
52 | help='Forgetting-intransigence trade-off (default=%(default)s)')
53 | # Page 6: "Based on our empirical results, we set T to 2 for all our experiments"
54 | parser.add_argument('--T', default=2.0, type=float, required=False,
55 | help='Temperature scaling (default=%(default)s)')
56 | # "The same reduction is used in the case of fine-tuning, except that the starting rate is 0.01."
57 | parser.add_argument('--lr-finetuning-factor', default=0.01, type=float, required=False,
58 | help='Finetuning learning rate factor (default=%(default)s)')
59 | # Number of epochs for balanced training
60 | parser.add_argument('--nepochs-finetuning', default=40, type=int, required=False,
61 | help='Number of epochs for balanced training (default=%(default)s)')
62 | # the addition of noise to the gradients
63 | parser.add_argument('--noise-grad', action='store_true',
64 | help='Add noise to gradients (default=%(default)s)')
65 | return parser.parse_known_args(args)
66 |
67 | def _train_unbalanced(self, t, trn_loader, val_loader):
68 | """Unbalanced training"""
69 | self._finetuning_balanced = False
70 | self._train_epoch = 0
71 | loader = self._get_train_loader(trn_loader, False)
72 | super().train_loop(t, loader, val_loader)
73 | return loader
74 |
75 | def _train_balanced(self, t, trn_loader, val_loader):
76 | """Balanced finetuning"""
77 | self._finetuning_balanced = True
78 | self._train_epoch = 0
79 | orig_lr = self.lr
80 | self.lr *= self.lr_finetuning_factor
81 | orig_nepochs = self.nepochs
82 | self.nepochs = self.nepochs_finetuning
83 | loader = self._get_train_loader(trn_loader, True)
84 | super().train_loop(t, loader, val_loader)
85 | self.lr = orig_lr
86 | self.nepochs = orig_nepochs
87 |
88 | def _get_train_loader(self, trn_loader, balanced=False):
89 | """Modify loader to be balanced or unbalanced"""
90 | exemplars_ds = self.exemplars_dataset
91 | trn_dataset = trn_loader.dataset
92 | if balanced:
93 | indices = torch.randperm(len(trn_dataset))
94 | trn_dataset = torch.utils.data.Subset(trn_dataset, indices[:len(exemplars_ds)])
95 | ds = exemplars_ds + trn_dataset
96 | return DataLoader(ds, batch_size=trn_loader.batch_size,
97 | shuffle=True,
98 | num_workers=trn_loader.num_workers,
99 | pin_memory=trn_loader.pin_memory)
100 |
101 | def _noise_grad(self, parameters, iteration, eta=0.3, gamma=0.55):
102 | """Add noise to the gradients"""
103 | parameters = list(filter(lambda p: p.grad is not None, parameters))
104 | variance = eta / ((1 + iteration) ** gamma)
105 | for p in parameters:
106 | p.grad.add_(torch.randn(p.grad.shape, device=p.grad.device) * variance)
107 |
108 | def train_loop(self, t, trn_loader, val_loader):
109 | """Contains the epochs loop"""
110 | if t == 0: # First task is simple training
111 | super().train_loop(t, trn_loader, val_loader)
112 | loader = trn_loader
113 | else:
114 | # Page 4: "4. Incremental Learning" -- Only modification is that instead of preparing examplars before
115 | # training, we do it online using the stored old model.
116 |
117 | # Training process (new + old) - unbalanced training
118 | loader = self._train_unbalanced(t, trn_loader, val_loader)
119 | # Balanced fine-tunning (new + old)
120 | self._train_balanced(t, trn_loader, val_loader)
121 |
122 | # After task training: update exemplars
123 | self.exemplars_dataset.collect_exemplars(self.model, loader, val_loader.dataset.transform)
124 |
125 | def post_train_process(self, t, trn_loader):
126 | """Runs after training all the epochs of the task (after the train session)"""
127 |
128 | # Save old model to extract features later
129 | self.model_old = deepcopy(self.model)
130 | self.model_old.eval()
131 | self.model_old.freeze_all()
132 |
133 | def train_epoch(self, t, trn_loader):
134 | """Runs a single epoch"""
135 | self.model.train()
136 | if self.fix_bn and t > 0:
137 | self.model.freeze_bn()
138 | for images, targets in trn_loader:
139 | images = images.to(self.device)
140 | # Forward old model
141 | outputs_old = None
142 | if t > 0:
143 | outputs_old = self.model_old(images)
144 | # Forward current model
145 | outputs = self.model(images)
146 | loss = self.criterion(t, outputs, targets.to(self.device), outputs_old)
147 | # Backward
148 | self.optimizer.zero_grad()
149 | loss.backward()
150 | # Page 8: "We apply L2-regularization and random noise [21] (with parameters eta = 0.3, gamma = 0.55)
151 | # on the gradients to minimize overfitting"
152 | # https://github.com/fmcp/EndToEndIncrementalLearning/blob/master/cnn_train_dag_exemplars.m#L367
153 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
154 | if self.noise_grad:
155 | self._noise_grad(self.model.parameters(), self._train_epoch)
156 | self.optimizer.step()
157 | self._train_epoch += 1
158 |
159 | def criterion(self, t, outputs, targets, outputs_old=None):
160 | """Returns the loss value"""
161 |
162 | # Classification loss for new classes
163 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
164 | # Distilation loss
165 | if t > 0 and outputs_old:
166 | # take into account current head when doing balanced finetuning
167 | last_head_idx = t if self._finetuning_balanced else (t - 1)
168 | for i in range(last_head_idx):
169 | loss += self.lamb * F.binary_cross_entropy(F.softmax(outputs[i] / self.T, dim=1),
170 | F.softmax(outputs_old[i] / self.T, dim=1))
171 | return loss
172 |
--------------------------------------------------------------------------------
/src/approach/ewc.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 | import itertools
4 | from argparse import ArgumentParser
5 |
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 | from datasets.exemplars_selection import override_dataset_transform
8 | from .incremental_learning import Inc_Learning_Appr
9 | from torch.distributions.categorical import Categorical
10 |
11 | from torch.utils.data import DataLoader
12 |
13 | class Appr(Inc_Learning_Appr):
14 | """Class implementing the Elastic Weight Consolidation (EWC) approach
15 | described in http://arxiv.org/abs/1612.00796
16 | """
17 |
18 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], clipgrad=10000,
19 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, ddp=False,
20 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred',
21 | fi_num_samples=-1, save_fisher=False):
22 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, clipgrad, momentum, wd,
23 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
24 | exemplars_dataset)
25 | self.lamb = lamb
26 | self.alpha = alpha
27 | self.sampling_type = fi_sampling_type
28 | self.num_samples = fi_num_samples
29 |
30 | # In all cases, we only keep importance weights for the model, but not for the heads.
31 | feat_ext = self.model.model
32 | # Store current parameters as the initial parameters before first task starts
33 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad}
34 | # Store fisher information weight importance
35 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
36 | if p.requires_grad}
37 | self.save_fisher = save_fisher
38 |
39 | @staticmethod
40 | def exemplars_dataset_class():
41 | return ExemplarsDataset
42 |
43 | @staticmethod
44 | def extra_parser(args):
45 | """Returns a parser containing the approach specific parameters"""
46 | parser = ArgumentParser()
47 | # Eq. 3: "lambda sets how important the old task is compared to the new one"
48 | parser.add_argument('--lamb', default=5000, type=float, required=False,
49 | help='Forgetting-intransigence trade-off (default=%(default)s)')
50 | # Define how old and new fisher is fused, by default it is a 50-50 fusion
51 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
52 | help='EWC alpha (default=%(default)s)')
53 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False,
54 | choices=['true', 'max_pred', 'multinomial'],
55 | help='Sampling type for Fisher information (default=%(default)s)')
56 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False,
57 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
58 | parser.add_argument('--save-fisher', action='store_true',
59 | help='whether to save Fisher information')
60 | return parser.parse_known_args(args)
61 |
62 | def _get_optimizer(self):
63 | """Returns the optimizer"""
64 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
65 | # if there are no exemplars, previous heads are not modified
66 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
67 | else:
68 | params = self.model.parameters()
69 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
70 |
71 | # def compute_fisher_matrix_diag(self, trn_loader):
72 | # # Store Fisher Information
73 | # fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
74 | # if p.requires_grad}
75 | # # Compute fisher information for specified number of samples -- rounded to the batch size
76 | # n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
77 | # else (len(trn_loader.dataset) // trn_loader.batch_size)
78 | # # Do forward and backward pass to compute the fisher information
79 | # self.model.train()
80 | # for images, targets in itertools.islice(trn_loader, n_samples_batches):
81 | # outputs = self.model.forward(images.to(self.device))
82 |
83 | # if self.sampling_type == 'true':
84 | # # Use the labels to compute the gradients based on the CE-loss with the ground truth
85 | # preds = targets.to(self.device)
86 | # elif self.sampling_type == 'max_pred':
87 | # # Not use labels and compute the gradients related to the prediction the model has learned
88 | # preds = torch.cat(outputs, dim=1).argmax(1).flatten()
89 | # elif self.sampling_type == 'multinomial':
90 | # # Use a multinomial sampling to compute the gradients
91 | # # probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1)
92 | # # preds = torch.multinomial(probs, len(targets)).flatten()
93 | # preds = Categorical(logits=outputs[-1]).sample()
94 |
95 | # loss = torch.nn.functional.cross_entropy(outputs[-1], preds)
96 | # self.optimizer.zero_grad()
97 | # loss.backward()
98 | # # Accumulate all gradients from loss with regularization
99 | # for n, p in self.model.model.named_parameters():
100 | # if p.grad is not None:
101 | # fisher[n] += p.grad.pow(2) * len(targets)
102 | # # Apply mean across all samples
103 | # n_samples = n_samples_batches * trn_loader.batch_size
104 | # fisher = {n: (p / n_samples) for n, p in fisher.items()}
105 | # return fisher
106 |
107 | def compute_fisher_matrix_diag(self, trn_loader, val_loader):
108 | # Store Fisher Information
109 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
110 | if p.requires_grad}
111 | # Do forward and backward pass to compute the fisher information
112 | self.model.eval()
113 | with override_dataset_transform(trn_loader.dataset, val_loader.dataset.transform) as _ds:
114 | fisher_loader = DataLoader(_ds, batch_size=1, shuffle=False,
115 | num_workers=trn_loader.num_workers, pin_memory=trn_loader.pin_memory)
116 | n_samples = 0
117 | for images, targets in tqdm(fisher_loader):
118 | images, targets = images.to(self.device), targets.to(self.device)
119 |
120 | outputs = self.model.forward(images)
121 | if self.sampling_type == 'true':
122 | preds = targets
123 | elif self.sampling_type == 'max_pred':
124 | preds = torch.cat(outputs, dim=1).argmax(1).flatten()
125 | elif self.sampling_type == 'multinomial':
126 | preds = Categorical(logits=outputs[-1]).sample()
127 |
128 | loss = torch.nn.functional.cross_entropy(outputs[-1], preds)
129 | self.optimizer.zero_grad()
130 | loss.backward()
131 | # Accumulate all gradients from loss with regularization
132 | for n, p in self.model.model.named_parameters():
133 | if p.grad is not None:
134 | fisher[n] += p.grad.pow(2) * len(targets)
135 | n_samples += len(targets)
136 |
137 | fisher = {n: (p / n_samples) for n, p in fisher.items()}
138 | return fisher
139 |
140 | def train_loop(self, t, trn_loader, val_loader):
141 | """Contains the epochs loop"""
142 |
143 | # add exemplars to train_loader
144 | if len(self.exemplars_dataset) > 0 and t > 0:
145 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
146 | batch_size=trn_loader.batch_size,
147 | shuffle=True,
148 | num_workers=trn_loader.num_workers,
149 | pin_memory=trn_loader.pin_memory)
150 |
151 | # FINETUNING TRAINING -- contains the epochs loop
152 | super().train_loop(t, trn_loader, val_loader)
153 |
154 | # EXEMPLAR MANAGEMENT -- select training subset
155 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
156 |
157 | def post_train_process(self, t, trn_loader, val_loader):
158 | """Runs after training all the epochs of the task (after the train session)"""
159 |
160 | # Store current parameters for the next task
161 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
162 |
163 | # calculate Fisher information
164 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader, val_loader)
165 | # merge fisher information, we do not want to keep fisher information for each task in memory
166 | for n in self.fisher.keys():
167 | # Added option to accumulate fisher over time with a pre-fixed growing alpha
168 | # if self.alpha == -1:
169 | # alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
170 | # self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n]
171 | # else:
172 | # self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n])
173 |
174 | # directly adding more constraint
175 | self.fisher[n] = self.fisher[n] + curr_fisher[n]
176 |
177 | if self.save_fisher:
178 | torch.save(self.fisher, './fisher/lamb_{}_task_{}.pt'.format(self.lamb, t))
179 |
180 | def criterion(self, t, outputs, targets):
181 | """Returns the loss value"""
182 | loss = 0
183 | if t > 0:
184 | loss_reg = 0
185 | # Eq. 3: elastic weight consolidation quadratic penalty
186 | for n, p in self.model.model.named_parameters():
187 | if n in self.fisher.keys():
188 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2
189 | loss += self.lamb * loss_reg
190 | # Current cross-entropy loss -- with exemplars use all heads
191 | if len(self.exemplars_dataset) > 0:
192 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
193 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
194 |
--------------------------------------------------------------------------------
/src/approach/finetuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 | import torch.nn.functional as F
8 | from argparse import ArgumentParser
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from .lucir_utils import CosineLinear, BasicBlockNoRelu, BottleneckNoRelu
11 |
12 |
13 | class Appr(Inc_Learning_Appr):
14 | """Class implementing the Class Incremental Learning With Dual Memory (IL2M) approach described in
15 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf
16 | """
17 |
18 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
19 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
20 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None,
21 | first_task_lr=0.1, first_task_bz=128):
22 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
23 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank,
24 | logger, exemplars_dataset)
25 | self.init_classes_means = []
26 | self.current_classes_means = []
27 | self.models_confidence = []
28 | # FLAG to not do scores rectification while finetuning training
29 | self.ft_train = False
30 |
31 | self.first_task_lr = first_task_lr
32 | self.first_task_bz = first_task_bz
33 | self.first_task = True
34 |
35 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
36 | assert (have_exemplars > 0), 'Error: IL2M needs exemplars.'
37 |
38 | @staticmethod
39 | def exemplars_dataset_class():
40 | return ExemplarsDataset
41 |
42 | @staticmethod
43 | def extra_parser(args):
44 | """Returns a parser containing the approach specific parameters"""
45 | parser = ArgumentParser()
46 | parser.add_argument('--first-task-lr', default=0.1, type=float)
47 | parser.add_argument('--first-task-bz', default=32, type=int)
48 | return parser.parse_known_args(args)
49 |
50 | def _get_optimizer(self):
51 | """Returns the optimizer"""
52 | if self.ddp:
53 | model = self.model.module
54 | else:
55 | model = self.model
56 |
57 | params = model.parameters()
58 |
59 | if self.first_task:
60 | self.first_task = False
61 | optimizer = torch.optim.SGD(params, lr=self.first_task_lr, weight_decay=self.wd, momentum=self.momentum)
62 | else:
63 | optimizer = torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
64 | print(optimizer.param_groups[0]['lr'])
65 | return optimizer
66 |
67 | def pre_train_process(self, t, trn_loader):
68 | """Runs before training all epochs of the task (before the train session)"""
69 | if self.ddp:
70 | model = self.model.module
71 | else:
72 | model = self.model
73 |
74 | if t == 0:
75 | # Sec. 4.1: "the ReLU in the penultimate layer is removed to allow the features to take both positive and
76 | # negative values"
77 | if model.model.__class__.__name__ == 'ResNetCifar':
78 | old_block = model.model.layer3[-1]
79 | model.model.layer3[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu,
80 | old_block.conv2, old_block.bn2, old_block.downsample)
81 | elif model.model.__class__.__name__ == 'ResNet':
82 | old_block = model.model.layer4[-1]
83 | model.model.layer4[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu,
84 | old_block.conv2, old_block.bn2, old_block.downsample)
85 | elif model.model.__class__.__name__ == 'ResNetBottleneck':
86 | old_block = model.model.layer4[-1]
87 | model.model.layer4[-1] = BottleneckNoRelu(old_block.conv1, old_block.bn1,
88 | old_block.relu, old_block.conv2, old_block.bn2,
89 | old_block.conv3, old_block.bn3, old_block.downsample)
90 | else:
91 | warnings.warn("Warning: ReLU not removed from last block.")
92 |
93 | # Changes the new head to a CosineLinear
94 | model.heads[-1] = CosineLinear(model.heads[-1].in_features, model.heads[-1].out_features)
95 | model.to(self.device)
96 | # if t > 0:
97 | # Share sigma (Eta in paper) between all the heads
98 | # Yujun: according to il2m, since we'll correct this with model confidence
99 | # maybe we shouldn't share sigma here.
100 | # model.heads[-1].sigma = model.heads[-2].sigma
101 |
102 | # and we probably shouldn't freeze sigma here.
103 | # for h in model.heads[:-1]:
104 | # for param in h.parameters():
105 | # param.requires_grad = False
106 | # model.heads[-1].sigma.requires_grad = True
107 |
108 | # if ddp option is activated, need to re-wrap the ddp model
109 | if self.ddp:
110 | self.model = DDP(self.model.module, device_ids=[self.local_rank])
111 |
112 | # The original code has an option called "imprint weights" that seems to initialize the new head.
113 | # However, this is not mentioned in the paper and doesn't seem to make a significant difference.
114 | super().pre_train_process(t, trn_loader)
115 |
116 | def train_loop(self, t, trn_loader, val_loader):
117 | """Contains the epochs loop"""
118 | if t == 0:
119 | dset = trn_loader.dataset
120 | trn_loader = torch.utils.data.DataLoader(dset,
121 | batch_size=self.first_task_bz,
122 | sampler=trn_loader.sampler,
123 | num_workers=trn_loader.num_workers,
124 | pin_memory=trn_loader.pin_memory)
125 |
126 | # add exemplars to train_loader
127 | if len(self.exemplars_dataset) > 0 and t > 0:
128 | dset = trn_loader.dataset + self.exemplars_dataset
129 | if self.ddp:
130 | trn_sampler = torch.utils.data.DistributedSampler(dset, shuffle=True)
131 | trn_loader = torch.utils.data.DataLoader(dset,
132 | batch_size=trn_loader.batch_size,
133 | sampler=trn_sampler,
134 | num_workers=trn_loader.num_workers,
135 | pin_memory=trn_loader.pin_memory)
136 | else:
137 | trn_loader = torch.utils.data.DataLoader(dset,
138 | batch_size=trn_loader.batch_size,
139 | shuffle=True,
140 | num_workers=trn_loader.num_workers,
141 | pin_memory=trn_loader.pin_memory)
142 |
143 |
144 | # FINETUNING TRAINING -- contains the epochs loop
145 | self.ft_train = True
146 | super().train_loop(t, trn_loader, val_loader)
147 | self.ft_train = False
148 |
149 | if self.ddp:
150 | # need to change the trainloader to the original version without distributed sampler
151 | dset = trn_loader.dataset
152 | trn_loader = torch.utils.data.DataLoader(dset,
153 | batch_size=200, shuffle=False, num_workers=trn_loader.num_workers,
154 | pin_memory=trn_loader.pin_memory)
155 |
156 | # EXEMPLAR MANAGEMENT -- select training subset
157 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform, self.ddp)
158 |
159 | def criterion(self, t, outputs, targets):
160 | if self.ddp:
161 | model = self.model.module
162 | else:
163 | model = self.model
164 |
165 | if type(outputs[0]) == dict:
166 | outputs = [o['wsigma'] for o in outputs]
167 |
168 | """Returns the loss value"""
169 | if len(self.exemplars_dataset) > 0:
170 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
171 | return torch.nn.functional.cross_entropy(outputs[t], targets - model.task_offset[t])
172 |
--------------------------------------------------------------------------------
/src/approach/freezing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the freezing baseline"""
10 |
11 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
12 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
13 | logger=None, exemplars_dataset=None, freeze_after=0, all_outputs=False):
14 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
15 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
16 | exemplars_dataset)
17 | self.freeze_after = freeze_after
18 | self.all_out = all_outputs
19 |
20 | @staticmethod
21 | def exemplars_dataset_class():
22 | return ExemplarsDataset
23 |
24 | @staticmethod
25 | def extra_parser(args):
26 | """Returns a parser containing the approach specific parameters"""
27 | parser = ArgumentParser()
28 | parser.add_argument('--freeze-after', default=0, type=int, required=False,
29 | help='Freeze model except current head after the specified task (default=%(default)s)')
30 | parser.add_argument('--all-outputs', action='store_true', required=False,
31 | help='Allow all weights related to all outputs to be modified (default=%(default)s)')
32 | return parser.parse_known_args(args)
33 |
34 | def _get_optimizer(self):
35 | """Returns the optimizer"""
36 | return torch.optim.SGD(self._train_parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
37 |
38 | def _has_exemplars(self):
39 | """Returns True in case exemplars are being used"""
40 | return self.exemplars_dataset is not None and len(self.exemplars_dataset) > 0
41 |
42 | def post_train_process(self, t, trn_loader):
43 | """Runs after training all the epochs of the task (after the train session)"""
44 | if t >= self.freeze_after:
45 | self.model.freeze_backbone()
46 |
47 | def train_loop(self, t, trn_loader, val_loader):
48 | """Contains the epochs loop"""
49 |
50 | # add exemplars to train_loader
51 | if t > 0 and self._has_exemplars():
52 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
53 | batch_size=trn_loader.batch_size,
54 | shuffle=True,
55 | num_workers=trn_loader.num_workers,
56 | pin_memory=trn_loader.pin_memory)
57 |
58 | # FINETUNING TRAINING -- contains the epochs loop
59 | super().train_loop(t, trn_loader, val_loader)
60 |
61 | # EXEMPLAR MANAGEMENT -- select training subset
62 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
63 |
64 | def train_epoch(self, t, trn_loader):
65 | """Runs a single epoch"""
66 | self._model_train(t)
67 | for images, targets in trn_loader:
68 | # Forward current model
69 | outputs = self.model(images.to(self.device))
70 | loss = self.criterion(t, outputs, targets.to(self.device))
71 | # Backward
72 | self.optimizer.zero_grad()
73 | loss.backward()
74 | torch.nn.utils.clip_grad_norm_(self._train_parameters(), self.clipgrad)
75 | self.optimizer.step()
76 |
77 | def _model_train(self, t):
78 | """Freezes the necessary weights"""
79 | if self.fix_bn and t > 0:
80 | self.model.freeze_bn()
81 | if self.freeze_after >= 0 and t <= self.freeze_after: # non-frozen task - whole model to train
82 | self.model.train()
83 | else:
84 | self.model.model.eval()
85 | if self._has_exemplars():
86 | # with exemplars - use all heads
87 | for head in self.model.heads:
88 | head.train()
89 | else:
90 | # no exemplars - use current head
91 | self.model.heads[-1].train()
92 |
93 | def _train_parameters(self):
94 | """Includes the necessary weights to the optimizer"""
95 | if len(self.model.heads) <= (self.freeze_after + 1):
96 | return self.model.parameters()
97 | else:
98 | if self._has_exemplars():
99 | return [p for head in self.model.heads for p in head.parameters()]
100 | else:
101 | return self.model.heads[-1].parameters()
102 |
103 | def criterion(self, t, outputs, targets):
104 | """Returns the loss value"""
105 | if self.all_out or self._has_exemplars():
106 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
107 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
108 |
--------------------------------------------------------------------------------
/src/approach/il2m.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 | import torch.nn.functional as F
8 | from argparse import ArgumentParser
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from .lucir_utils import CosineLinear, BasicBlockNoRelu, BottleneckNoRelu
11 |
12 |
13 | class Appr(Inc_Learning_Appr):
14 | """Class implementing the Class Incremental Learning With Dual Memory (IL2M) approach described in
15 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf
16 | """
17 |
18 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
19 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
20 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None,
21 | first_task_lr=0.1, first_task_bz=128):
22 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
23 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank,
24 | logger, exemplars_dataset)
25 | self.init_classes_means = []
26 | self.current_classes_means = []
27 | self.models_confidence = []
28 | # FLAG to not do scores rectification while finetuning training
29 | self.ft_train = False
30 |
31 | self.first_task_lr = first_task_lr
32 | self.first_task_bz = first_task_bz
33 | self.first_task = True
34 |
35 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
36 | assert (have_exemplars > 0), 'Error: IL2M needs exemplars.'
37 |
38 | @staticmethod
39 | def exemplars_dataset_class():
40 | return ExemplarsDataset
41 |
42 | @staticmethod
43 | def extra_parser(args):
44 | """Returns a parser containing the approach specific parameters"""
45 | parser = ArgumentParser()
46 | parser.add_argument('--first-task-lr', default=0.1, type=float)
47 | parser.add_argument('--first-task-bz', default=32, type=int)
48 | return parser.parse_known_args(args)
49 |
50 | def _get_optimizer(self):
51 | """Returns the optimizer"""
52 | if self.ddp:
53 | model = self.model.module
54 | else:
55 | model = self.model
56 |
57 | params = model.parameters()
58 |
59 | if self.first_task:
60 | self.first_task = False
61 | optimizer = torch.optim.SGD(params, lr=self.first_task_lr, weight_decay=self.wd, momentum=self.momentum)
62 | else:
63 | optimizer = torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
64 | print(optimizer.param_groups[0]['lr'])
65 | return optimizer
66 |
67 | def pre_train_process(self, t, trn_loader):
68 | """Runs before training all epochs of the task (before the train session)"""
69 | if self.ddp:
70 | model = self.model.module
71 | else:
72 | model = self.model
73 |
74 | if t == 0:
75 | # Sec. 4.1: "the ReLU in the penultimate layer is removed to allow the features to take both positive and
76 | # negative values"
77 | if model.model.__class__.__name__ == 'ResNetCifar':
78 | old_block = model.model.layer3[-1]
79 | model.model.layer3[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu,
80 | old_block.conv2, old_block.bn2, old_block.downsample)
81 | elif model.model.__class__.__name__ == 'ResNet':
82 | old_block = model.model.layer4[-1]
83 | model.model.layer4[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu,
84 | old_block.conv2, old_block.bn2, old_block.downsample)
85 | elif model.model.__class__.__name__ == 'ResNetBottleneck':
86 | old_block = model.model.layer4[-1]
87 | model.model.layer4[-1] = BottleneckNoRelu(old_block.conv1, old_block.bn1,
88 | old_block.relu, old_block.conv2, old_block.bn2,
89 | old_block.conv3, old_block.bn3, old_block.downsample)
90 | else:
91 | warnings.warn("Warning: ReLU not removed from last block.")
92 |
93 | # Changes the new head to a CosineLinear
94 | model.heads[-1] = CosineLinear(model.heads[-1].in_features, model.heads[-1].out_features)
95 | model.to(self.device)
96 | # if t > 0:
97 | # Share sigma (Eta in paper) between all the heads
98 | # Yujun: according to il2m, since we'll correct this with model confidence
99 | # maybe we shouldn't share sigma here.
100 | # model.heads[-1].sigma = model.heads[-2].sigma
101 |
102 | # and we probably shouldn't freeze sigma here.
103 | # for h in model.heads[:-1]:
104 | # for param in h.parameters():
105 | # param.requires_grad = False
106 | # model.heads[-1].sigma.requires_grad = True
107 |
108 | # if ddp option is activated, need to re-wrap the ddp model
109 | if self.ddp:
110 | self.model = DDP(self.model.module, device_ids=[self.local_rank])
111 |
112 | # The original code has an option called "imprint weights" that seems to initialize the new head.
113 | # However, this is not mentioned in the paper and doesn't seem to make a significant difference.
114 | super().pre_train_process(t, trn_loader)
115 |
116 | # assume the trn_loader using naive sampler instead of distributed sampler
117 | def il2m(self, t, trn_loader):
118 | """Compute and store statistics for score rectification"""
119 | if self.ddp:
120 | model = self.model.module
121 | else:
122 | model = self.model
123 |
124 | old_classes_number = sum(model.task_cls[:t])
125 | classes_counts = [0 for _ in range(sum(model.task_cls))]
126 | models_counts = 0
127 |
128 | # to store statistics for the classes as learned in the current incremental state
129 | self.current_classes_means = [0 for _ in range(old_classes_number)]
130 | # to store statistics for past classes as learned in their initial states
131 | for cls in range(old_classes_number, old_classes_number + model.task_cls[t]):
132 | self.init_classes_means.append(0)
133 | # to store statistics for model confidence in different states (i.e. avg top-1 pred scores)
134 | self.models_confidence.append(0)
135 |
136 | # compute the mean prediction scores that will be used to rectify scores in subsequent tasks
137 | with torch.no_grad():
138 | self.model.eval()
139 | for images, targets in trn_loader:
140 | outputs = self.model(images.to(self.device))
141 | scores = np.array(torch.cat(outputs, dim=1).data.cpu().numpy(), dtype=np.float)
142 | for m in range(len(targets)):
143 | if targets[m] < old_classes_number:
144 | # computation of class means for past classes of the current state.
145 | self.current_classes_means[targets[m]] += scores[m, targets[m]]
146 | classes_counts[targets[m]] += 1
147 | else:
148 | # compute the mean prediction scores for the new classes of the current state
149 | self.init_classes_means[targets[m]] += scores[m, targets[m]]
150 | classes_counts[targets[m]] += 1
151 | # compute the mean top scores for the new classes of the current state
152 | self.models_confidence[t] += np.max(scores[m, ])
153 | models_counts += 1
154 | # Normalize by corresponding number of images
155 | for cls in range(old_classes_number):
156 | self.current_classes_means[cls] /= classes_counts[cls]
157 | for cls in range(old_classes_number, old_classes_number + model.task_cls[t]):
158 | self.init_classes_means[cls] /= classes_counts[cls]
159 | self.models_confidence[t] /= models_counts
160 |
161 | def train_loop(self, t, trn_loader, val_loader):
162 | """Contains the epochs loop"""
163 | if t == 0:
164 | dset = trn_loader.dataset
165 | trn_loader = torch.utils.data.DataLoader(dset,
166 | batch_size=self.first_task_bz,
167 | sampler=trn_loader.sampler,
168 | num_workers=trn_loader.num_workers,
169 | pin_memory=trn_loader.pin_memory)
170 |
171 | # add exemplars to train_loader
172 | if len(self.exemplars_dataset) > 0 and t > 0:
173 | dset = trn_loader.dataset + self.exemplars_dataset
174 | if self.ddp:
175 | trn_sampler = torch.utils.data.DistributedSampler(dset, shuffle=True)
176 | trn_loader = torch.utils.data.DataLoader(dset,
177 | batch_size=trn_loader.batch_size,
178 | sampler=trn_sampler,
179 | num_workers=trn_loader.num_workers,
180 | pin_memory=trn_loader.pin_memory)
181 | else:
182 | trn_loader = torch.utils.data.DataLoader(dset,
183 | batch_size=trn_loader.batch_size,
184 | shuffle=True,
185 | num_workers=trn_loader.num_workers,
186 | pin_memory=trn_loader.pin_memory)
187 |
188 |
189 | # FINETUNING TRAINING -- contains the epochs loop
190 | self.ft_train = True
191 | super().train_loop(t, trn_loader, val_loader)
192 | self.ft_train = False
193 |
194 | if self.ddp:
195 | # need to change the trainloader to the original version without distributed sampler
196 | dset = trn_loader.dataset
197 | trn_loader = torch.utils.data.DataLoader(dset,
198 | batch_size=200, shuffle=False, num_workers=trn_loader.num_workers,
199 | pin_memory=trn_loader.pin_memory)
200 |
201 | # IL2M outputs rectification
202 | self.il2m(t, trn_loader)
203 |
204 | # EXEMPLAR MANAGEMENT -- select training subset
205 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform, self.ddp)
206 |
207 | def calculate_metrics(self, outputs, targets):
208 | """Contains the main Task-Aware and Task-Agnostic metrics"""
209 | if self.ft_train:
210 | # no score rectification while training
211 | hits_taw, hits_tag = super().calculate_metrics(outputs, targets)
212 | else:
213 | if self.ddp:
214 | model = self.model.module
215 | else:
216 | model = self.model
217 | # Task-Aware Multi-Head
218 | pred = torch.zeros_like(targets.to(self.device))
219 | for m in range(len(pred)):
220 | this_task = (model.task_cls.cumsum(0) <= targets[m]).sum()
221 | pred[m] = outputs[this_task][m].argmax() + model.task_offset[this_task]
222 | hits_taw = (pred == targets.to(self.device)).float()
223 | # Task-Agnostic Multi-Head
224 | if self.multi_softmax:
225 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs]
226 | # Eq. 1: rectify predicted scores
227 | old_classes_number = sum(model.task_cls[:-1])
228 | for m in range(len(targets)):
229 | rectified_outputs = torch.cat(outputs, dim=1)
230 | pred[m] = rectified_outputs[m].argmax()
231 | if old_classes_number:
232 | # if the top-1 class predicted by the network is a new one, rectify the score
233 | if int(pred[m]) >= old_classes_number:
234 | for o in range(old_classes_number):
235 | o_task = int((model.task_cls.cumsum(0) <= o).sum())
236 | rectified_outputs[m, o] *= (self.init_classes_means[o] / self.current_classes_means[o]) * \
237 | (self.models_confidence[-1] / self.models_confidence[o_task])
238 | pred[m] = rectified_outputs[m].argmax()
239 | # otherwise, rectification is not done because an old class is directly predicted
240 | hits_tag = (pred == targets.to(self.device)).float()
241 | return hits_taw, hits_tag
242 |
243 | def criterion(self, t, outputs, targets):
244 | if self.ddp:
245 | model = self.model.module
246 | else:
247 | model = self.model
248 |
249 | if type(outputs[0]) == dict:
250 | outputs = [o['wsigma'] for o in outputs]
251 |
252 | """Returns the loss value"""
253 | if len(self.exemplars_dataset) > 0:
254 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
255 | return torch.nn.functional.cross_entropy(outputs[t], targets - model.task_offset[t])
256 |
--------------------------------------------------------------------------------
/src/approach/incremental_learning.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import time
3 | import torch
4 | import numpy as np
5 | from argparse import ArgumentParser
6 |
7 | from loggers.exp_logger import ExperimentLogger
8 | from datasets.exemplars_dataset import ExemplarsDataset
9 |
10 | from .utils import reduce_tensor_mean, reduce_tensor_sum
11 |
12 | class Inc_Learning_Appr:
13 | """Basic class for implementing incremental learning approaches"""
14 |
15 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120],
16 | lr_decay=0.1, clipgrad=10000,
17 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
18 | eval_on_train=False, ddp=False, local_rank=0,
19 | logger: ExperimentLogger = None, exemplars_dataset: ExemplarsDataset = None):
20 | self.model = model
21 | self.device = device
22 | self.nepochs = nepochs
23 | self.lr = lr
24 | self.decay_mile_stone = decay_mile_stone
25 | self.lr_decay = lr_decay
26 | self.clipgrad = clipgrad
27 | self.momentum = momentum
28 | self.wd = wd
29 | self.multi_softmax = multi_softmax
30 | self.logger = logger
31 | self.exemplars_dataset = exemplars_dataset
32 | self.warmup_epochs = wu_nepochs
33 | self.warmup_lr = lr * wu_lr_factor
34 | self.warmup_loss = torch.nn.CrossEntropyLoss()
35 | self.fix_bn = fix_bn
36 | self.eval_on_train = eval_on_train
37 | self.ddp = ddp
38 | self.local_rank = local_rank
39 | self.optimizer = None
40 |
41 | @staticmethod
42 | def extra_parser(args):
43 | """Returns a parser containing the approach specific parameters"""
44 | parser = ArgumentParser()
45 | return parser.parse_known_args(args)
46 |
47 | @staticmethod
48 | def exemplars_dataset_class():
49 | """Returns a exemplar dataset to use during the training if the approach needs it
50 | :return: ExemplarDataset class or None
51 | """
52 | return None
53 |
54 | def _get_optimizer(self):
55 | """Returns the optimizer"""
56 | return torch.optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
57 |
58 | def train(self, t, trn_loader, val_loader):
59 | """Main train structure"""
60 | self.pre_train_process(t, trn_loader)
61 | self.train_loop(t, trn_loader, val_loader)
62 | self.post_train_process(t, trn_loader, val_loader)
63 |
64 | def pre_train_process(self, t, trn_loader):
65 | """Runs before training all epochs of the task (before the train session)"""
66 |
67 | # Warm-up phase
68 | if self.warmup_epochs and t > 0:
69 | self.optimizer = torch.optim.SGD(self.model.heads[-1].parameters(), lr=self.warmup_lr)
70 | # Loop epochs -- train warm-up head
71 | for e in range(self.warmup_epochs):
72 | warmupclock0 = time.time()
73 | self.model.heads[-1].train()
74 | for images, targets in trn_loader:
75 | outputs = self.model(images.to(self.device))
76 | loss = self.warmup_loss(outputs[t], targets.to(self.device) - self.model.task_offset[t])
77 | self.optimizer.zero_grad()
78 | loss.backward()
79 | torch.nn.utils.clip_grad_norm_(self.model.heads[-1].parameters(), self.clipgrad)
80 | self.optimizer.step()
81 | warmupclock1 = time.time()
82 | with torch.no_grad():
83 | total_loss, total_acc_taw = 0, 0
84 | self.model.eval()
85 | for images, targets in trn_loader:
86 | outputs = self.model(images.to(self.device))
87 | loss = self.warmup_loss(outputs[t], targets.to(self.device) - self.model.task_offset[t])
88 | pred = torch.zeros_like(targets.to(self.device))
89 | for m in range(len(pred)):
90 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum()
91 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task]
92 | hits_taw = (pred == targets.to(self.device)).float()
93 | total_loss += loss.item() * len(targets)
94 | total_acc_taw += hits_taw.sum().item()
95 | total_num = len(trn_loader.dataset.labels)
96 | trn_loss, trn_acc = total_loss / total_num, total_acc_taw / total_num
97 | warmupclock2 = time.time()
98 | if self.local_rank == 0:
99 | print('| Warm-up Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format(
100 | e + 1, warmupclock1 - warmupclock0, warmupclock2 - warmupclock1, trn_loss, 100 * trn_acc))
101 | self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=trn_loss, group="warmup")
102 | self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * trn_acc, group="warmup")
103 |
104 |
105 | def train_loop(self, t, trn_loader, val_loader):
106 | """Contains the epochs loop"""
107 | #######################
108 | # best_acc = 0
109 | # if self.ddp:
110 | # best_model = self.model.module.state_dict()
111 | # else:
112 | # best_model = self.model.state_dict()
113 | #######################
114 |
115 | self.optimizer = self._get_optimizer()
116 | scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.decay_mile_stone, gamma=self.lr_decay)
117 |
118 | # Loop epochs
119 | for e in range(self.nepochs):
120 | # Train
121 | clock0 = time.time()
122 | self.train_epoch(t, trn_loader)
123 | clock1 = time.time()
124 | if self.eval_on_train:
125 | train_loss, train_acc, _ = self.eval(t, trn_loader)
126 | clock2 = time.time()
127 | if self.local_rank == 0:
128 | print('| Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format(
129 | e + 1, clock1 - clock0, clock2 - clock1, train_loss, 100 * train_acc), end='')
130 | else:
131 | if self.local_rank == 0:
132 | print('| Epoch {:3d}, time={:5.1f}s | Train: skip eval |'.format(e + 1, clock1 - clock0), end='')
133 |
134 | # Valid
135 | clock3 = time.time()
136 | valid_loss, valid_acc, _ = self.eval(t, val_loader)
137 | clock4 = time.time()
138 | if self.local_rank == 0:
139 | print(' Valid: time={:5.1f}s loss={:.3f}, TAw acc={:5.1f}% |'.format(
140 | clock4 - clock3, valid_loss, 100 * valid_acc), end='')
141 |
142 | scheduler.step()
143 | #######################
144 | # if valid_acc > best_acc:
145 | # if self.ddp:
146 | # best_model = deepcopy(self.model.module.state_dict())
147 | # else:
148 | # best_model = deepcopy(self.model.state_dict())
149 | # best_acc = valid_acc
150 | #######################
151 | if self.local_rank == 0:
152 | print()
153 |
154 | #######################
155 | # if self.ddp:
156 | # self.model.module.set_state_dict(best_model)
157 | # else:
158 | # self.model.set_state_dict(best_model)
159 | #######################
160 |
161 | def post_train_process(self, t, trn_loader, val_loader):
162 | """Runs after training all the epochs of the task (after the train session)"""
163 | pass
164 |
165 | def train_epoch(self, t, trn_loader):
166 | """Runs a single epoch"""
167 | self.model.train()
168 | if self.fix_bn and t > 0:
169 | self.model.freeze_bn()
170 | for images, targets in trn_loader:
171 | # Forward current model
172 | outputs = self.model(images.to(self.device))
173 | loss = self.criterion(t, outputs, targets.to(self.device))
174 | # Backward
175 | self.optimizer.zero_grad()
176 | loss.backward()
177 | # clipgrad < 0 implicitly implies disabling gradient clipping
178 | if self.clipgrad > 0:
179 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
180 | self.optimizer.step()
181 |
182 | def eval(self, t, val_loader):
183 | """Contains the evaluation code"""
184 | with torch.no_grad():
185 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
186 | self.model.eval()
187 | for images, targets in val_loader:
188 | # Forward current model
189 | outputs = self.model(images.to(self.device))
190 | loss = self.criterion(t, outputs, targets.to(self.device))
191 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
192 |
193 | # if self.ddp:
194 | # hits_taw, hits_tag = reduce_tensor_mean(hits_taw, self.world_size), reduce_tensor_mean(hits_tag, self.world_size)
195 | # loss = reduce_tensor_mean(loss, self.world_size)
196 |
197 | total_loss += loss.item() * len(targets)
198 | total_acc_taw += hits_taw.sum().item()
199 | total_acc_tag += hits_tag.sum().item()
200 | total_num += len(targets)
201 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
202 |
203 | def calculate_metrics(self, outputs, targets):
204 | """Contains the main Task-Aware and Task-Agnostic metrics"""
205 | pred = torch.zeros_like(targets.to(self.device))
206 | # Task-Aware Multi-Head
207 | if self.ddp:
208 | for m in range(len(pred)):
209 | this_task = (self.model.module.task_cls.cumsum(0) <= targets[m]).sum()
210 | pred[m] = outputs[this_task][m].argmax() + self.model.module.task_offset[this_task]
211 | else:
212 | for m in range(len(pred)):
213 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum()
214 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task]
215 | hits_taw = (pred == targets.to(self.device)).float()
216 | # Task-Agnostic Multi-Head
217 | if self.multi_softmax:
218 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs]
219 | pred = torch.cat(outputs, dim=1).argmax(1)
220 | else:
221 | pred = torch.cat(outputs, dim=1).argmax(1)
222 | hits_tag = (pred == targets.to(self.device)).float()
223 | return hits_taw, hits_tag
224 |
225 | def criterion(self, t, outputs, targets):
226 | """Returns the loss value"""
227 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
228 |
--------------------------------------------------------------------------------
/src/approach/joint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 | from torch.utils.data import DataLoader, Dataset
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 | from .lucir_utils import CosineLinear, BasicBlockNoRelu, BottleneckNoRelu
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the joint baseline"""
11 |
12 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
13 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
14 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None,
15 | lamb=5., lamb_mr=1., dist=0.5, K=2):
16 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
17 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank,
18 | logger, exemplars_dataset)
19 | self.trn_datasets = []
20 | self.val_datasets = []
21 |
22 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
23 | assert (have_exemplars == 0), 'Warning: Joint does not use exemplars. Comment this line to force it.'
24 |
25 | @staticmethod
26 | def exemplars_dataset_class():
27 | return ExemplarsDataset
28 |
29 | @staticmethod
30 | def extra_parser(args):
31 | """Returns a parser containing the approach specific parameters"""
32 | parser = ArgumentParser()
33 | return parser.parse_known_args(args)
34 |
35 | def pre_train_process(self, t, trn_loader):
36 | """Runs before training all epochs of the task (before the train session)"""
37 | if self.ddp:
38 | model = self.model.module
39 | else:
40 | model = self.model
41 |
42 | if t == 0:
43 | # Sec. 4.1: "the ReLU in the penultimate layer is removed to allow the features to take both positive and
44 | # negative values"
45 | if model.model.__class__.__name__ == 'ResNetCifar':
46 | old_block = model.model.layer3[-1]
47 | model.model.layer3[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu,
48 | old_block.conv2, old_block.bn2, old_block.downsample)
49 | elif model.model.__class__.__name__ == 'ResNet':
50 | old_block = model.model.layer4[-1]
51 | model.model.layer4[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu,
52 | old_block.conv2, old_block.bn2, old_block.downsample)
53 | elif model.model.__class__.__name__ == 'ResNetBottleneck':
54 | old_block = model.model.layer4[-1]
55 | model.model.layer4[-1] = BottleneckNoRelu(old_block.conv1, old_block.bn1,
56 | old_block.relu, old_block.conv2, old_block.bn2,
57 | old_block.conv3, old_block.bn3, old_block.downsample)
58 | else:
59 | warnings.warn("Warning: ReLU not removed from last block.")
60 | # Changes the new head to a CosineLinear
61 | model.heads[-1] = CosineLinear(model.heads[-1].in_features, model.heads[-1].out_features)
62 | model.to(self.device)
63 |
64 | # if ddp option is activated, need to re-wrap the ddp model
65 | # yujun: debug to make sure this one is ok
66 | if self.ddp:
67 | self.model = DDP(self.model.module, device_ids=[self.local_rank])
68 | # The original code has an option called "imprint weights" that seems to initialize the new head.
69 | # However, this is not mentioned in the paper and doesn't seem to make a significant difference.
70 | super().pre_train_process(t, trn_loader)
71 |
72 | def post_train_process(self, t, trn_loader, val_loader):
73 | """Runs after training all the epochs of the task (after the train session)"""
74 | pass
75 |
76 | def train_loop(self, t, trn_loader, val_loader):
77 | """Contains the epochs loop"""
78 |
79 | # add new datasets to existing cumulative ones
80 | self.trn_datasets.append(trn_loader.dataset)
81 | self.val_datasets.append(val_loader.dataset)
82 | trn_dset = JointDataset(self.trn_datasets)
83 | val_dset = JointDataset(self.val_datasets)
84 | trn_loader = DataLoader(trn_dset,
85 | batch_size=trn_loader.batch_size,
86 | shuffle=True,
87 | num_workers=trn_loader.num_workers,
88 | pin_memory=trn_loader.pin_memory)
89 | val_loader = DataLoader(val_dset,
90 | batch_size=val_loader.batch_size,
91 | shuffle=False,
92 | num_workers=val_loader.num_workers,
93 | pin_memory=val_loader.pin_memory)
94 | # continue training as usual
95 | super().train_loop(t, trn_loader, val_loader)
96 |
97 | def train_epoch(self, t, trn_loader):
98 | """Runs a single epoch"""
99 | self.model.train()
100 | if self.fix_bn and t > 0:
101 | self.model.freeze_bn()
102 |
103 | for images, targets in trn_loader:
104 | images, targets = images.to(self.device), targets.to(self.device)
105 | # Forward current model
106 | outputs = self.model(images)
107 | loss = self.criterion(t, outputs, targets)
108 | # Backward
109 | self.optimizer.zero_grad()
110 | loss.backward()
111 | self.optimizer.step()
112 |
113 | def criterion(self, t, outputs, targets):
114 | """Returns the loss value"""
115 | if type(outputs[0])==dict:
116 | outputs = torch.cat([o['wsigma'] for o in outputs], dim=1)
117 | else:
118 | outputs = torch.cat([o for o in outputs], dim=1)
119 | return torch.nn.functional.cross_entropy(outputs, targets)
120 |
121 | class JointDataset(Dataset):
122 | """Characterizes a dataset for PyTorch -- this dataset accumulates each task dataset incrementally"""
123 |
124 | def __init__(self, datasets):
125 | self.datasets = datasets
126 | self._len = sum([len(d) for d in self.datasets])
127 |
128 | def __len__(self):
129 | 'Denotes the total number of samples'
130 | return self._len
131 |
132 | def __getitem__(self, index):
133 | for d in self.datasets:
134 | if len(d) <= index:
135 | index -= len(d)
136 | else:
137 | x, y = d[index]
138 | return x, y
139 |
--------------------------------------------------------------------------------
/src/approach/lucir_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from torch.nn import Module, Parameter
6 |
7 |
8 | # Sec 3.2: This class implements the cosine normalizing linear layer module using Eq. 4
9 | class CosineLinear(Module):
10 | def __init__(self, in_features, out_features, sigma=True):
11 | super(CosineLinear, self).__init__()
12 | self.in_features = in_features
13 | self.out_features = out_features
14 | self.weight = Parameter(torch.Tensor(out_features, in_features))
15 | if sigma:
16 | self.sigma = Parameter(torch.Tensor(1))
17 | else:
18 | self.register_parameter('sigma', None)
19 | self.reset_parameters()
20 |
21 | def reset_parameters(self):
22 | stdv = 1. / math.sqrt(self.weight.size(1))
23 | self.weight.data.uniform_(-stdv, stdv)
24 | if self.sigma is not None:
25 | self.sigma.data.fill_(1) # for initializaiton of sigma
26 |
27 | def forward(self, input):
28 | out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
29 | if self.sigma is not None:
30 | out_s = self.sigma * out
31 | else:
32 | out_s = out
33 | if self.training:
34 | return {'wsigma': out_s, 'wosigma': out}
35 | else:
36 | return out_s
37 |
38 |
39 | # This class implements a ResNet Basic Block without the final ReLu in the forward
40 | class BasicBlockNoRelu(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, conv1, bn1, relu, conv2, bn2, downsample):
44 | super(BasicBlockNoRelu, self).__init__()
45 | self.conv1 = conv1
46 | self.bn1 = bn1
47 | self.relu = relu
48 | self.conv2 = conv2
49 | self.bn2 = bn2
50 | self.downsample = downsample
51 |
52 | def forward(self, x):
53 | residual = x
54 | out = self.relu(self.bn1(self.conv1(x)))
55 | out = self.bn2(self.conv2(out))
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 | out += residual
59 | # Removed final ReLU
60 | return out
61 |
62 | class BottleneckNoRelu(nn.Module):
63 | expansion = 4
64 |
65 | def __init__(self, conv1, bn1, relu, conv2, bn2, conv3, bn3, downsample):
66 | super(BottleneckNoRelu, self).__init__()
67 | self.conv1 = conv1
68 | self.bn1 = bn1
69 | self.conv2 = conv2
70 | self.bn2 = bn2
71 | self.conv3 = conv3
72 | self.bn3 = bn3
73 | self.relu = relu
74 | self.downsample = downsample
75 |
76 | def forward(self, x):
77 | identity = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | identity = self.downsample(x)
92 |
93 | out += identity
94 | # Removed final ReLU
95 | return out
--------------------------------------------------------------------------------
/src/approach/lwf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 | from argparse import ArgumentParser
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Learning Without Forgetting (LwF) approach
11 | described in https://arxiv.org/abs/1606.09282
12 | """
13 |
14 | # Weight decay of 0.0005 is used in the original article (page 4).
15 | # Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our
16 | # method or the compared Less Forgetting Learning (see Table 2(b))."
17 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
18 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
19 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None,
20 | lamb=1, T=2):
21 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
22 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank,
23 | logger, exemplars_dataset)
24 | self.model_old = None
25 | self.lamb = lamb
26 | self.T = T
27 |
28 | @staticmethod
29 | def exemplars_dataset_class():
30 | return ExemplarsDataset
31 |
32 | @staticmethod
33 | def extra_parser(args):
34 | """Returns a parser containing the approach specific parameters"""
35 | parser = ArgumentParser()
36 | # Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor
37 | # the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by
38 | # changing lambda."
39 | parser.add_argument('--lamb', default=1, type=float, required=False,
40 | help='Forgetting-intransigence trade-off (default=%(default)s)')
41 | # Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’
42 | # recommendations." -- Using a higher value for T produces a softer probability distribution over classes.
43 | parser.add_argument('--T', default=2, type=int, required=False,
44 | help='Temperature scaling (default=%(default)s)')
45 | return parser.parse_known_args(args)
46 |
47 | def _get_optimizer(self):
48 | """Returns the optimizer"""
49 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
50 | # if there are no exemplars, previous heads are not modified
51 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
52 | else:
53 | params = self.model.parameters()
54 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
55 |
56 | def train_loop(self, t, trn_loader, val_loader):
57 | """Contains the epochs loop"""
58 |
59 | # add exemplars to train_loader
60 | if len(self.exemplars_dataset) > 0 and t > 0:
61 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
62 | batch_size=trn_loader.batch_size,
63 | shuffle=True,
64 | num_workers=trn_loader.num_workers,
65 | pin_memory=trn_loader.pin_memory)
66 |
67 | # FINETUNING TRAINING -- contains the epochs loop
68 | super().train_loop(t, trn_loader, val_loader)
69 |
70 | # EXEMPLAR MANAGEMENT -- select training subset
71 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform, self.ddp)
72 |
73 | def post_train_process(self, t, trn_loader, val_loader):
74 | """Runs after training all the epochs of the task (after the train session)"""
75 |
76 | # Restore best and save model for future tasks
77 | self.model_old = deepcopy(self.model)
78 | self.model_old.eval()
79 | self.model_old.freeze_all()
80 |
81 | def train_epoch(self, t, trn_loader):
82 | """Runs a single epoch"""
83 | self.model.train()
84 | if self.fix_bn and t > 0:
85 | self.model.freeze_bn()
86 | for images, targets in trn_loader:
87 | # Forward old model
88 | targets_old = None
89 | if t > 0:
90 | targets_old = self.model_old(images.to(self.device))
91 | # Forward current model
92 | outputs = self.model(images.to(self.device))
93 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
94 | # Backward
95 | self.optimizer.zero_grad()
96 | loss.backward()
97 | self.optimizer.step()
98 |
99 | def eval(self, t, val_loader):
100 | """Contains the evaluation code"""
101 | with torch.no_grad():
102 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
103 | self.model.eval()
104 | for images, targets in val_loader:
105 | # Forward old model
106 | targets_old = None
107 | if t > 0:
108 | targets_old = self.model_old(images.to(self.device))
109 | # Forward current model
110 | outputs = self.model(images.to(self.device))
111 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
112 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
113 | # Log
114 | total_loss += loss.data.cpu().numpy().item() * len(targets)
115 | total_acc_taw += hits_taw.sum().data.cpu().numpy().item()
116 | total_acc_tag += hits_tag.sum().data.cpu().numpy().item()
117 | total_num += len(targets)
118 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
119 |
120 | def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):
121 | """Calculates cross-entropy with temperature scaling"""
122 | out = torch.nn.functional.softmax(outputs, dim=1)
123 | tar = torch.nn.functional.softmax(targets, dim=1)
124 | if exp != 1:
125 | out = out.pow(exp)
126 | out = out / out.sum(1).view(-1, 1).expand_as(out)
127 | tar = tar.pow(exp)
128 | tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)
129 | out = out + eps / out.size(1)
130 | out = out / out.sum(1).view(-1, 1).expand_as(out)
131 | ce = -(tar * out.log()).sum(1)
132 | if size_average:
133 | ce = ce.mean()
134 | return ce
135 |
136 | def criterion(self, t, outputs, targets, outputs_old=None):
137 | """Returns the loss value"""
138 | loss = 0
139 | if t > 0:
140 | # Knowledge distillation loss for all previous tasks
141 | loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1),
142 | torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T)
143 | # Current cross-entropy loss -- with exemplars use all heads
144 | if len(self.exemplars_dataset) > 0:
145 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
146 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
147 |
--------------------------------------------------------------------------------
/src/approach/mas.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | from argparse import ArgumentParser
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Memory Aware Synapses (MAS) approach (global version)
11 | described in https://arxiv.org/abs/1711.09601
12 | Original code available at https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses
13 | """
14 |
15 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
16 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
17 | logger=None, exemplars_dataset=None, lamb=1, alpha=0.5, fi_num_samples=-1):
18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
20 | exemplars_dataset)
21 | self.lamb = lamb
22 | self.alpha = alpha
23 | self.num_samples = fi_num_samples
24 |
25 | # In all cases, we only keep importance weights for the model, but not for the heads.
26 | feat_ext = self.model.model
27 | # Store current parameters as the initial parameters before first task starts
28 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad}
29 | # Store fisher information weight importance
30 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
31 | if p.requires_grad}
32 |
33 | @staticmethod
34 | def exemplars_dataset_class():
35 | return ExemplarsDataset
36 |
37 | @staticmethod
38 | def extra_parser(args):
39 | """Returns a parser containing the approach specific parameters"""
40 | parser = ArgumentParser()
41 | # Eq. 3: lambda is the regularizer trade-off -- In original code: MAS.ipynb block [4]: lambda set to 1
42 | parser.add_argument('--lamb', default=1, type=float, required=False,
43 | help='Forgetting-intransigence trade-off (default=%(default)s)')
44 | # Define how old and new importance is fused, by default it is a 50-50 fusion
45 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
46 | help='MAS alpha (default=%(default)s)')
47 | # Number of samples from train for estimating importance
48 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False,
49 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
50 | return parser.parse_known_args(args)
51 |
52 | def _get_optimizer(self):
53 | """Returns the optimizer"""
54 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
55 | # if there are no exemplars, previous heads are not modified
56 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
57 | else:
58 | params = self.model.parameters()
59 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
60 |
61 | # Section 4.1: MAS (global) is implemented since the paper shows is more efficient than l-MAS (local)
62 | def estimate_parameter_importance(self, trn_loader):
63 | # Initialize importance matrices
64 | importance = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
65 | if p.requires_grad}
66 | # Compute fisher information for specified number of samples -- rounded to the batch size
67 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
68 | else (len(trn_loader.dataset) // trn_loader.batch_size)
69 | # Do forward and backward pass to accumulate L2-loss gradients
70 | self.model.train()
71 | for images, targets in itertools.islice(trn_loader, n_samples_batches):
72 | # MAS allows any unlabeled data to do the estimation, we choose the current data as in main experiments
73 | outputs = self.model.forward(images.to(self.device))
74 | # Page 6: labels not required, "...use the gradients of the squared L2-norm of the learned function output."
75 | loss = torch.norm(torch.cat(outputs, dim=1), p=2, dim=1).mean()
76 | self.optimizer.zero_grad()
77 | loss.backward()
78 | # Eq. 2: accumulate the gradients over the inputs to obtain importance weights
79 | for n, p in self.model.model.named_parameters():
80 | if p.grad is not None:
81 | importance[n] += p.grad.abs() * len(targets)
82 | # Eq. 2: divide by N total number of samples
83 | n_samples = n_samples_batches * trn_loader.batch_size
84 | importance = {n: (p / n_samples) for n, p in importance.items()}
85 | return importance
86 |
87 | def train_loop(self, t, trn_loader, val_loader):
88 | """Contains the epochs loop"""
89 |
90 | # add exemplars to train_loader
91 | if len(self.exemplars_dataset) > 0 and t > 0:
92 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
93 | batch_size=trn_loader.batch_size,
94 | shuffle=True,
95 | num_workers=trn_loader.num_workers,
96 | pin_memory=trn_loader.pin_memory)
97 |
98 | # FINETUNING TRAINING -- contains the epochs loop
99 | super().train_loop(t, trn_loader, val_loader)
100 |
101 | # EXEMPLAR MANAGEMENT -- select training subset
102 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
103 |
104 | def post_train_process(self, t, trn_loader):
105 | """Runs after training all the epochs of the task (after the train session)"""
106 |
107 | # Store current parameters for the next task
108 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
109 |
110 | # calculate Fisher information
111 | curr_importance = self.estimate_parameter_importance(trn_loader)
112 | # merge fisher information, we do not want to keep fisher information for each task in memory
113 | for n in self.importance.keys():
114 | # Added option to accumulate importance over time with a pre-fixed growing alpha
115 | if self.alpha == -1:
116 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
117 | self.importance[n] = alpha * self.importance[n] + (1 - alpha) * curr_importance[n]
118 | else:
119 | # As in original code: MAS_utils/MAS_based_Training.py line 638 -- just add prev and new
120 | self.importance[n] = self.alpha * self.importance[n] + (1 - self.alpha) * curr_importance[n]
121 |
122 | def criterion(self, t, outputs, targets):
123 | """Returns the loss value"""
124 | loss = 0
125 | if t > 0:
126 | loss_reg = 0
127 | # Eq. 3: memory aware synapses regularizer penalty
128 | for n, p in self.model.model.named_parameters():
129 | if n in self.importance.keys():
130 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2)) / 2
131 | loss += self.lamb * loss_reg
132 | # Current cross-entropy loss -- with exemplars use all heads
133 | if len(self.exemplars_dataset) > 0:
134 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
135 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
136 |
--------------------------------------------------------------------------------
/src/approach/path_integral.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the Path Integral (aka Synaptic Intelligence) approach
10 | described in http://proceedings.mlr.press/v70/zenke17a.html
11 | Original code available at https://github.com/ganguli-lab/pathint
12 | """
13 |
14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
16 | logger=None, exemplars_dataset=None, lamb=0.1, damping=0.1):
17 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
18 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
19 | exemplars_dataset)
20 | self.lamb = lamb
21 | self.damping = damping
22 |
23 | # In all cases, we only keep importance weights for the model, but not for the heads.
24 | feat_ext = self.model.model
25 | # Page 3, following Eq. 3: "The w now have an intuitive interpretation as the parameter specific contribution to
26 | # changes in the total loss."
27 | self.w = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() if p.requires_grad}
28 | # Store current parameters as the initial parameters before first task starts
29 | self.older_params = {n: p.clone().detach().to(self.device) for n, p in feat_ext.named_parameters()
30 | if p.requires_grad}
31 | # Store importance weights matrices
32 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
33 | if p.requires_grad}
34 |
35 | @staticmethod
36 | def exemplars_dataset_class():
37 | return ExemplarsDataset
38 |
39 | @staticmethod
40 | def extra_parser(args):
41 | """Returns a parser containing the approach specific parameters"""
42 | parser = ArgumentParser()
43 | # Eq. 4: lamb is the 'c' trade-off parameter from the surrogate loss -- 1e-3 < c < 0.1
44 | parser.add_argument('--lamb', default=0.1, type=float, required=False,
45 | help='Forgetting-intransigence trade-off (default=%(default)s)')
46 | # Eq. 5: damping parameter is set to 0.1 in the MNIST case
47 | parser.add_argument('--damping', default=0.1, type=float, required=False,
48 | help='Damping (default=%(default)s)')
49 | return parser.parse_known_args(args)
50 |
51 | def _get_optimizer(self):
52 | """Returns the optimizer"""
53 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
54 | # if there are no exemplars, previous heads are not modified
55 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
56 | else:
57 | params = self.model.parameters()
58 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
59 |
60 | def train_loop(self, t, trn_loader, val_loader):
61 | """Contains the epochs loop"""
62 |
63 | # add exemplars to train_loader
64 | if len(self.exemplars_dataset) > 0 and t > 0:
65 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
66 | batch_size=trn_loader.batch_size,
67 | shuffle=True,
68 | num_workers=trn_loader.num_workers,
69 | pin_memory=trn_loader.pin_memory)
70 |
71 | # FINETUNING TRAINING -- contains the epochs loop
72 | super().train_loop(t, trn_loader, val_loader)
73 |
74 | # EXEMPLAR MANAGEMENT -- select training subset
75 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
76 |
77 | def post_train_process(self, t, trn_loader):
78 | """Runs after training all the epochs of the task (after the train session)"""
79 |
80 | # Eq. 5: accumulate Omega regularization strength (importance matrix)
81 | with torch.no_grad():
82 | curr_params = {n: p for n, p in self.model.model.named_parameters() if p.requires_grad}
83 | for n, p in self.importance.items():
84 | p += self.w[n] / ((curr_params[n] - self.older_params[n]) ** 2 + self.damping)
85 | self.w[n].zero_()
86 |
87 | # Store current parameters for the next task
88 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
89 |
90 | def train_epoch(self, t, trn_loader):
91 | """Runs a single epoch"""
92 | self.model.train()
93 | if self.fix_bn and t > 0:
94 | self.model.freeze_bn()
95 | for images, targets in trn_loader:
96 | # store current model without heads
97 | curr_feat_ext = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
98 |
99 | # Forward current model
100 | outputs = self.model(images.to(self.device))
101 | # theoretically this is the correct one for 2 tasks, however, for more tasks maybe is the current loss
102 | # check https://github.com/ganguli-lab/pathint/blob/master/pathint/optimizers.py line 123
103 | # cross-entropy loss on current task
104 | if len(self.exemplars_dataset) == 0:
105 | loss = torch.nn.functional.cross_entropy(outputs[t], targets.to(self.device) - self.model.task_offset[t])
106 | else:
107 | # with exemplars we check output from all heads (train data has all labels)
108 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets.to(self.device))
109 | self.optimizer.zero_grad()
110 | loss.backward(retain_graph=True)
111 | # store gradients without regularization term
112 | unreg_grads = {n: p.grad.clone().detach() for n, p in self.model.model.named_parameters()
113 | if p.grad is not None}
114 | # apply loss with path integral regularization
115 | loss = self.criterion(t, outputs, targets.to(self.device))
116 |
117 | # Backward
118 | self.optimizer.zero_grad()
119 | loss.backward()
120 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
121 | self.optimizer.step()
122 |
123 | # Eq. 3: accumulate w, compute the path integral -- "In practice, we can approximate w online as the running
124 | # sum of the product of the gradient with the parameter update".
125 | with torch.no_grad():
126 | for n, p in self.model.model.named_parameters():
127 | if n in unreg_grads.keys():
128 | # w[n] >=0, but minus for loss decrease
129 | self.w[n] -= unreg_grads[n] * (p.detach() - curr_feat_ext[n])
130 |
131 | def criterion(self, t, outputs, targets):
132 | """Returns the loss value"""
133 | loss = 0
134 | if t > 0:
135 | loss_reg = 0
136 | # Eq. 4: quadratic surrogate loss
137 | for n, p in self.model.model.named_parameters():
138 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2))
139 | loss += self.lamb * loss_reg
140 | # Current cross-entropy loss -- with exemplars use all heads
141 | if len(self.exemplars_dataset) > 0:
142 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
143 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
144 |
--------------------------------------------------------------------------------
/src/approach/r_walk.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | from argparse import ArgumentParser
4 | from torch.utils.data import DataLoader
5 |
6 | from .incremental_learning import Inc_Learning_Appr
7 | from datasets.exemplars_dataset import ExemplarsDataset
8 |
9 |
10 | class Appr(Inc_Learning_Appr):
11 | """Class implementing the Riemannian Walk (RWalk) approach described in
12 | http://openaccess.thecvf.com/content_ECCV_2018/papers/Arslan_Chaudhry__Riemannian_Walk_ECCV_2018_paper.pdf
13 | """
14 |
15 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
16 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
17 | logger=None, exemplars_dataset=None, lamb=1, alpha=0.5, damping=0.1, fim_sampling_type='max_pred',
18 | fim_num_samples=-1):
19 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
20 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
21 | exemplars_dataset)
22 | self.lamb = lamb
23 | self.alpha = alpha
24 | self.damping = damping
25 | self.sampling_type = fim_sampling_type
26 | self.num_samples = fim_num_samples
27 |
28 | # In all cases, we only keep importance weights for the model, but not for the heads.
29 | feat_ext = self.model.model
30 | # Page 7: "task-specific parameter importance over the entire training trajectory."
31 | self.w = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() if p.requires_grad}
32 | # Store current parameters as the initial parameters before first task starts
33 | self.older_params = {n: p.clone().detach().to(self.device) for n, p in feat_ext.named_parameters()
34 | if p.requires_grad}
35 | # Store scores and fisher information
36 | self.scores = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
37 | if p.requires_grad}
38 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
39 | if p.requires_grad}
40 |
41 | @staticmethod
42 | def exemplars_dataset_class():
43 | return ExemplarsDataset
44 |
45 | @staticmethod
46 | def extra_parser(args):
47 | """Returns a parser containing the approach specific parameters"""
48 | parser = ArgumentParser()
49 | # Eq. 5 and 8: "regularization hyperparameter lambda being less sensitive to the number of tasks. Whereas,
50 | # EWC and Path Integral are highly sensitive to lambda, making them relatively less reliable for IL"
51 | parser.add_argument('--lamb', default=1, type=float, required=False,
52 | help='Forgetting-intransigence trade-off (default=%(default)s)')
53 | # Define how old and new fisher is fused, by default it is a 50-50 fusion
54 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
55 | help='RWalk alpha (default=%(default)s)')
56 | # Damping parameter as in Path Integral
57 | parser.add_argument('--damping', default=0.1, type=float, required=False,
58 | help='(default=%(default)s)')
59 | parser.add_argument('--fim_sampling_type', default='max_pred', type=str, required=False,
60 | choices=['true', 'max_pred', 'multinomial'],
61 | help='Sampling type for Fisher information (default=%(default)s)')
62 | parser.add_argument('--fim_num_samples', default=-1, type=int, required=False,
63 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
64 | return parser.parse_known_args(args)
65 |
66 | def _get_optimizer(self):
67 | """Returns the optimizer"""
68 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
69 | # if there are no exemplars, previous heads are not modified
70 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
71 | else:
72 | params = self.model.parameters()
73 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
74 |
75 | def compute_fisher_matrix_diag(self, trn_loader):
76 | # Store Fisher Information
77 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
78 | if p.requires_grad}
79 | # Compute fisher information for specified number of samples -- rounded to the batch size
80 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
81 | else (len(trn_loader.dataset) // trn_loader.batch_size)
82 | # Do forward and backward pass to compute the fisher information
83 | self.model.train()
84 | for images, targets in itertools.islice(trn_loader, n_samples_batches):
85 | outputs = self.model.forward(images.to(self.device))
86 |
87 | if self.sampling_type == 'true':
88 | # Use the labels to compute the gradients based on the CE-loss with the ground truth
89 | preds = targets.to(self.device)
90 | elif self.sampling_type == 'max_pred':
91 | # Not use labels and compute the gradients related to the prediction the model has learned
92 | preds = torch.cat(outputs, dim=1).argmax(1).flatten()
93 | elif self.sampling_type == 'multinomial':
94 | # Use a multinomial sampling to compute the gradients
95 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1)
96 | preds = torch.multinomial(probs, len(targets)).flatten()
97 |
98 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds)
99 | self.optimizer.zero_grad()
100 | loss.backward()
101 | # Page 6: "the Fisher component [...] is the expected square of the loss gradient w.r.t the i-th parameter."
102 | for n, p in self.model.model.named_parameters():
103 | if p.grad is not None:
104 | fisher[n] += p.grad.pow(2) * len(targets)
105 | # Apply mean across all samples
106 | n_samples = n_samples_batches * trn_loader.batch_size
107 | fisher = {n: (p / n_samples) for n, p in fisher.items()}
108 | return fisher
109 |
110 | def train_loop(self, t, trn_loader, val_loader):
111 | """Contains the epochs loop"""
112 |
113 | # add exemplars to train_loader
114 | if len(self.exemplars_dataset) > 0 and t > 0:
115 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
116 | batch_size=trn_loader.batch_size,
117 | shuffle=True,
118 | num_workers=trn_loader.num_workers,
119 | pin_memory=trn_loader.pin_memory)
120 |
121 | # FINETUNING TRAINING -- contains the epochs loop
122 | super().train_loop(t, trn_loader, val_loader)
123 |
124 | # EXEMPLAR MANAGEMENT -- select training subset
125 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
126 |
127 | def post_train_process(self, t, trn_loader):
128 | """Runs after training all the epochs of the task (after the train session)"""
129 |
130 | # calculate Fisher Information Matrix
131 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader)
132 |
133 | # Eq. 10: efficiently update Fisher Information Matrix
134 | for n in self.fisher.keys():
135 | # Added option to accumulate fisher over time with a pre-fixed growing alpha
136 | if self.alpha == -1:
137 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
138 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n]
139 | else:
140 | self.fisher[n] = self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n]
141 | # Page 7: Optimization Path-based Parameter Importance: importance scores computation
142 | curr_score = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
143 | if p.requires_grad}
144 | with torch.no_grad():
145 | curr_params = {n: p for n, p in self.model.model.named_parameters() if p.requires_grad}
146 | for n, p in self.scores.items():
147 | curr_score[n] = self.w[n] / (
148 | self.fisher[n] * ((curr_params[n] - self.older_params[n]) ** 2) + self.damping)
149 | self.w[n].zero_()
150 | # Page 7: "Since we care about positive influence of the parameters, negative scores are set to zero."
151 | curr_score[n] = torch.nn.functional.relu(curr_score[n])
152 | # Page 8: alleviating regularization getting increasingly rigid by averaging scores
153 | for n, p in self.scores.items():
154 | self.scores[n] = (self.scores[n] + curr_score[n]) / 2
155 |
156 | # Store current parameters for the next task
157 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
158 |
159 | def train_epoch(self, t, trn_loader):
160 | """Runs a single epoch"""
161 | self.model.train()
162 | if self.fix_bn and t > 0:
163 | self.model.freeze_bn()
164 | for images, targets in trn_loader:
165 | # store current model
166 | curr_feat_ext = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
167 |
168 | # Forward current model
169 | outputs = self.model(images.to(self.device))
170 | # cross-entropy loss on current task
171 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets.to(self.device))
172 | self.optimizer.zero_grad()
173 | loss.backward(retain_graph=True)
174 | # store gradients without regularization term
175 | unreg_grads = {n: p.grad.clone().detach() for n, p in self.model.model.named_parameters()
176 | if p.grad is not None}
177 | # apply loss with path integral regularization
178 | loss = self.criterion(t, outputs, targets.to(self.device))
179 |
180 | # Backward
181 | self.optimizer.zero_grad()
182 | loss.backward()
183 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
184 | self.optimizer.step()
185 |
186 | # Page 7: "accumulate task-specific parameter importance over the entire training trajectory"
187 | # "the parameter importance is defined as the ratio of the change in the loss function to the distance
188 | # between the conditional likelihod distributions per step in the parameter space."
189 | with torch.no_grad():
190 | for n, p in self.model.model.named_parameters():
191 | if n in unreg_grads.keys():
192 | self.w[n] -= unreg_grads[n] * (p.detach() - curr_feat_ext[n])
193 |
194 | def criterion(self, t, outputs, targets):
195 | """Returns the loss value"""
196 | loss = 0
197 | if t > 0:
198 | loss_reg = 0
199 | # Eq. 9: final objective function
200 | for n, p in self.model.model.named_parameters():
201 | if n in self.fisher.keys():
202 | loss_reg += torch.sum((self.fisher[n] + self.scores[n]) * (p - self.older_params[n]).pow(2))
203 | loss += self.lamb * loss_reg
204 | # Current cross-entropy loss -- with exemplars use all heads
205 | if len(self.exemplars_dataset) > 0:
206 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
207 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
208 |
--------------------------------------------------------------------------------
/src/approach/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import transforms
6 | from torch import distributed as dist
7 |
8 |
9 | # ------------------- SSL utils -----------------------
10 | # Supervised learning and SSL
11 | class TransformSLAndSSL:
12 | def __init__(self, orig_transform, transform_ssl):
13 | self.orig_transform = orig_transform
14 | self.transform_ssl = transform_ssl
15 |
16 | def __call__(self, inp):
17 | out = self.orig_transform(inp)
18 | out_ssl_1 = self.transform_ssl(inp)
19 | out_ssl_2 = self.transform_ssl(inp)
20 | return out, out_ssl_1, out_ssl_2
21 |
22 | # (H,W): data input size
23 | def get_simclr_transforms(H, W, orig_transform):
24 | simclr_aug = transforms.Compose([
25 | transforms.RandomResizedCrop((H,W)),
26 | transforms.RandomHorizontalFlip(), # with 0.5 probability
27 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
28 | transforms.RandomGrayscale(p=0.2),
29 | transforms.ToTensor(),])
30 | return TransformSLAndSSL(orig_transform, simclr_aug)
31 | # -----------------------------------------------------
32 |
33 | # transform a scalar to multihot vector
34 | # targets: (N,)
35 | def scalar2onehot(targets, num_class):
36 | N = targets.shape[0]
37 | onehot_target = torch.zeros(N, num_class).to(targets.device).scatter_(1, targets.unsqueeze(-1), 1)
38 | return onehot_target
39 |
40 | def scalar2SmoothOneHot(targets, num_class):
41 | N = targets.shape[0]
42 | hot_prob = 0.9
43 | smooth_prob = (1.0 - hot_prob) / (N-1)
44 | onehot_target = (smooth_prob*torch.ones(N, num_class)).to(targets.device).scatter_(1, targets.unsqueeze(-1), hot_prob)
45 | return onehot_target
46 |
47 | # Cutmix
48 | # assuming the targets are one-hot vectors
49 | def cut_and_mix(images, targets):
50 | N,_,H,W = images.shape
51 | assert H==W, 'only support square right now'
52 |
53 | cut_len = int(np.floor(np.random.rand()*H))
54 | cut_len = max(min(cut_len, H-1), 1)
55 | mix_ratio = float(cut_len*cut_len) / (H*W)
56 |
57 | top = np.random.randint(0, H - cut_len)
58 | left = np.random.randint(0, W - cut_len)
59 | bottom = top + cut_len
60 | right = left + cut_len
61 |
62 | # cut and mix
63 | # shuffled batch images
64 | rp = torch.randperm(N)
65 | shuffled_images = images[rp,:,:,:]
66 | shuffled_targets = targets[rp,:]
67 | images[:, :, top:bottom, left:right] = shuffled_images[:, :, top:bottom, left:right]
68 |
69 | # adjust the target
70 | targets = (1-mix_ratio)*targets + mix_ratio*shuffled_targets
71 |
72 | return images, targets
73 |
74 | # Cutmix w/ small window
75 | # assuming the targets are one-hot vectors
76 | def cut_and_mix_small_window(images, targets):
77 | N,_,H,W = images.shape
78 | assert H==W, 'only support square right now'
79 |
80 | cut_len = int(np.floor(0.4*np.random.rand()*H))
81 | cut_len = max(min(cut_len, H-1), 1)
82 | mix_ratio = float(cut_len*cut_len) / (H*W)
83 |
84 | top = np.random.randint(0, H - cut_len)
85 | left = np.random.randint(0, W - cut_len)
86 | bottom = top + cut_len
87 | right = left + cut_len
88 |
89 | # cut and mix
90 | # shuffled batch images
91 | rp = torch.randperm(N)
92 | shuffled_images = images[rp,:,:,:]
93 | # shuffled_targets = targets[rp,:]
94 | images[:, :, top:bottom, left:right] = shuffled_images[:, :, top:bottom, left:right]
95 |
96 | # do not adjust the target
97 | # targets = (1-mix_ratio)*targets + mix_ratio*shuffled_targets
98 |
99 | return images, targets
100 |
101 | # Mixup
102 | # assuming the targets are one-hot vectors
103 | def mixup(images, targets):
104 | lam = np.random.beta(1.0, 1.0)
105 | N = images.shape[0]
106 | rp = torch.randperm(N)
107 | shuffled_images = images[rp,:,:,:]
108 | shuffled_targets = targets[rp,:]
109 |
110 | images = lam * images + (1-lam)*shuffled_images
111 | targets = lam * targets + (1-lam)*shuffled_targets
112 |
113 | return images, targets
114 |
115 | class MultiLabelCrossEntropyLoss(nn.Module):
116 |
117 | def __init__(self):
118 | super(MultiLabelCrossEntropyLoss, self).__init__()
119 |
120 | # logit: (N, C)
121 | # label: (N, C)
122 | def forward(self, logits, label):
123 | loss = -(label*F.log_softmax(logits, dim=1)).sum(dim=-1).mean()
124 | return loss
125 |
126 | class BalancedCrossEntropy(nn.Module):
127 |
128 | def __init__(self, tao):
129 | super(BalancedCrossEntropy, self).__init__()
130 | self.tao = tao
131 | self.eps = 1e-8
132 |
133 | def forward(self, logits, label):
134 | num_classes = logits.shape[1]
135 | label_onehot = scalar2onehot(label, num_classes)
136 |
137 | # v1 (undesired solution)
138 | # loss = -self.tao*(logits*label).sum(dim=-1) + (2.0 - self.tao)*torch.log(self.eps + logits.exp().sum(dim=-1))
139 |
140 | # v2 (undesired solution)
141 | loss = -self.tao*(logits*label_onehot).sum(dim=-1) + torch.log(self.eps + logits.exp().sum(dim=-1))
142 |
143 | # v3 (broken solution)
144 | # pos_loss = -self.tao*(logits*label).sum(dim=-1)
145 | # neg_loss = (2.0 - self.tao)*torch.log(self.eps + ((1.-label)*logits.exp()).sum(dim=-1))
146 | # loss = pos_loss + neg_loss
147 |
148 | # v4 (broken solution)
149 | # pos_loss = -self.tao*(logits*label).sum(dim=-1)
150 | # reweight = 1. - (1. - self.tao)*label
151 | # neg_loss = torch.log(self.eps + (reweight*logits).exp().sum(dim=-1))
152 | # loss = pos_loss + neg_loss
153 | # return loss.mean()
154 |
155 | # v5
156 | # reweight = 1. - (1. - self.tao)*label_onehot
157 | # logits = reweight*logits
158 | # loss = nn.CrossEntropyLoss(None)(logits, label)
159 | # return loss
160 |
161 | # pos_loss = -(logits*label_onehot).sum(dim=-1)
162 | # reweight = (1. - label_onehot)*self.tao + label_onehot
163 | # neg_loss = torch.log(self.eps + (reweight*logits.exp()).sum(dim=-1))
164 | # loss = pos_loss + neg_loss
165 |
166 | return loss.mean()
167 |
168 | class AugCrossEntropy(nn.Module):
169 |
170 | def __init__(self, n_aug):
171 | super(AugCrossEntropy, self).__init__()
172 | self.n_aug = n_aug
173 | self.ce = nn.CrossEntropyLoss()
174 |
175 | def forward(self, logits, features, targets):
176 | N,C = features.shape
177 | device = features.device
178 | # first generating random fake class embeddings
179 | pseudo_embeddings = F.normalize(torch.rand(self.n_aug, C, device=device) - 0.5, dim=-1)
180 | pseudo_logits = torch.matmul(features, pseudo_embeddings.t())
181 | cat_logits = torch.cat([logits, pseudo_logits], dim=-1)
182 | loss = self.ce(cat_logits, targets)
183 | return loss
184 |
185 |
186 | # -------------- DDP utils -------------------
187 | def reduce_tensor_mean(tensor, n):
188 | rt = tensor.clone()
189 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
190 | rt /= n
191 | return rt
192 |
193 | def reduce_tensor_sum(tensor):
194 | rt = tensor.clone()
195 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
196 | return rt
197 |
198 | def global_gather(x):
199 | all_x = [torch.ones_like(x) for _ in range(dist.get_world_size())]
200 | dist.all_gather(all_x, x, async_op=False)
201 | return torch.cat(all_x, dim=0)
202 |
203 | # differentiable gather layer
204 | class GatherLayer(torch.autograd.Function):
205 | """Gather tensors from all process, supporting backward propagation."""
206 |
207 | @staticmethod
208 | def forward(ctx, input):
209 | ctx.save_for_backward(input)
210 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
211 | dist.all_gather(output, input)
212 | return tuple(output)
213 |
214 | @staticmethod
215 | def backward(ctx, *grads):
216 | (input,) = ctx.saved_tensors
217 | grad_out = torch.zeros_like(input)
218 | grad_out[:] = grads[dist.get_rank()]
219 | return grad_out
220 | # --------------------------------------------
221 |
222 | class SAM(torch.optim.Optimizer):
223 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
224 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
225 |
226 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
227 | super(SAM, self).__init__(params, defaults)
228 |
229 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
230 | self.param_groups = self.base_optimizer.param_groups
231 |
232 | @torch.no_grad()
233 | def first_step(self, zero_grad=False):
234 | grad_norm = self._grad_norm()
235 | for group in self.param_groups:
236 | scale = group["rho"] / (grad_norm + 1e-12)
237 |
238 | for p in group["params"]:
239 | if p.grad is None: continue
240 | self.state[p]["old_p"] = p.data.clone()
241 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
242 | p.add_(e_w) # climb to the local maximum "w + e(w)"
243 |
244 | if zero_grad: self.zero_grad()
245 |
246 | @torch.no_grad()
247 | def second_step(self, zero_grad=False):
248 | for group in self.param_groups:
249 | for p in group["params"]:
250 | if p.grad is None: continue
251 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
252 |
253 | self.base_optimizer.step() # do the actual "sharpness-aware" update
254 |
255 | if zero_grad: self.zero_grad()
256 |
257 | @torch.no_grad()
258 | def step(self, closure=None):
259 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
260 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
261 |
262 | self.first_step(zero_grad=True)
263 | closure()
264 | self.second_step()
265 |
266 | def _grad_norm(self):
267 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
268 | norm = torch.norm(
269 | torch.stack([
270 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
271 | for group in self.param_groups for p in group["params"]
272 | if p.grad is not None
273 | ]),
274 | p=2
275 | )
276 | return norm
277 |
278 | def load_state_dict(self, state_dict):
279 | super().load_state_dict(state_dict)
280 | self.base_optimizer.param_groups = self.param_groups
281 |
--------------------------------------------------------------------------------
/src/data/imagenet/gen_lst_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | subset_num = 100
5 |
6 | root_dir = 'train'
7 | with open(root_dir+'_'+str(subset_num)+'.txt', 'w') as f:
8 | classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir())
9 | seed = 1993
10 | np.random.seed(seed)
11 | subset_classes = np.random.choice(classes, subset_num, replace=False)
12 |
13 | for class_id, class_name in enumerate(subset_classes):
14 | folder_name = os.path.join(root_dir, class_name)
15 | for img_name in sorted(os.listdir(folder_name)):
16 | write_line = os.path.join(root_dir, class_name, img_name)
17 | write_line += ' ' + str(class_id) + '\n'
18 | f.write(write_line)
19 |
20 | root_dir = 'val'
21 | with open(root_dir+'_'+str(subset_num)+'.txt', 'w') as f:
22 | classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir())
23 | seed = 1993
24 | np.random.seed(seed)
25 | subset_classes = np.random.choice(classes, subset_num, replace=False)
26 |
27 | for class_id, class_name in enumerate(subset_classes):
28 | folder_name = os.path.join(root_dir, class_name)
29 | for img_name in sorted(os.listdir(folder_name)):
30 | write_line = os.path.join(root_dir, class_name, img_name)
31 | write_line += ' ' + str(class_id) + '\n'
32 | f.write(write_line)
33 |
--------------------------------------------------------------------------------
/src/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class BaseDataset(Dataset):
9 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory"""
10 |
11 | def __init__(self, data, transform, class_indices=None):
12 | """Initialization"""
13 | self.labels = data['y']
14 | self.images = data['x']
15 | self.transform = transform
16 | self.class_indices = class_indices
17 |
18 | def __len__(self):
19 | """Denotes the total number of samples"""
20 | return len(self.images)
21 |
22 | def __getitem__(self, index):
23 | """Generates one sample of data"""
24 | x = Image.open(self.images[index]).convert('RGB')
25 | x = self.transform(x)
26 | y = self.labels[index]
27 | return x, y
28 |
29 |
30 | def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None, trn_lst=None, tst_lst=None):
31 | """Prepare data: dataset splits, task partition, class order"""
32 |
33 | data = {}
34 | taskcla = []
35 |
36 | # read filenames and labels
37 | if trn_lst is None and tst_lst is None:
38 | trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str)
39 | tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str)
40 | else:
41 | trn_lines = np.loadtxt(os.path.join(path, trn_lst), dtype=str)
42 | tst_lines = np.loadtxt(os.path.join(path, tst_lst), dtype=str)
43 |
44 | if class_order is None:
45 | num_classes = len(np.unique(trn_lines[:, 1]))
46 | class_order = list(range(num_classes))
47 | else:
48 | num_classes = len(class_order)
49 | class_order = class_order.copy()
50 | # yujun: a little hack here, in no case shall we shuffle the classes
51 | # if shuffle_classes:
52 | # np.random.shuffle(class_order)
53 |
54 | # compute classes per task and num_tasks
55 | if nc_first_task is None:
56 | cpertask = np.array([num_classes // num_tasks] * num_tasks)
57 | for i in range(num_classes % num_tasks):
58 | cpertask[i] += 1
59 | else:
60 | assert nc_first_task < num_classes, "first task wants more classes than exist"
61 | remaining_classes = num_classes - nc_first_task
62 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
63 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
64 | for i in range(remaining_classes % (num_tasks - 1)):
65 | cpertask[i + 1] += 1
66 |
67 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
68 | cpertask_cumsum = np.cumsum(cpertask)
69 | init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
70 |
71 | # initialize data structure
72 | for tt in range(num_tasks):
73 | data[tt] = {}
74 | data[tt]['name'] = 'task-' + str(tt)
75 | data[tt]['trn'] = {'x': [], 'y': []}
76 | data[tt]['val'] = {'x': [], 'y': []}
77 | data[tt]['tst'] = {'x': [], 'y': []}
78 |
79 | # ALL OR TRAIN
80 | for this_image, this_label in trn_lines:
81 | if not os.path.isabs(this_image):
82 | this_image = os.path.join(path, this_image)
83 | this_label = int(this_label)
84 | if this_label not in class_order:
85 | continue
86 | # If shuffling is false, it won't change the class number
87 | this_label = class_order.index(this_label)
88 |
89 | # add it to the corresponding split
90 | this_task = (this_label >= cpertask_cumsum).sum()
91 | data[this_task]['trn']['x'].append(this_image)
92 | data[this_task]['trn']['y'].append(this_label - init_class[this_task])
93 |
94 | # ALL OR TEST
95 | for this_image, this_label in tst_lines:
96 | if not os.path.isabs(this_image):
97 | this_image = os.path.join(path, this_image)
98 | this_label = int(this_label)
99 | if this_label not in class_order:
100 | continue
101 | # If shuffling is false, it won't change the class number
102 | this_label = class_order.index(this_label)
103 |
104 | # add it to the corresponding split
105 | this_task = (this_label >= cpertask_cumsum).sum()
106 | data[this_task]['tst']['x'].append(this_image)
107 | data[this_task]['tst']['y'].append(this_label - init_class[this_task])
108 |
109 | # check classes
110 | for tt in range(num_tasks):
111 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
112 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
113 |
114 | # validation
115 | if validation > 0.0:
116 | for tt in data.keys():
117 | for cc in range(data[tt]['ncla']):
118 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
119 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
120 | rnd_img.sort(reverse=True)
121 | for ii in range(len(rnd_img)):
122 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
123 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
124 | data[tt]['trn']['x'].pop(rnd_img[ii])
125 | data[tt]['trn']['y'].pop(rnd_img[ii])
126 |
127 | # other
128 | n = 0
129 | for t in data.keys():
130 | taskcla.append((t, data[t]['ncla']))
131 | n += data[t]['ncla']
132 | data['ncla'] = n
133 |
134 | return data, taskcla, class_order
135 |
--------------------------------------------------------------------------------
/src/datasets/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils import data
5 | import torchvision.transforms as transforms
6 | from torchvision.datasets import MNIST as TorchVisionMNIST
7 | from torchvision.datasets import CIFAR100 as TorchVisionCIFAR100
8 | from torchvision.datasets import SVHN as TorchVisionSVHN
9 |
10 | from . import base_dataset as basedat
11 | from . import memory_dataset as memd
12 | from .dataset_config import dataset_config
13 |
14 |
15 | def get_loaders(datasets, num_tasks, nc_first_task, batch_size, num_workers, pin_memory, validation=.1, ddp=False):
16 | """Apply transformations to Datasets and create the DataLoaders for each task"""
17 |
18 | trn_load, val_load, tst_load = [], [], []
19 | taskcla = []
20 | dataset_offset = 0
21 | for idx_dataset, cur_dataset in enumerate(datasets, 0):
22 | # get configuration for current dataset
23 | dc = dataset_config[cur_dataset]
24 |
25 | # transformations
26 | trn_transform, tst_transform = get_transforms(resize=dc['resize'],
27 | resize_test=dc['resize_test'],
28 | pad=dc['pad'],
29 | crop=dc['crop'],
30 | cifar_crop=dc['cifar_crop'],
31 | flip=dc['flip'],
32 | normalize=dc['normalize'],
33 | extend_channel=dc['extend_channel'])
34 |
35 | # datasets
36 | trn_dset, val_dset, tst_dset, curtaskcla = get_datasets(cur_dataset, dc['path'], num_tasks, nc_first_task,
37 | validation=validation,
38 | trn_transform=trn_transform,
39 | tst_transform=tst_transform,
40 | class_order=dc['class_order'])
41 |
42 | # apply offsets in case of multiple datasets
43 | if idx_dataset > 0:
44 | for tt in range(num_tasks):
45 | trn_dset[tt].labels = [elem + dataset_offset for elem in trn_dset[tt].labels]
46 | val_dset[tt].labels = [elem + dataset_offset for elem in val_dset[tt].labels]
47 | tst_dset[tt].labels = [elem + dataset_offset for elem in tst_dset[tt].labels]
48 | dataset_offset = dataset_offset + sum([tc[1] for tc in curtaskcla])
49 |
50 | # reassign class idx for multiple dataset case
51 | curtaskcla = [(tc[0] + idx_dataset * num_tasks, tc[1]) for tc in curtaskcla]
52 |
53 | # extend final taskcla list
54 | taskcla.extend(curtaskcla)
55 |
56 | # loaders
57 | if ddp:
58 | for tt in range(num_tasks):
59 | trn_sampler = torch.utils.data.DistributedSampler(trn_dset[tt], shuffle=True)
60 | trn_load.append(data.DataLoader(trn_dset[tt], batch_size=batch_size, num_workers=num_workers,
61 | sampler=trn_sampler, pin_memory=pin_memory))
62 | val_load.append(data.DataLoader(val_dset[tt], batch_size=batch_size, shuffle=False,
63 | num_workers=num_workers, pin_memory=pin_memory))
64 | tst_load.append(data.DataLoader(tst_dset[tt], batch_size=batch_size, shuffle=False,
65 | num_workers=num_workers, pin_memory=pin_memory))
66 | else:
67 | for tt in range(num_tasks):
68 | trn_load.append(data.DataLoader(trn_dset[tt], batch_size=batch_size, shuffle=True, num_workers=num_workers,
69 | pin_memory=pin_memory))
70 | val_load.append(data.DataLoader(val_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers,
71 | pin_memory=pin_memory))
72 | tst_load.append(data.DataLoader(tst_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers,
73 | pin_memory=pin_memory))
74 | return trn_load, val_load, tst_load, taskcla
75 |
76 |
77 | def get_datasets(dataset, path, num_tasks, nc_first_task, validation, trn_transform, tst_transform, class_order=None):
78 | """Extract datasets and create Dataset class"""
79 |
80 | trn_dset, val_dset, tst_dset = [], [], []
81 |
82 | if 'mnist' in dataset:
83 | tvmnist_trn = TorchVisionMNIST(path, train=True, download=True)
84 | tvmnist_tst = TorchVisionMNIST(path, train=False, download=True)
85 | trn_data = {'x': tvmnist_trn.data.numpy(), 'y': tvmnist_trn.targets.tolist()}
86 | tst_data = {'x': tvmnist_tst.data.numpy(), 'y': tvmnist_tst.targets.tolist()}
87 | # compute splits
88 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
89 | num_tasks=num_tasks, nc_first_task=nc_first_task,
90 | shuffle_classes=class_order is None, class_order=class_order)
91 | # set dataset type
92 | Dataset = memd.MemoryDataset
93 |
94 | elif 'cifar100' in dataset:
95 | tvcifar_trn = TorchVisionCIFAR100(path, train=True, download=True)
96 | tvcifar_tst = TorchVisionCIFAR100(path, train=False, download=True)
97 | trn_data = {'x': tvcifar_trn.data, 'y': tvcifar_trn.targets}
98 | tst_data = {'x': tvcifar_tst.data, 'y': tvcifar_tst.targets}
99 | # compute splits
100 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
101 | num_tasks=num_tasks, nc_first_task=nc_first_task,
102 | shuffle_classes=class_order is None, class_order=class_order)
103 | # set dataset type
104 | Dataset = memd.MemoryDataset
105 |
106 | elif dataset == 'svhn':
107 | tvsvhn_trn = TorchVisionSVHN(path, split='train', download=True)
108 | tvsvhn_tst = TorchVisionSVHN(path, split='test', download=True)
109 | trn_data = {'x': tvsvhn_trn.data.transpose(0, 2, 3, 1), 'y': tvsvhn_trn.labels}
110 | tst_data = {'x': tvsvhn_tst.data.transpose(0, 2, 3, 1), 'y': tvsvhn_tst.labels}
111 | # Notice that SVHN in Torchvision has an extra training set in case needed
112 | # tvsvhn_xtr = TorchVisionSVHN(path, split='extra', download=True)
113 | # xtr_data = {'x': tvsvhn_xtr.data.transpose(0, 2, 3, 1), 'y': tvsvhn_xtr.labels}
114 |
115 | # compute splits
116 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
117 | num_tasks=num_tasks, nc_first_task=nc_first_task,
118 | shuffle_classes=class_order is None, class_order=class_order)
119 | # set dataset type
120 | Dataset = memd.MemoryDataset
121 |
122 | elif 'imagenet_32' in dataset:
123 | import pickle
124 | # load data
125 | x_trn, y_trn = [], []
126 | for i in range(1, 11):
127 | with open(os.path.join(path, 'train_data_batch_{}'.format(i)), 'rb') as f:
128 | d = pickle.load(f)
129 | x_trn.append(d['data'])
130 | y_trn.append(np.array(d['labels']) - 1) # labels from 0 to 999
131 | with open(os.path.join(path, 'val_data'), 'rb') as f:
132 | d = pickle.load(f)
133 | x_trn.append(d['data'])
134 | y_tst = np.array(d['labels']) - 1 # labels from 0 to 999
135 | # reshape data
136 | for i, d in enumerate(x_trn, 0):
137 | x_trn[i] = d.reshape(d.shape[0], 3, 32, 32).transpose(0, 2, 3, 1)
138 | x_tst = x_trn[-1]
139 | x_trn = np.vstack(x_trn[:-1])
140 | y_trn = np.concatenate(y_trn)
141 | trn_data = {'x': x_trn, 'y': y_trn}
142 | tst_data = {'x': x_tst, 'y': y_tst}
143 | # compute splits
144 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
145 | num_tasks=num_tasks, nc_first_task=nc_first_task,
146 | shuffle_classes=class_order is None, class_order=class_order)
147 | # set dataset type
148 | Dataset = memd.MemoryDataset
149 |
150 | elif dataset == 'imagenet_100':
151 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs
152 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task,
153 | validation=validation, shuffle_classes=class_order is None,
154 | class_order=class_order, trn_lst='train_100.txt', tst_lst='val_100.txt')
155 | # set dataset type
156 | Dataset = basedat.BaseDataset
157 |
158 | elif dataset == 'imagenet_1000':
159 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs
160 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task,
161 | validation=validation, shuffle_classes=class_order is None,
162 | class_order=class_order, trn_lst='train_1000.txt', tst_lst='val_1000.txt')
163 | # set dataset type
164 | Dataset = basedat.BaseDataset
165 |
166 | else:
167 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs
168 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task,
169 | validation=validation, shuffle_classes=class_order is None,
170 | class_order=class_order)
171 | # set dataset type
172 | Dataset = basedat.BaseDataset
173 |
174 | # get datasets, apply correct label offsets for each task
175 | offset = 0
176 | for task in range(num_tasks):
177 | all_data[task]['trn']['y'] = [label + offset for label in all_data[task]['trn']['y']]
178 | all_data[task]['val']['y'] = [label + offset for label in all_data[task]['val']['y']]
179 | all_data[task]['tst']['y'] = [label + offset for label in all_data[task]['tst']['y']]
180 | trn_dset.append(Dataset(all_data[task]['trn'], trn_transform, class_indices))
181 | val_dset.append(Dataset(all_data[task]['val'], tst_transform, class_indices))
182 | tst_dset.append(Dataset(all_data[task]['tst'], tst_transform, class_indices))
183 | offset += taskcla[task][1]
184 |
185 | return trn_dset, val_dset, tst_dset, taskcla
186 |
187 |
188 | def get_transforms(resize, resize_test, pad, crop, cifar_crop, flip, normalize, extend_channel):
189 | """Unpack transformations and apply to train or test splits"""
190 |
191 | trn_transform_list = []
192 | tst_transform_list = []
193 |
194 | # resize
195 | if resize is not None:
196 | trn_transform_list.append(transforms.Resize(resize))
197 | if resize_test is not None:
198 | tst_transform_list.append(transforms.Resize(resize_test))
199 |
200 | # padding
201 | if pad is not None:
202 | trn_transform_list.append(transforms.Pad(pad))
203 | # tst_transform_list.append(transforms.Pad(pad))
204 |
205 | # crop
206 | if crop is not None:
207 | trn_transform_list.append(transforms.RandomResizedCrop(crop))
208 | tst_transform_list.append(transforms.CenterCrop(crop))
209 |
210 | if cifar_crop is not None:
211 | trn_transform_list.append(transforms.RandomCrop(cifar_crop))
212 |
213 | # flips
214 | if flip:
215 | trn_transform_list.append(transforms.RandomHorizontalFlip())
216 |
217 | # to tensor
218 | trn_transform_list.append(transforms.ToTensor())
219 | tst_transform_list.append(transforms.ToTensor())
220 |
221 | # normalization
222 | if normalize is not None:
223 | trn_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1]))
224 | tst_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1]))
225 |
226 | # gray to rgb
227 | if extend_channel is not None:
228 | trn_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1)))
229 | tst_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1)))
230 |
231 | print(trn_transform_list)
232 | print(tst_transform_list)
233 | return transforms.Compose(trn_transform_list), \
234 | transforms.Compose(tst_transform_list)
235 |
--------------------------------------------------------------------------------
/src/datasets/dataset_config.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | _BASE_DATA_PATH = "../data"
4 |
5 | dataset_config = {
6 | 'mnist': {
7 | 'path': join(_BASE_DATA_PATH, 'mnist'),
8 | 'normalize': ((0.1307,), (0.3081,)),
9 | # Use the next 3 lines to use MNIST with a 3x32x32 input
10 | # 'extend_channel': 3,
11 | # 'pad': 2,
12 | # 'normalize': ((0.1,), (0.2752,)) # values including padding
13 | },
14 | 'svhn': {
15 | 'path': join(_BASE_DATA_PATH, 'svhn'),
16 | 'resize': (224, 224),
17 | 'crop': None,
18 | 'flip': False,
19 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
20 | },
21 | 'cifar100': {
22 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
23 | 'resize': None,
24 | 'pad': 4,
25 | 'crop': 32,
26 | 'flip': True,
27 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
28 | },
29 | 'cifar100_icarl': {
30 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
31 | # 'resize': None,
32 | 'pad': 4,
33 | # 'crop': 32,
34 | 'cifar_crop': 32,
35 | 'flip': True,
36 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
37 | 'class_order': [
38 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
39 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
40 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
41 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
42 | ]
43 | },
44 | 'vggface2': {
45 | 'path': join(_BASE_DATA_PATH, 'VGGFace2'),
46 | 'resize': 256,
47 | 'crop': 224,
48 | 'flip': True,
49 | 'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169))
50 | },
51 | 'imagenet_1000': {
52 | 'path': join(_BASE_DATA_PATH, 'imagenet'),
53 | 'resize': None,
54 | 'resize_test': 256,
55 | 'crop': 224,
56 | 'flip': True,
57 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
58 | },
59 | 'imagenet_100': {
60 | 'path': join(_BASE_DATA_PATH, 'imagenet'),
61 | 'resize': None,
62 | 'resize_test': 256,
63 | 'crop': 224,
64 | 'flip': True,
65 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
66 | },
67 | 'cars': {
68 | 'path': join(_BASE_DATA_PATH, 'cars'),
69 | 'resize': None,
70 | 'resize_test': 256,
71 | 'crop': 224,
72 | 'flip': True,
73 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
74 | },
75 | 'aircraft': {
76 | 'path': join(_BASE_DATA_PATH, 'aircraft'),
77 | 'resize': None,
78 | 'resize_test': 256,
79 | 'crop': 224,
80 | 'flip': True,
81 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
82 | },
83 | 'imagenet_32_reduced': {
84 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_32'),
85 | 'resize': None,
86 | 'pad': 4,
87 | 'crop': 32,
88 | 'flip': True,
89 | 'normalize': ((0.481, 0.457, 0.408), (0.260, 0.253, 0.268)),
90 | 'class_order': [
91 | 472, 46, 536, 806, 547, 976, 662, 12, 955, 651, 492, 80, 999, 996, 788, 471, 911, 907, 680, 126, 42, 882,
92 | 327, 719, 716, 224, 918, 647, 808, 261, 140, 908, 833, 925, 57, 388, 407, 215, 45, 479, 525, 641, 915, 923,
93 | 108, 461, 186, 843, 115, 250, 829, 625, 769, 323, 974, 291, 438, 50, 825, 441, 446, 200, 162, 373, 872, 112,
94 | 212, 501, 91, 672, 791, 370, 942, 172, 315, 959, 636, 635, 66, 86, 197, 182, 59, 736, 175, 445, 947, 268,
95 | 238, 298, 926, 851, 494, 760, 61, 293, 696, 659, 69, 819, 912, 486, 706, 343, 390, 484, 282, 729, 575, 731,
96 | 530, 32, 534, 838, 466, 734, 425, 400, 290, 660, 254, 266, 551, 775, 721, 134, 886, 338, 465, 236, 522, 655,
97 | 209, 861, 88, 491, 985, 304, 981, 560, 405, 902, 521, 909, 763, 455, 341, 905, 280, 776, 113, 434, 274, 581,
98 | 158, 738, 671, 702, 147, 718, 148, 35, 13, 585, 591, 371, 745, 281, 956, 935, 346, 352, 284, 604, 447, 415,
99 | 98, 921, 118, 978, 880, 509, 381, 71, 552, 169, 600, 334, 171, 835, 798, 77, 249, 318, 419, 990, 335, 374,
100 | 949, 316, 755, 878, 946, 142, 299, 863, 558, 306, 183, 417, 64, 765, 565, 432, 440, 939, 297, 805, 364, 735,
101 | 251, 270, 493, 94, 773, 610, 278, 16, 363, 92, 15, 593, 96, 468, 252, 699, 377, 95, 799, 868, 820, 328, 756,
102 | 81, 991, 464, 774, 584, 809, 844, 940, 720, 498, 310, 384, 619, 56, 406, 639, 285, 67, 634, 792, 232, 54,
103 | 664, 818, 513, 349, 330, 207, 361, 345, 279, 549, 944, 817, 353, 228, 312, 796, 193, 179, 520, 451, 871,
104 | 692, 60, 481, 480, 929, 499, 673, 331, 506, 70, 645, 759, 744, 459]
105 | },
106 | 'mini_imagenet': {
107 | 'path': join(_BASE_DATA_PATH, 'mini_imagenet'),
108 | 'resize': None,
109 | 'crop': 84,
110 | 'flip': True,
111 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
112 | 'class_order': [26, 86, 2, 55, 75, 93, 16, 73, 54, 95, 53, 92, 78, 13, 7, 30, 22,
113 | 24, 33, 8, 43, 62, 3, 71, 45, 48, 6, 99, 82, 76, 60, 80, 90, 68,
114 | 51, 27, 18, 56, 63, 74, 1, 61, 42, 41, 4, 15, 17, 40, 38, 5, 91,
115 | 59, 0, 34, 28, 50, 11, 35, 23, 52, 10, 31, 66, 57, 79, 85, 32, 84,
116 | 14, 89, 19, 29, 49, 97, 98, 69, 20, 94, 72, 77, 25, 37, 81, 46, 39,
117 | 65, 58, 12, 88, 70, 87, 36, 21, 83, 9, 96, 67, 64, 47, 44]
118 | },
119 |
120 | }
121 |
122 | # Add missing keys:
123 | for dset in dataset_config.keys():
124 | for k in ['resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel', 'resize_test', 'cifar_crop']:
125 | if k not in dataset_config[dset].keys():
126 | dataset_config[dset][k] = None
127 | if 'flip' not in dataset_config[dset].keys():
128 | dataset_config[dset]['flip'] = False
129 |
--------------------------------------------------------------------------------
/src/datasets/exemplars_dataset.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from argparse import ArgumentParser
3 |
4 | from datasets.memory_dataset import MemoryDataset
5 |
6 |
7 | class ExemplarsDataset(MemoryDataset):
8 | """Exemplar storage for approaches with an interface of Dataset"""
9 |
10 | def __init__(self, transform, class_indices,
11 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'):
12 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices)
13 | self.max_num_exemplars_per_class = num_exemplars_per_class
14 | self.max_num_exemplars = num_exemplars
15 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!'
16 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize())
17 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name)
18 | self.exemplars_selector = selector_cls(self)
19 |
20 | # Returns a parser containing the approach specific parameters
21 | @staticmethod
22 | def extra_parser(args):
23 | parser = ArgumentParser("Exemplars Management Parameters")
24 | _group = parser.add_mutually_exclusive_group()
25 | _group.add_argument('--num-exemplars', default=0, type=int, required=False,
26 | help='Fixed memory, total number of exemplars (default=%(default)s)')
27 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False,
28 | help='Growing memory, number of exemplars per class (default=%(default)s)')
29 | parser.add_argument('--exemplar-selection', default='random', type=str,
30 | choices=['herding', 'random', 'entropy', 'distance'],
31 | required=False, help='Exemplar selection strategy (default=%(default)s)')
32 | return parser.parse_known_args(args)
33 |
34 | def _is_active(self):
35 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0
36 |
37 | def collect_exemplars(self, model, trn_loader, selection_transform, ddp):
38 | if self._is_active():
39 | self.images, self.labels = self.exemplars_selector(model, trn_loader, selection_transform, ddp)
40 |
--------------------------------------------------------------------------------
/src/datasets/exemplars_selection.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | from contextlib import contextmanager
4 | from typing import Iterable
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import DataLoader, ConcatDataset
9 | from torchvision.transforms import Lambda
10 |
11 | from datasets.exemplars_dataset import ExemplarsDataset
12 | from networks.network import LLL_Net
13 |
14 |
15 | class ExemplarsSelector:
16 | """Exemplar selector for approaches with an interface of Dataset"""
17 |
18 | def __init__(self, exemplars_dataset: ExemplarsDataset):
19 | self.exemplars_dataset = exemplars_dataset
20 |
21 | def __call__(self, model: LLL_Net, trn_loader: DataLoader, transform, ddp=False):
22 | clock0 = time.time()
23 | exemplars_per_class = self._exemplars_per_class_num(model)
24 | with override_dataset_transform(trn_loader.dataset, transform) as ds_for_selection:
25 | # change loader and fix to go sequentially (shuffle=False), keeps same order for later, eval transforms
26 | sel_loader = DataLoader(ds_for_selection, batch_size=trn_loader.batch_size, shuffle=False,
27 | num_workers=trn_loader.num_workers, pin_memory=trn_loader.pin_memory)
28 | selected_indices = self._select_indices(model, sel_loader, exemplars_per_class, transform)
29 | if ddp:
30 | # make sure all process using the same exemplar set
31 | selected_indices = torch.from_numpy(np.array(selected_indices)).cuda()
32 | torch.distributed.broadcast(selected_indices, src=0)
33 | selected_indices = selected_indices.cpu().tolist()
34 |
35 | with override_dataset_transform(trn_loader.dataset, Lambda(lambda x: np.array(x))) as ds_for_raw:
36 | x, y = zip(*(ds_for_raw[idx] for idx in selected_indices))
37 | clock1 = time.time()
38 | print('| Selected {:d} train exemplars, time={:5.1f}s'.format(len(x), clock1 - clock0))
39 | return x, y
40 |
41 | def _exemplars_per_class_num(self, model: LLL_Net):
42 | if self.exemplars_dataset.max_num_exemplars_per_class:
43 | return self.exemplars_dataset.max_num_exemplars_per_class
44 |
45 | num_cls = model.task_cls.sum().item()
46 | num_exemplars = self.exemplars_dataset.max_num_exemplars
47 | exemplars_per_class = int(np.ceil(num_exemplars / num_cls))
48 | assert exemplars_per_class > 0, \
49 | "Not enough exemplars to cover all classes!\n" \
50 | "Number of classes so far: {}. " \
51 | "Limit of exemplars: {}".format(num_cls,
52 | num_exemplars)
53 | return exemplars_per_class
54 |
55 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
56 | pass
57 |
58 |
59 | class RandomExemplarsSelector(ExemplarsSelector):
60 | """Selection of new samples. This is based on random selection, which produces a random list of samples."""
61 |
62 | def __init__(self, exemplars_dataset):
63 | super().__init__(exemplars_dataset)
64 |
65 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
66 | num_cls = sum(model.task_cls)
67 | result = []
68 | labels = self._get_labels(sel_loader)
69 | for curr_cls in range(num_cls):
70 | # get all indices from current class -- check if there are exemplars from previous task in the loader
71 | cls_ind = np.where(labels == curr_cls)[0]
72 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
73 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
74 | # select the exemplars randomly
75 | result.extend(random.sample(list(cls_ind), exemplars_per_class))
76 | return result
77 |
78 | def _get_labels(self, sel_loader):
79 | if hasattr(sel_loader.dataset, 'labels'): # BaseDataset, MemoryDataset
80 | labels = np.asarray(sel_loader.dataset.labels)
81 | elif isinstance(sel_loader.dataset, ConcatDataset):
82 | labels = []
83 | for ds in sel_loader.dataset.datasets:
84 | labels.extend(ds.labels)
85 | labels = np.array(labels)
86 | else:
87 | raise RuntimeError("Unsupported dataset: {}".format(sel_loader.dataset.__class__.__name__))
88 | return labels
89 |
90 |
91 | class HerdingExemplarsSelector(ExemplarsSelector):
92 | """Selection of new samples. This is based on herding selection, which produces a sorted list of samples of one
93 | class based on the distance to the mean sample of that class. From iCaRL algorithm 4 and 5:
94 | https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf
95 | """
96 | def __init__(self, exemplars_dataset):
97 | super().__init__(exemplars_dataset)
98 |
99 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
100 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device
101 |
102 | # extract outputs from the model for all train samples
103 | extracted_features = []
104 | extracted_targets = []
105 | with torch.no_grad():
106 | model.eval()
107 | for images, targets in sel_loader:
108 | feats = model(images.to(model_device), return_features=True)[1]
109 | feats = feats / feats.norm(dim=1).view(-1, 1) # Feature normalization
110 | extracted_features.append(feats)
111 | extracted_targets.extend(targets)
112 | extracted_features = (torch.cat(extracted_features)).cpu()
113 | extracted_targets = np.array(extracted_targets)
114 | result = []
115 | # iterate through all classes
116 | for curr_cls in np.unique(extracted_targets):
117 | # get all indices from current class
118 | cls_ind = np.where(extracted_targets == curr_cls)[0]
119 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
120 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
121 | # get all extracted features for current class
122 | cls_feats = extracted_features[cls_ind]
123 | # calculate the mean
124 | cls_mu = cls_feats.mean(0)
125 | # select the exemplars closer to the mean of each class
126 | selected = []
127 | selected_feat = []
128 | for k in range(exemplars_per_class):
129 | # fix this to the dimension of the model features
130 | sum_others = torch.zeros(cls_feats.shape[1])
131 | for j in selected_feat:
132 | sum_others += j / (k + 1)
133 | dist_min = np.inf
134 | # choose the closest to the mean of the current class
135 | for item in cls_ind:
136 | if item not in selected:
137 | feat = extracted_features[item]
138 | dist = torch.norm(cls_mu - feat / (k + 1) - sum_others)
139 | if dist < dist_min:
140 | dist_min = dist
141 | newone = item
142 | newonefeat = feat
143 | selected_feat.append(newonefeat)
144 | selected.append(newone)
145 | result.extend(selected)
146 | return result
147 |
148 |
149 | class EntropyExemplarsSelector(ExemplarsSelector):
150 | """Selection of new samples. This is based on entropy selection, which produces a sorted list of samples of one
151 | class based on entropy of each sample. From RWalk http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112
152 | """
153 | def __init__(self, exemplars_dataset):
154 | super().__init__(exemplars_dataset)
155 |
156 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
157 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device
158 |
159 | # extract outputs from the model for all train samples
160 | extracted_logits = []
161 | extracted_targets = []
162 | with torch.no_grad():
163 | model.eval()
164 | for images, targets in sel_loader:
165 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1))
166 | extracted_targets.extend(targets)
167 | extracted_logits = (torch.cat(extracted_logits)).cpu()
168 | extracted_targets = np.array(extracted_targets)
169 | result = []
170 | # iterate through all classes
171 | for curr_cls in np.unique(extracted_targets):
172 | # get all indices from current class
173 | cls_ind = np.where(extracted_targets == curr_cls)[0]
174 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
175 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
176 | # get all extracted features for current class
177 | cls_logits = extracted_logits[cls_ind]
178 | # select the exemplars with higher entropy (lower: -entropy)
179 | probs = torch.softmax(cls_logits, dim=1)
180 | log_probs = torch.log(probs)
181 | minus_entropy = (probs * log_probs).sum(1) # change sign of this variable for inverse order
182 | selected = cls_ind[minus_entropy.sort()[1][:exemplars_per_class]]
183 | result.extend(selected)
184 | return result
185 |
186 |
187 | class DistanceExemplarsSelector(ExemplarsSelector):
188 | """Selection of new samples. This is based on distance-based selection, which produces a sorted list of samples of
189 | one class based on closeness to decision boundary of each sample. From RWalk
190 | http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112
191 | """
192 | def __init__(self, exemplars_dataset):
193 | super().__init__(exemplars_dataset)
194 |
195 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int,
196 | transform) -> Iterable:
197 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device
198 |
199 | # extract outputs from the model for all train samples
200 | extracted_logits = []
201 | extracted_targets = []
202 | with torch.no_grad():
203 | model.eval()
204 | for images, targets in sel_loader:
205 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1))
206 | extracted_targets.extend(targets)
207 | extracted_logits = (torch.cat(extracted_logits)).cpu()
208 | extracted_targets = np.array(extracted_targets)
209 | result = []
210 | # iterate through all classes
211 | for curr_cls in np.unique(extracted_targets):
212 | # get all indices from current class
213 | cls_ind = np.where(extracted_targets == curr_cls)[0]
214 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
215 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
216 | # get all extracted features for current class
217 | cls_logits = extracted_logits[cls_ind]
218 | # select the exemplars closer to boundary
219 | distance = cls_logits[:, curr_cls] # change sign of this variable for inverse order
220 | selected = cls_ind[distance.sort()[1][:exemplars_per_class]]
221 | result.extend(selected)
222 | return result
223 |
224 |
225 | def dataset_transforms(dataset, transform_to_change):
226 | if isinstance(dataset, ConcatDataset):
227 | r = []
228 | for ds in dataset.datasets:
229 | r += dataset_transforms(ds, transform_to_change)
230 | return r
231 | else:
232 | old_transform = dataset.transform
233 | dataset.transform = transform_to_change
234 | return [(dataset, old_transform)]
235 |
236 |
237 | @contextmanager
238 | def override_dataset_transform(dataset, transform):
239 | try:
240 | datasets_with_orig_transform = dataset_transforms(dataset, transform)
241 | yield dataset
242 | finally:
243 | # get bac original transformations
244 | for ds, orig_transform in datasets_with_orig_transform:
245 | ds.transform = orig_transform
246 |
--------------------------------------------------------------------------------
/src/datasets/memory_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class MemoryDataset(Dataset):
8 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory"""
9 |
10 | def __init__(self, data, transform, class_indices=None):
11 | """Initialization"""
12 | self.labels = data['y']
13 | self.images = data['x']
14 | self.transform = transform
15 | self.class_indices = class_indices
16 |
17 | def __len__(self):
18 | """Denotes the total number of samples"""
19 | return len(self.images)
20 |
21 | def __getitem__(self, index):
22 | """Generates one sample of data"""
23 | x = Image.fromarray(self.images[index])
24 | x = self.transform(x)
25 | y = self.labels[index]
26 | return x, y
27 |
28 |
29 | def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
30 | """Prepare data: dataset splits, task partition, class order"""
31 |
32 | data = {}
33 | taskcla = []
34 | if class_order is None:
35 | num_classes = len(np.unique(trn_data['y']))
36 | class_order = list(range(num_classes))
37 | else:
38 | num_classes = len(class_order)
39 | class_order = class_order.copy()
40 | if shuffle_classes:
41 | np.random.shuffle(class_order)
42 |
43 | # compute classes per task and num_tasks
44 | if nc_first_task is None:
45 | cpertask = np.array([num_classes // num_tasks] * num_tasks)
46 | for i in range(num_classes % num_tasks):
47 | cpertask[i] += 1
48 | else:
49 | assert nc_first_task < num_classes, "first task wants more classes than exist"
50 | remaining_classes = num_classes - nc_first_task
51 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
52 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
53 | for i in range(remaining_classes % (num_tasks - 1)):
54 | cpertask[i + 1] += 1
55 |
56 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
57 | cpertask_cumsum = np.cumsum(cpertask)
58 | init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
59 |
60 | # initialize data structure
61 | for tt in range(num_tasks):
62 | data[tt] = {}
63 | data[tt]['name'] = 'task-' + str(tt)
64 | data[tt]['trn'] = {'x': [], 'y': []}
65 | data[tt]['val'] = {'x': [], 'y': []}
66 | data[tt]['tst'] = {'x': [], 'y': []}
67 |
68 | # ALL OR TRAIN
69 | filtering = np.isin(trn_data['y'], class_order)
70 | if filtering.sum() != len(trn_data['y']):
71 | trn_data['x'] = trn_data['x'][filtering]
72 | trn_data['y'] = np.array(trn_data['y'])[filtering]
73 | for this_image, this_label in zip(trn_data['x'], trn_data['y']):
74 | # If shuffling is false, it won't change the class number
75 | this_label = class_order.index(this_label)
76 | # add it to the corresponding split
77 | this_task = (this_label >= cpertask_cumsum).sum()
78 | data[this_task]['trn']['x'].append(this_image)
79 | data[this_task]['trn']['y'].append(this_label - init_class[this_task])
80 |
81 | # ALL OR TEST
82 | filtering = np.isin(tst_data['y'], class_order)
83 | if filtering.sum() != len(tst_data['y']):
84 | tst_data['x'] = tst_data['x'][filtering]
85 | tst_data['y'] = tst_data['y'][filtering]
86 | for this_image, this_label in zip(tst_data['x'], tst_data['y']):
87 | # If shuffling is false, it won't change the class number
88 | this_label = class_order.index(this_label)
89 | # add it to the corresponding split
90 | this_task = (this_label >= cpertask_cumsum).sum()
91 | data[this_task]['tst']['x'].append(this_image)
92 | data[this_task]['tst']['y'].append(this_label - init_class[this_task])
93 |
94 | # check classes
95 | for tt in range(num_tasks):
96 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
97 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
98 |
99 | # validation
100 | if validation > 0.0:
101 | for tt in data.keys():
102 | for cc in range(data[tt]['ncla']):
103 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
104 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
105 | rnd_img.sort(reverse=True)
106 | for ii in range(len(rnd_img)):
107 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
108 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
109 | data[tt]['trn']['x'].pop(rnd_img[ii])
110 | data[tt]['trn']['y'].pop(rnd_img[ii])
111 |
112 | # convert them to numpy arrays
113 | for tt in data.keys():
114 | for split in ['trn', 'val', 'tst']:
115 | data[tt][split]['x'] = np.asarray(data[tt][split]['x'])
116 |
117 | # other
118 | n = 0
119 | for t in data.keys():
120 | taskcla.append((t, data[t]['ncla']))
121 | n += data[t]['ncla']
122 | data['ncla'] = n
123 |
124 | return data, taskcla, class_order
125 |
--------------------------------------------------------------------------------
/src/exp_cifar_lucir.sh:
--------------------------------------------------------------------------------
1 | device_id=0
2 | SEED=0
3 | bz=64
4 | lr=0.1
5 | mom=0.9
6 | wd=5e-4
7 | data=cifar100_icarl
8 | network=resnet18_cifar
9 | nepochs=160
10 |
11 | appr=lucir
12 | lamb=5.0
13 | nc_first=50
14 | ntask=2
15 |
16 | first_task_bz=128
17 | first_task_lr=0.1
18 |
19 |
20 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name ${nc_first}_${ntask}_${SEED} \
21 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \
22 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 80 120 \
23 | --clipping -1 --results-path results --save-models \
24 | --approach $appr --lamb $lamb --first-task-bz $first_task_bz --first-task-lr $first_task_lr \
25 | --num-exemplars-per-class 20 --exemplar-selection herding
26 |
--------------------------------------------------------------------------------
/src/exp_cifar_lucir_cwd.sh:
--------------------------------------------------------------------------------
1 | device_id=0
2 | SEED=0
3 | bz=128
4 | lr=0.1
5 | mom=0.9
6 | wd=5e-4
7 | data=cifar100_icarl
8 | network=resnet18_cifar
9 | nepochs=160
10 | n_exemplar=20
11 |
12 | appr=lucir_cwd
13 | lamb=5.0
14 | nc_first=50
15 | ntask=6
16 |
17 | aux_coef=0.5
18 | rej_thresh=1
19 | first_task_lr=0.1
20 | first_task_bz=128
21 |
22 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \
23 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \
24 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 80 120 \
25 | --clipping -1 --results-path results --save-models \
26 | --approach $appr --lamb $lamb --num-exemplars-per-class $n_exemplar --exemplar-selection herding \
27 | --aux-coef $aux_coef --reject-threshold $rej_thresh \
28 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz
29 |
30 |
--------------------------------------------------------------------------------
/src/exp_im100_joint.sh:
--------------------------------------------------------------------------------
1 | device_id=7
2 | SEED=1
3 |
4 | bz=128
5 | lr=0.1
6 | mom=0.9
7 | wd=1e-4
8 | data=imagenet_100
9 | network=resnet18
10 | nepochs=90
11 |
12 | appr=joint
13 |
14 | nc_first=10
15 | ntask=10
16 |
17 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \
18 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \
19 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \
20 | --clipping -1 --results-path results --save-models \
21 | --approach $appr
22 |
23 |
24 |
--------------------------------------------------------------------------------
/src/exp_im100_lucir.sh:
--------------------------------------------------------------------------------
1 | device_id=0
2 | SEED=0
3 |
4 | bz=128
5 | lr=0.1
6 | mom=0.9
7 | wd=1e-4
8 | data=imagenet_100
9 | network=resnet18
10 | nepochs=90
11 | n_exemplar=20
12 |
13 | appr=lucir
14 | lamb=10.0
15 |
16 | nc_first=50
17 | ntask=6
18 |
19 | first_task_lr=0.1
20 | first_task_bz=128
21 |
22 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \
23 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \
24 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \
25 | --clipping -1 --results-path results --save-models \
26 | --approach $appr --lamb $lamb \
27 | --num-exemplars-per-class $n_exemplar --exemplar-selection herding \
28 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz
29 |
30 |
--------------------------------------------------------------------------------
/src/exp_im100_lucir_cwd.sh:
--------------------------------------------------------------------------------
1 | device_id=0
2 | SEED=0
3 | bz=128
4 | lr=0.1
5 | mom=0.9
6 | wd=1e-4
7 | data=imagenet_100
8 | network=resnet18
9 | nepochs=90
10 | n_exemplar=20
11 |
12 | appr=lucir_cwd
13 | lamb=10.0
14 | nc_first=50
15 | ntask=6
16 |
17 | aux_coef=0.75
18 | rej_thresh=1
19 | first_task_lr=0.2
20 | first_task_bz=128
21 |
22 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \
23 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \
24 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \
25 | --clipping -1 --results-path results --save-models \
26 | --approach $appr --lamb $lamb --num-exemplars-per-class $n_exemplar --exemplar-selection herding \
27 | --aux-coef $aux_coef --reject-threshold $rej_thresh \
28 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz
29 |
--------------------------------------------------------------------------------
/src/exp_im100_lucir_oracle.sh:
--------------------------------------------------------------------------------
1 | device_id=0
2 | SEED=0
3 |
4 | bz=128
5 | lr=0.1
6 | mom=0.9
7 | wd=1e-4
8 | data=imagenet_100
9 | network=resnet18
10 | nepochs=90
11 | n_exemplar=20
12 |
13 | appr=lucir_oracle
14 | lamb=10.0
15 |
16 | nc_first=50
17 | ntask=6
18 |
19 | first_task_lr=0.1
20 | first_task_bz=128
21 | aux_coef=10.0
22 | oracle_path=baselines/imagenet_subset_lucir_nc_first_99_ntask_2/models/task0.ckpt
23 |
24 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \
25 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \
26 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \
27 | --clipping -1 --results-path results \
28 | --approach $appr --lamb $lamb \
29 | --num-exemplars-per-class $n_exemplar --exemplar-selection herding \
30 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz \
31 | --aux-coef $aux_coef --oracle-path $oracle_path
32 |
--------------------------------------------------------------------------------
/src/gridsearch.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from argparse import ArgumentParser
4 |
5 | import utils
6 |
7 |
8 | class GridSearch:
9 | """Basic class for implementing hyperparameter grid search"""
10 |
11 | def __init__(self, appr_ft, seed, gs_config='gridsearch_config', acc_drop_thr=0.2, hparam_decay=0.5,
12 | max_num_searches=7):
13 | self.seed = seed
14 | GridSearchConfig = getattr(importlib.import_module(name=gs_config), 'GridSearchConfig')
15 | self.appr_ft = appr_ft
16 | self.gs_config = GridSearchConfig()
17 | self.acc_drop_thr = acc_drop_thr
18 | self.hparam_decay = hparam_decay
19 | self.max_num_searches = max_num_searches
20 | self.lr_first = 1.0
21 |
22 | @staticmethod
23 | def extra_parser(args):
24 | """Returns a parser containing the GridSearch specific parameters"""
25 | parser = ArgumentParser()
26 | # Configuration file with a GridSearchConfig class with all necessary args
27 | parser.add_argument('--gridsearch-config', type=str, default='gridsearch_config', required=False,
28 | help='Configuration file for GridSearch options (default=%(default)s)')
29 | # Accuracy threshold drop below which the search stops for that phase
30 | parser.add_argument('--gridsearch-acc-drop-thr', default=0.2, type=float, required=False,
31 | help='GridSearch accuracy drop threshold (default=%(default)f)')
32 | # Value at which hyperparameters decay
33 | parser.add_argument('--gridsearch-hparam-decay', default=0.5, type=float, required=False,
34 | help='GridSearch hyperparameter decay (default=%(default)f)')
35 | # Maximum number of searched before the search stops for that phase
36 | parser.add_argument('--gridsearch-max-num-searches', default=7, type=int, required=False,
37 | help='GridSearch maximum number of hyperparameter search (default=%(default)f)')
38 | return parser.parse_known_args(args)
39 |
40 | def search_lr(self, model, t, trn_loader, val_loader):
41 | """Search for accuracy and best LR on finetuning"""
42 | best_ft_acc = 0.0
43 | best_ft_lr = 0.0
44 |
45 | # Get general parameters and fix the ones with only one value
46 | gen_params = self.gs_config.get_params('general')
47 | for k, v in gen_params.items():
48 | if not isinstance(v, list):
49 | setattr(self.appr_ft, k, v)
50 | if t > 0:
51 | # LR for search are 'lr_searches' largest LR below 'lr_first'
52 | list_lr = [lr for lr in gen_params['lr'] if lr < self.lr_first][:gen_params['lr_searches'][0]]
53 | else:
54 | # For first task, try larger LR range
55 | list_lr = gen_params['lr_first']
56 |
57 | # Iterate through the other variable parameters
58 | for curr_lr in list_lr:
59 | utils.seed_everything(seed=self.seed)
60 | self.appr_ft.model = deepcopy(model)
61 | self.appr_ft.lr = curr_lr
62 | self.appr_ft.train(t, trn_loader, val_loader)
63 | _, ft_acc_taw, _ = self.appr_ft.eval(t, val_loader)
64 | if ft_acc_taw > best_ft_acc:
65 | best_ft_acc = ft_acc_taw
66 | best_ft_lr = curr_lr
67 | print('Current best LR: ' + str(best_ft_lr))
68 | self.gs_config.current_lr = best_ft_lr
69 | print('Current best acc: {:5.1f}'.format(best_ft_acc * 100))
70 | # After first task, keep LR used
71 | if t == 0:
72 | self.lr_first = best_ft_lr
73 |
74 | return best_ft_acc, best_ft_lr
75 |
76 | def search_tradeoff(self, appr_name, appr, t, trn_loader, val_loader, best_ft_acc):
77 | """Search for less-forgetting tradeoff with minimum accuracy loss"""
78 | best_tradeoff = None
79 | tradeoff_name = None
80 |
81 | # Get general parameters and fix all the ones that have only one option
82 | appr_params = self.gs_config.get_params(appr_name)
83 | for k, v in appr_params.items():
84 | if isinstance(v, list):
85 | # get tradeoff name as the only one with multiple values
86 | tradeoff_name = k
87 | else:
88 | # Any other hyperparameters are fixed
89 | setattr(appr, k, v)
90 |
91 | # If there is no tradeoff, no need to gridsearch more
92 | if tradeoff_name is not None and t > 0:
93 | # get starting value for trade-off hyperparameter
94 | best_tradeoff = appr_params[tradeoff_name][0]
95 | # iterate through decreasing trade-off values -- limit to `max_num_searches` searches
96 | num_searches = 0
97 | while num_searches < self.max_num_searches:
98 | utils.seed_everything(seed=self.seed)
99 | # Make deepcopy of the appr without duplicating the logger
100 | appr_gs = type(appr)(deepcopy(appr.model), appr.device, exemplars_dataset=appr.exemplars_dataset)
101 | for attr, value in vars(appr).items():
102 | if attr == 'logger':
103 | setattr(appr_gs, attr, value)
104 | else:
105 | setattr(appr_gs, attr, deepcopy(value))
106 |
107 | # update tradeoff value
108 | setattr(appr_gs, tradeoff_name, best_tradeoff)
109 | # train this iteration
110 | appr_gs.train(t, trn_loader, val_loader)
111 | _, curr_acc, _ = appr_gs.eval(t, val_loader)
112 | print('Current acc: ' + str(curr_acc) + ' for ' + tradeoff_name + '=' + str(best_tradeoff))
113 | # Check if accuracy is within acceptable threshold drop
114 | if curr_acc < ((1 - self.acc_drop_thr) * best_ft_acc):
115 | best_tradeoff = best_tradeoff * self.hparam_decay
116 | else:
117 | break
118 | num_searches += 1
119 | else:
120 | print('There is no trade-off to gridsearch.')
121 |
122 | return best_tradeoff, tradeoff_name
123 |
--------------------------------------------------------------------------------
/src/gridsearch_config.py:
--------------------------------------------------------------------------------
1 | class GridSearchConfig():
2 | def __init__(self):
3 | self.params = {
4 | 'general': {
5 | 'lr_first': [5e-1, 1e-1, 5e-2],
6 | 'lr': [1e-1, 5e-2, 1e-2, 5e-3, 1e-3],
7 | 'lr_searches': [3],
8 | 'lr_min': 1e-4,
9 | 'lr_factor': 3,
10 | 'lr_patience': 10,
11 | 'clipping': 10000,
12 | 'momentum': 0.9,
13 | 'wd': 0.0002
14 | },
15 | 'finetuning': {
16 | },
17 | 'freezing': {
18 | },
19 | 'joint': {
20 | },
21 | 'lwf': {
22 | 'lamb': [10],
23 | 'T': 2
24 | },
25 | 'icarl': {
26 | 'lamb': [4]
27 | },
28 | 'dmc': {
29 | 'aux_dataset': 'imagenet_32_reduced',
30 | 'aux_batch_size': 128
31 | },
32 | 'il2m': {
33 | },
34 | 'eeil': {
35 | 'lamb': [10],
36 | 'T': 2,
37 | 'lr_finetuning_factor': 0.1,
38 | 'nepochs_finetuning': 40,
39 | 'noise_grad': False
40 | },
41 | 'bic': {
42 | 'T': 2,
43 | 'val_percentage': 0.1,
44 | 'bias_epochs': 200
45 | },
46 | 'lucir': {
47 | 'lamda_base': [10],
48 | 'lamda_mr': 1.0,
49 | 'dist': 0.5,
50 | 'K': 2
51 | },
52 | 'lwm': {
53 | 'beta': [2],
54 | 'gamma': 1.0
55 | },
56 | 'ewc': {
57 | 'lamb': [10000]
58 | },
59 | 'mas': {
60 | 'lamb': [400]
61 | },
62 | 'path_integral': {
63 | 'lamb': [10],
64 | },
65 | 'r_walk': {
66 | 'lamb': [20],
67 | },
68 | }
69 | self.current_lr = self.params['general']['lr'][0]
70 | self.current_tradeoff = 0
71 |
72 | def get_params(self, approach):
73 | return self.params[approach]
74 |
--------------------------------------------------------------------------------
/src/last_layer_analysis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | matplotlib.use('Agg')
6 |
7 |
8 | def last_layer_analysis(heads, task, taskcla, y_lim=False, sort_weights=False):
9 | """Plot last layer weight and bias analysis"""
10 | print('Plotting last layer analysis...')
11 | num_classes = sum([x for (_, x) in taskcla])
12 | weights, biases, indexes = [], [], []
13 | class_id = 0
14 | with torch.no_grad():
15 | for t in range(task + 1):
16 | n_classes_t = taskcla[t][1]
17 | indexes.append(np.arange(class_id, class_id + n_classes_t))
18 | if type(heads) == torch.nn.Linear: # Single head
19 | biases.append(heads.bias[class_id: class_id + n_classes_t].detach().cpu().numpy())
20 | weights.append((heads.weight[class_id: class_id + n_classes_t] ** 2).sum(1).sqrt().detach().cpu().numpy())
21 | else: # Multi-head
22 | weights.append((heads[t].weight ** 2).sum(1).sqrt().detach().cpu().numpy())
23 | if type(heads[t]) == torch.nn.Linear:
24 | biases.append(heads[t].bias.detach().cpu().numpy())
25 | else:
26 | biases.append(np.zeros(weights[-1].shape)) # For LUCIR
27 | class_id += n_classes_t
28 |
29 | # Figure weights
30 | f_weights = plt.figure(dpi=300)
31 | ax = f_weights.subplots(nrows=1, ncols=1)
32 | for i, (x, y) in enumerate(zip(indexes, weights), 0):
33 | if sort_weights:
34 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i))
35 | else:
36 | ax.bar(x, y, label="Task {}".format(i))
37 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif')
38 | ax.set_ylabel("Weights L2-norm", fontsize=11, fontfamily='serif')
39 | if num_classes is not None:
40 | ax.set_xlim(0, num_classes)
41 | if y_lim:
42 | ax.set_ylim(0, 5)
43 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif')
44 |
45 | # Figure biases
46 | f_biases = plt.figure(dpi=300)
47 | ax = f_biases.subplots(nrows=1, ncols=1)
48 | for i, (x, y) in enumerate(zip(indexes, biases), 0):
49 | if sort_weights:
50 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i))
51 | else:
52 | ax.bar(x, y, label="Task {}".format(i))
53 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif')
54 | ax.set_ylabel("Bias values", fontsize=11, fontfamily='serif')
55 | if num_classes is not None:
56 | ax.set_xlim(0, num_classes)
57 | if y_lim:
58 | ax.set_ylim(-1.0, 1.0)
59 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif')
60 |
61 | return f_weights, f_biases
62 |
--------------------------------------------------------------------------------
/src/loggers/disk_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import torch
5 | import numpy as np
6 | from datetime import datetime
7 |
8 | from loggers.exp_logger import ExperimentLogger
9 |
10 |
11 | class Logger(ExperimentLogger):
12 | """Characterizes a disk logger"""
13 |
14 | def __init__(self, log_path, exp_name, begin_time=None):
15 | super(Logger, self).__init__(log_path, exp_name, begin_time)
16 |
17 | self.begin_time_str = self.begin_time.strftime("%Y-%m-%d-%H-%M")
18 |
19 | # Duplicate standard outputs
20 | sys.stdout = FileOutputDuplicator(sys.stdout,
21 | os.path.join(self.exp_path, 'stdout-{}.txt'.format(self.begin_time_str)), 'w')
22 | sys.stderr = FileOutputDuplicator(sys.stderr,
23 | os.path.join(self.exp_path, 'stderr-{}.txt'.format(self.begin_time_str)), 'w')
24 |
25 | # Raw log file
26 | self.raw_log_file = open(os.path.join(self.exp_path, "raw_log-{}.txt".format(self.begin_time_str)), 'a')
27 |
28 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
29 | if curtime is None:
30 | curtime = datetime.now()
31 |
32 | # Raw dump
33 | entry = {"task": task, "iter": iter, "name": name, "value": value, "group": group,
34 | "time": curtime.strftime("%Y-%m-%d-%H-%M")}
35 | self.raw_log_file.write(json.dumps(entry, sort_keys=True) + "\n")
36 | self.raw_log_file.flush()
37 |
38 | def log_args(self, args):
39 | with open(os.path.join(self.exp_path, 'args-{}.txt'.format(self.begin_time_str)), 'w') as f:
40 | json.dump(args.__dict__, f, separators=(',\n', ' : '), sort_keys=True)
41 |
42 | def log_result(self, array, name, step):
43 | if array.ndim <= 1:
44 | array = array[None]
45 | np.savetxt(os.path.join(self.exp_path, 'results', '{}-{}.txt'.format(name, self.begin_time_str)),
46 | array, '%.6f', delimiter='\t')
47 |
48 | def log_figure(self, name, iter, figure, curtime=None):
49 | curtime = datetime.now()
50 | figure.savefig(os.path.join(self.exp_path, 'figures',
51 | '{}_{}-{}.png'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S"))))
52 | figure.savefig(os.path.join(self.exp_path, 'figures',
53 | '{}_{}-{}.pdf'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S"))))
54 |
55 | def save_model(self, state_dict, task):
56 | torch.save(state_dict, os.path.join(self.exp_path, "models", "task{}.ckpt".format(task)))
57 |
58 | def __del__(self):
59 | self.raw_log_file.close()
60 |
61 |
62 | class FileOutputDuplicator(object):
63 | def __init__(self, duplicate, fname, mode):
64 | self.file = open(fname, mode)
65 | self.duplicate = duplicate
66 |
67 | def __del__(self):
68 | self.file.close()
69 |
70 | def write(self, data):
71 | self.file.write(data)
72 | self.duplicate.write(data)
73 |
74 | def flush(self):
75 | self.file.flush()
76 |
--------------------------------------------------------------------------------
/src/loggers/exp_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | from datetime import datetime
4 |
5 |
6 | class ExperimentLogger:
7 | """Main class for experiment logging"""
8 |
9 | def __init__(self, log_path, exp_name, begin_time=None):
10 | self.log_path = log_path
11 | self.exp_name = exp_name
12 | self.exp_path = os.path.join(log_path, exp_name)
13 | if begin_time is None:
14 | self.begin_time = datetime.now()
15 | else:
16 | self.begin_time = begin_time
17 |
18 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
19 | pass
20 |
21 | def log_args(self, args):
22 | pass
23 |
24 | def log_result(self, array, name, step):
25 | pass
26 |
27 | def log_figure(self, name, iter, figure, curtime=None):
28 | pass
29 |
30 | def save_model(self, state_dict, task):
31 | pass
32 |
33 |
34 | class MultiLogger(ExperimentLogger):
35 | """This class allows to use multiple loggers"""
36 |
37 | def __init__(self, log_path, exp_name, loggers=None, save_models=True):
38 | super(MultiLogger, self).__init__(log_path, exp_name)
39 | if os.path.exists(self.exp_path):
40 | print("WARNING: {} already exists!".format(self.exp_path))
41 | else:
42 | os.makedirs(os.path.join(self.exp_path, 'models'))
43 | os.makedirs(os.path.join(self.exp_path, 'results'))
44 | os.makedirs(os.path.join(self.exp_path, 'figures'))
45 |
46 | self.save_models = save_models
47 | self.loggers = []
48 | for l in loggers:
49 | lclass = getattr(importlib.import_module(name='loggers.' + l + '_logger'), 'Logger')
50 | self.loggers.append(lclass(self.log_path, self.exp_name))
51 |
52 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
53 | if curtime is None:
54 | curtime = datetime.now()
55 | for l in self.loggers:
56 | l.log_scalar(task, iter, name, value, group, curtime)
57 |
58 | def log_args(self, args):
59 | for l in self.loggers:
60 | l.log_args(args)
61 |
62 | def log_result(self, array, name, step):
63 | for l in self.loggers:
64 | l.log_result(array, name, step)
65 |
66 | def log_figure(self, name, iter, figure, curtime=None):
67 | if curtime is None:
68 | curtime = datetime.now()
69 | for l in self.loggers:
70 | l.log_figure(name, iter, figure, curtime)
71 |
72 | def save_model(self, state_dict, task):
73 | if self.save_models:
74 | for l in self.loggers:
75 | l.save_model(state_dict, task)
76 |
--------------------------------------------------------------------------------
/src/loggers/tensorboard_logger.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 |
3 | from loggers.exp_logger import ExperimentLogger
4 | import json
5 | import numpy as np
6 |
7 |
8 | class Logger(ExperimentLogger):
9 | """Characterizes a Tensorboard logger"""
10 |
11 | def __init__(self, log_path, exp_name, begin_time=None):
12 | super(Logger, self).__init__(log_path, exp_name, begin_time)
13 | self.tbwriter = SummaryWriter(self.exp_path)
14 |
15 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
16 | self.tbwriter.add_scalar(tag="t{}/{}_{}".format(task, group, name),
17 | scalar_value=value,
18 | global_step=iter)
19 | self.tbwriter.file_writer.flush()
20 |
21 | def log_figure(self, name, iter, figure, curtime=None):
22 | self.tbwriter.add_figure(tag=name, figure=figure, global_step=iter)
23 | self.tbwriter.file_writer.flush()
24 |
25 | def log_args(self, args):
26 | self.tbwriter.add_text(
27 | 'args',
28 | json.dumps(args.__dict__,
29 | separators=(',\n', ' : '),
30 | sort_keys=True))
31 | self.tbwriter.file_writer.flush()
32 |
33 | def log_result(self, array, name, step):
34 | if array.ndim == 1:
35 | # log as scalars
36 | self.tbwriter.add_scalar(f'results/{name}', array[step], step)
37 |
38 | elif array.ndim == 2:
39 | s = ""
40 | i = step
41 | # for i in range(array.shape[0]):
42 | for j in range(array.shape[1]):
43 | s += '{:5.1f}% '.format(100 * array[i, j])
44 | if np.trace(array) == 0.0:
45 | if i > 0:
46 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i].mean())
47 | else:
48 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i + 1].mean())
49 | self.tbwriter.add_text(f'results/{name}', s, step)
50 |
51 | def __del__(self):
52 | self.tbwriter.close()
53 |
--------------------------------------------------------------------------------
/src/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 |
3 | from .resnet18 import resnet18
4 | from .resnet18_cifar import resnet18_cifar
5 |
6 | # available torchvision models
7 | tvmodels = ['alexnet',
8 | 'densenet121', 'densenet169', 'densenet201', 'densenet161',
9 | 'googlenet',
10 | 'inception_v3',
11 | 'mobilenet_v2',
12 | 'resnet34',
13 | 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
14 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
15 | 'squeezenet1_0', 'squeezenet1_1',
16 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',
17 | 'wide_resnet50_2', 'wide_resnet101_2'
18 | ]
19 |
20 | allmodels = tvmodels + ['resnet18', 'resnet18_cifar']
21 |
22 |
23 | def set_tvmodel_head_var(model):
24 | if type(model) == models.AlexNet:
25 | model.head_var = 'classifier'
26 | elif type(model) == models.DenseNet:
27 | model.head_var = 'classifier'
28 | elif type(model) == models.Inception3:
29 | model.head_var = 'fc'
30 | elif type(model) == models.ResNet:
31 | model.head_var = 'fc'
32 | elif type(model) == models.VGG:
33 | model.head_var = 'classifier'
34 | elif type(model) == models.GoogLeNet:
35 | model.head_var = 'fc'
36 | elif type(model) == models.MobileNetV2:
37 | model.head_var = 'classifier'
38 | elif type(model) == models.ShuffleNetV2:
39 | model.head_var = 'fc'
40 | elif type(model) == models.SqueezeNet:
41 | model.head_var = 'classifier'
42 | else:
43 | raise ModuleNotFoundError
44 |
--------------------------------------------------------------------------------
/src/networks/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from copy import deepcopy
4 |
5 |
6 | class LLL_Net(nn.Module):
7 | """Basic class for implementing networks"""
8 |
9 | def __init__(self, model, remove_existing_head=False):
10 | head_var = model.head_var
11 | assert type(head_var) == str
12 | assert not remove_existing_head or hasattr(model, head_var), \
13 | "Given model does not have a variable called {}".format(head_var)
14 | assert not remove_existing_head or type(getattr(model, head_var)) in [nn.Sequential, nn.Linear], \
15 | "Given model's head {} does is not an instance of nn.Sequential or nn.Linear".format(head_var)
16 | super(LLL_Net, self).__init__()
17 |
18 | self.model = model
19 | last_layer = getattr(self.model, head_var)
20 |
21 | if remove_existing_head:
22 | if type(last_layer) == nn.Sequential:
23 | self.out_size = last_layer[-1].in_features
24 | # strips off last linear layer of classifier
25 | del last_layer[-1]
26 | elif type(last_layer) == nn.Linear:
27 | self.out_size = last_layer.in_features
28 | # converts last layer into identity
29 | # setattr(self.model, head_var, nn.Identity())
30 | # WARNING: this is for when pytorch version is <1.2
31 | setattr(self.model, head_var, nn.Sequential())
32 | else:
33 | self.out_size = last_layer.out_features
34 |
35 | self.heads = nn.ModuleList()
36 | self.task_cls = []
37 | self.task_offset = []
38 | self._initialize_weights()
39 |
40 | def add_head(self, num_outputs):
41 | """Add a new head with the corresponding number of outputs. Also update the number of classes per task and the
42 | corresponding offsets
43 | """
44 | self.heads.append(nn.Linear(self.out_size, num_outputs, bias=False))
45 | # we re-compute instead of append in case an approach makes changes to the heads
46 | self.task_cls = torch.tensor([head.out_features for head in self.heads])
47 | self.task_offset = torch.cat([torch.LongTensor(1).zero_(), self.task_cls.cumsum(0)[:-1]])
48 |
49 | def forward(self, x, return_features=False):
50 | """Applies the forward pass
51 |
52 | Simplification to work on multi-head only -- returns all head outputs in a list
53 | Args:
54 | x (tensor): input images
55 | return_features (bool): return the representations before the heads
56 | """
57 | x = self.model(x)
58 | assert (len(self.heads) > 0), "Cannot access any head"
59 | y = []
60 | for head in self.heads:
61 | y.append(head(x))
62 | if return_features:
63 | return y, x
64 | else:
65 | return y
66 |
67 | # hard coded a interface specifically for podnet
68 | # output: prediction y, features x, pod_features
69 | def forward_pod(self, x):
70 | x, pod_features = self.model(x, return_pod=True)
71 | y = []
72 | for head in self.heads:
73 | y.append(head(x))
74 | return y, x, pod_features
75 |
76 | def forward_repres(self, x):
77 | repres = self.model(x)
78 | return repres
79 |
80 | def forward_cls(self, repres):
81 | y = []
82 | for head in self.heads:
83 | y.append(head(repres))
84 | return y
85 |
86 | def get_copy(self):
87 | """Get weights from the model"""
88 | return deepcopy(self.state_dict())
89 |
90 | def set_state_dict(self, state_dict):
91 | """Load weights into the model"""
92 | self.load_state_dict(deepcopy(state_dict))
93 | return
94 |
95 | def freeze_all(self):
96 | """Freeze all parameters from the model, including the heads"""
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def freeze_backbone(self):
101 | """Freeze all parameters from the main model, but not the heads"""
102 | for param in self.model.parameters():
103 | param.requires_grad = False
104 |
105 | def freeze_bn(self):
106 | """Freeze all Batch Normalization layers from the model and use them in eval() mode"""
107 | for m in self.model.modules():
108 | if isinstance(m, nn.BatchNorm2d):
109 | m.eval()
110 |
111 | def _initialize_weights(self):
112 | """Initialize weights using different strategies"""
113 | # TODO: add different initialization strategies
114 | pass
115 |
--------------------------------------------------------------------------------
/src/networks/resnet18.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 | def conv3x3(in_planes, out_planes, stride=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=1, bias=False)
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, inplanes, planes, stride=1, downsample=None):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = conv3x3(inplanes, planes, stride)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.conv2 = conv3x3(planes, planes)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.downsample = downsample
21 | self.stride = stride
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 |
33 | if self.downsample is not None:
34 | residual = self.downsample(x)
35 |
36 | out += residual
37 | out = self.relu(out)
38 |
39 | return out
40 |
41 | class ResNet(nn.Module):
42 |
43 | def __init__(self, block, layers, num_classes=1000):
44 | self.inplanes = 64
45 | super(ResNet, self).__init__()
46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
47 | bias=False)
48 | self.bn1 = nn.BatchNorm2d(64)
49 | self.relu = nn.ReLU(inplace=True)
50 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
51 | self.layer1 = self._make_layer(block, 64, layers[0])
52 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
53 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
54 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
55 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
56 |
57 | # last classifier layer (head) with as many outputs as classes
58 | self.fc = nn.Linear(512 * block.expansion, num_classes)
59 | self.last_dim = self.fc.in_features
60 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
61 | self.head_var = 'fc'
62 |
63 | for m in self.modules():
64 | if isinstance(m, nn.Conv2d):
65 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
66 | elif isinstance(m, nn.BatchNorm2d):
67 | nn.init.constant_(m.weight, 1)
68 | nn.init.constant_(m.bias, 0)
69 |
70 | def _make_layer(self, block, planes, blocks, stride=1):
71 | downsample = None
72 | if stride != 1 or self.inplanes != planes * block.expansion:
73 | downsample = nn.Sequential(
74 | nn.Conv2d(self.inplanes, planes * block.expansion,
75 | kernel_size=1, stride=stride, bias=False),
76 | nn.BatchNorm2d(planes * block.expansion),
77 | )
78 |
79 | layers = []
80 | layers.append(block(self.inplanes, planes, stride, downsample))
81 | self.inplanes = planes * block.expansion
82 |
83 | for i in range(1, blocks):
84 | layers.append(block(self.inplanes, planes))
85 |
86 | return nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | x = self.conv1(x)
90 | x = self.bn1(x)
91 | x = self.relu(x)
92 | x = self.maxpool(x)
93 |
94 | x = self.layer1(x)
95 | x = self.layer2(x)
96 | x = self.layer3(x)
97 | x = self.layer4(x)
98 |
99 | x = self.avgpool(x)
100 | x = x.view(x.size(0), -1)
101 | x = self.fc(x)
102 |
103 | return x
104 |
105 |
106 | def resnet18(pretrained=False, **kwargs):
107 | """Constructs a ResNet-18 model.
108 |
109 | Args:
110 | pretrained (bool): If True, returns a model pre-trained on ImageNet
111 | """
112 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
113 | return model
114 |
--------------------------------------------------------------------------------
/src/networks/resnet18_cifar.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 | def conv3x3(in_planes, out_planes, stride=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=1, bias=False)
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, inplanes, planes, stride=1, downsample=None):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = conv3x3(inplanes, planes, stride)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.conv2 = conv3x3(planes, planes)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.downsample = downsample
21 | self.stride = stride
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 |
33 | if self.downsample is not None:
34 | residual = self.downsample(x)
35 |
36 | out += residual
37 | out = self.relu(out)
38 |
39 | return out
40 |
41 | class ResNet(nn.Module):
42 |
43 | def __init__(self, block, layers, num_classes=1000):
44 | self.inplanes = 64
45 | super(ResNet, self).__init__()
46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
47 | bias=False)
48 | self.bn1 = nn.BatchNorm2d(64)
49 | self.relu = nn.ReLU(inplace=True)
50 | self.layer1 = self._make_layer(block, 64, layers[0])
51 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
52 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
53 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
54 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
55 |
56 | # last classifier layer (head) with as many outputs as classes
57 | self.fc = nn.Linear(512 * block.expansion, num_classes)
58 | self.last_dim = self.fc.in_features
59 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
60 | self.head_var = 'fc'
61 |
62 | for m in self.modules():
63 | if isinstance(m, nn.Conv2d):
64 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
65 | elif isinstance(m, nn.BatchNorm2d):
66 | nn.init.constant_(m.weight, 1)
67 | nn.init.constant_(m.bias, 0)
68 |
69 | def _make_layer(self, block, planes, blocks, stride=1):
70 | downsample = None
71 | if stride != 1 or self.inplanes != planes * block.expansion:
72 | downsample = nn.Sequential(
73 | nn.Conv2d(self.inplanes, planes * block.expansion,
74 | kernel_size=1, stride=stride, bias=False),
75 | nn.BatchNorm2d(planes * block.expansion),
76 | )
77 |
78 | layers = []
79 | layers.append(block(self.inplanes, planes, stride, downsample))
80 | self.inplanes = planes * block.expansion
81 |
82 | for i in range(1, blocks):
83 | layers.append(block(self.inplanes, planes))
84 |
85 | return nn.Sequential(*layers)
86 |
87 | def forward(self, x):
88 | x = self.conv1(x)
89 | x = self.bn1(x)
90 | x = self.relu(x)
91 |
92 | x = self.layer1(x)
93 | x = self.layer2(x)
94 | x = self.layer3(x)
95 | x = self.layer4(x)
96 |
97 | x = self.avgpool(x)
98 | x = x.view(x.size(0), -1)
99 | x = self.fc(x)
100 |
101 | return x
102 |
103 |
104 | def resnet18_cifar(pretrained=False, **kwargs):
105 | """Constructs a ResNet-18 model.
106 |
107 | Args:
108 | pretrained (bool): If True, returns a model pre-trained on ImageNet
109 | """
110 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
111 | return model
112 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | from sklearn.metrics import confusion_matrix
6 |
7 | cudnn_deterministic = True
8 |
9 |
10 | def seed_everything(seed=0):
11 | """Fix all random seeds"""
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | os.environ['PYTHONHASHSEED'] = str(seed)
17 | torch.backends.cudnn.deterministic = cudnn_deterministic
18 |
19 |
20 | def print_summary(taskcla, acc_taw, acc_tag, forg_taw, forg_tag):
21 | """Print summary of results"""
22 | tag_acc = []
23 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]):
24 | print('*' * 108)
25 | print(name)
26 | for i in range(metric.shape[0]):
27 | print('\t', end='')
28 | for j in range(metric.shape[1]):
29 | print('{:5.1f}% '.format(100 * metric[i, j]), end='')
30 |
31 | # calculate average
32 | task_weight = np.array([ncla for _,ncla in taskcla[0:i+1]])
33 | task_weight = task_weight / task_weight.sum()
34 |
35 | if np.trace(metric) == 0.0:
36 | if i > 0:
37 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='')
38 | else:
39 | avg_metric = 100 * (metric[i, :i + 1]*task_weight).sum()
40 | print('\tAvg.:{:5.1f}% '.format(avg_metric), end='')
41 | if name == 'TAg Acc':
42 | tag_acc.append(avg_metric)
43 | print()
44 | print('*' * 108)
45 | avg_tag_acc = np.array(tag_acc).mean()
46 | print('Average Incremental Accuracy: ', avg_tag_acc)
47 | print('done')
48 |
49 | # save results of abalation study
50 | def save_summary(save_path, taskcla, acc_taw, acc_tag, forg_taw, forg_tag, appr_args):
51 | """save summary of results"""
52 | with open(save_path, 'w') as f:
53 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]):
54 | f.write('*' * 108 + '\n')
55 | f.write(name + '\n')
56 | for i in range(metric.shape[0]):
57 | f.write('\t')
58 | for j in range(metric.shape[1]):
59 | f.write('{:5.1f}% '.format(100 * metric[i, j]))
60 |
61 | # calculate average
62 | task_weight = np.array([ncla for _,ncla in taskcla[0:i+1]])
63 | task_weight = task_weight / task_weight.sum()
64 |
65 | if np.trace(metric) == 0.0:
66 | if i > 0:
67 | f.write('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()))
68 | else:
69 | f.write('\tAvg.:{:5.1f}% '.format(100 * (metric[i, :i + 1]*task_weight).sum()))
70 | f.write('\n')
71 |
72 | # --------------- approach arguments ------------------
73 | f.write('*' * 108 + '\n')
74 | f.write('Approach arguments =\n')
75 | for arg in np.sort(list(vars(appr_args).keys())):
76 | f.write('\t' + arg + ': ' + str(getattr(appr_args, arg)) + '\n')
77 | f.write('=' * 108 + '\n')
78 | # -----------------------------------------------------
79 |
80 | # val_loaders: a list of data loaders
81 | def compute_confusion_matrix(model, val_loaders, num_classes):
82 | with torch.no_grad():
83 | model.eval()
84 | num_classes = sum([head.out_features for head in model.heads])
85 | cm = np.zeros((num_classes, num_classes))
86 | for loader in val_loaders:
87 | for images, targets in loader:
88 | images = images.cuda()
89 | outputs = model(images)
90 | outputs = torch.cat(outputs, dim=1)
91 |
92 | pred = outputs.argmax(dim=1).cpu().numpy()
93 | targets = targets.cpu().numpy()
94 | cm += confusion_matrix(targets, pred, labels=np.arange(num_classes))
95 | return cm
96 |
--------------------------------------------------------------------------------