├── LICENSE ├── README.md ├── requirements.txt └── src ├── approach ├── .DS_Store ├── __init__.py ├── aux_loss.py ├── bic.py ├── dmc.py ├── eeil.py ├── ewc.py ├── finetuning.py ├── freezing.py ├── il2m.py ├── incremental_learning.py ├── joint.py ├── lucir.py ├── lucir_cwd.py ├── lucir_oracle.py ├── lucir_utils.py ├── lwf.py ├── lwm.py ├── mas.py ├── path_integral.py ├── r_walk.py └── utils.py ├── data └── imagenet │ └── gen_lst_imagenet.py ├── datasets ├── base_dataset.py ├── data_loader.py ├── dataset_config.py ├── exemplars_dataset.py ├── exemplars_selection.py └── memory_dataset.py ├── exp_cifar_lucir.sh ├── exp_cifar_lucir_cwd.sh ├── exp_im100_joint.sh ├── exp_im100_lucir.sh ├── exp_im100_lucir_cwd.sh ├── exp_im100_lucir_oracle.sh ├── gridsearch.py ├── gridsearch_config.py ├── last_layer_analysis.py ├── loggers ├── disk_logger.py ├── exp_logger.py └── tensorboard_logger.py ├── main_incremental.py ├── networks ├── __init__.py ├── network.py ├── resnet18.py └── resnet18_cifar.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yujun Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (CVPR 2022) Mimicking the Oracle: An Initial Phase Decorrelation Approach for Class Incremental Learning [ArXiv](https://arxiv.org/abs/2112.04731) 2 | This repo contains Official Implementation of our CVPR 2022 paper: Mimicking the Oracle: An Initial Phase Decorrelation Approach for Class Incremental Learning. 3 | 4 | 5 | 6 | ### 1. Abstract 7 | 8 | Class Incremental Learning (CIL) aims at learning a classifier in a phase-by-phase manner, in which only data of a subset of the classes are provided at each phase. Previous works mainly focus on mitigating forgetting in phases after the initial one. However, we find that improving CIL at its initial phase is also a promising direction. Specifically, we experimentally show that directly encouraging CIL Learner at the initial phase to output similar representations as the model jointly trained on all classes can greatly boost the CIL performance. Motivated by this, we study the difference between a na\"ively-trained initial-phase model and the oracle model. Specifically, since one major difference between these two models is the number of training classes, we investigate how such difference affects the model representations. We find that, with fewer training classes, the data representations of each class lie in a long and narrow region; with more training classes, the representations of each class scatter more uniformly. Inspired by this observation, we propose **C**lass-**w**ise **D**ecorrelation (**CwD**) that effectively regularizes representations of each class to scatter more uniformly, thus mimicking the model jointly trained with all classes (i.e., the oracle model). Our CwD is simple to implement and easy to plug into existing methods. Extensive experiments on various benchmark datasets show that CwD consistently and significantly improves the performance of existing state-of-the-art methods by around 1% to 3%. 9 | 10 |
11 | 12 | 13 | 14 | ### 2. Instructions to Run Our Code 15 | 16 | Current codebase only contain experiments on [LUCIR](https://openaccess.thecvf.com/content_CVPR_2019/papers/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.pdf) with CIFAR100 and ImageNet100. Code reproducing results based on [PODNet](https://github.com/arthurdouillard/incremental_learning.pytorch) and [AANet](https://github.com/yaoyao-liu/class-incremental-learning) are based on their repo and will be coming soon! 17 | 18 |
19 | 20 | #### CIFAR100 Experiments w/ LUCIR 21 | 22 | No need to download the datasets, everything will be dealt with automatically. 23 | 24 | For LUCIR baseline, simply first navigate under "src" folder and run: 25 | 26 | ```bash 27 | bash exp_cifar_lucir.sh 28 | ``` 29 | 30 | For LUCIR + CwD, first navigate under "src" folder and run: 31 | 32 | ```bash 33 | bash exp_cifar_lucir_cwd.sh 34 | ``` 35 | 36 | #### ImageNet100 Experiments w/ LUCIR 37 | 38 | To run ImageNet100, please follow the following two steps: 39 | 40 | Step 1: 41 | 42 | download and extract imagenet dataset under "src/data/imagenet" folder. 43 | 44 | Then, under "src/data/imagenet", run: 45 | 46 | ```bash 47 | python3 gen_lst.py 48 | ``` 49 | 50 | This command will generate two list that determine the order of classes for class incremental learning. The class order is shuffled by seed 1993 like most previous works. 51 | 52 |
53 | 54 | Step 2: 55 | 56 | For LUCIR baseline, first navigate under "src" folder and run: 57 | 58 | ```bash 59 | bash exp_im100_lucir.sh 60 | ``` 61 | 62 | For LUCIR+CWD, first navigate under "src" folder and run: 63 | 64 | ```bash 65 | bash exp_im100_lucir_cwd.sh 66 | ``` 67 | 68 | 69 | 70 | #### Some Comments on Running Scripts. 71 | 72 | For "SEED" variable in the scripts, it is not the seed that used to shuffle the class order, it is the seed that determines model initialisation/data loader sampling, etc. We vary "SEED" from 0,1,2 and average the Average Incremental Accuracy to obtain results reported in the paper. 73 | 74 |
75 | 76 | 77 | 78 | ### 3. For customized usage 79 | 80 | To use our CwD loss in your own project, simply copy and paste the CwD loss implemented in "src/approach/aux\_loss.py" will be fine. 81 | 82 |
83 | 84 | 85 | 86 | ### 4. Citation 87 | 88 | If you find our repo/paper helpful, please consider citing our work :) 89 | ``` 90 | @inproceedings{shi2022mimicking, 91 | title={Mimicking the oracle: an initial phase decorrelation approach for class incremental learning}, 92 | author={Shi, Yujun and Zhou, Kuangqi and Liang, Jian and Jiang, Zihang and Feng, Jiashi and Torr, Philip HS and Bai, Song and Tan, Vincent YF}, 93 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 94 | pages={16722--16731}, 95 | year={2022} 96 | } 97 | ``` 98 | 99 | 100 | 101 | ### 5. Contact 102 | 103 | Yujun Shi (shi.yujun@u.nus.edu) 104 | 105 | 106 | 107 | ### 6. Acknowledgements 108 | 109 | Our code is based on [FACIL](https://github.com/mmasana/FACIL), one of the most well-written CIL library in my opinion:) 110 | 111 | 112 | 113 | ### 7. Some Additional Remarks 114 | 115 | Based on the original implementation of FACIL, I also implemented Distributed Data Parallel to enable multi-GPU training. However, it seems that the performance is not as good as single card training (about 0.5% lower). Therefore, in all experiments, I still use single card training. 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # NOTE: Previous versions of pytorch and torchvision might also work as well, 2 | # but we haven't test them yet 3 | torch>=1.7.1 4 | torchvision>=0.8.2 5 | matplotlib 6 | numpy 7 | tensorboard -------------------------------------------------------------------------------- /src/approach/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yujun-Shi/CwD/291e2289b00140b81477f5a5b3e5e78938c6e8cd/src/approach/.DS_Store -------------------------------------------------------------------------------- /src/approach/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # list all approaches available 4 | __all__ = list( 5 | map(lambda x: x[:-3], 6 | filter(lambda x: x not in ['__init__.py', 'aux_loss.py', 'incremental_learning.py'] and x.endswith('.py'), 7 | os.listdir(os.path.dirname(__file__)) 8 | ) 9 | ) 10 | ) 11 | -------------------------------------------------------------------------------- /src/approach/aux_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from .utils import * 6 | from torch.distributions import Normal, Independent 7 | from torch import distributed as dist 8 | 9 | # function credit to https://github.com/facebookresearch/barlowtwins/blob/main/main.py 10 | def off_diagonal(x): 11 | # return a flattened view of the off-diagonal elements of a square matrix 12 | n, m = x.shape 13 | assert n == m 14 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 15 | 16 | class DecorrelateLossClass(nn.Module): 17 | 18 | def __init__(self, reject_threshold=1, ddp=False): 19 | super(DecorrelateLossClass, self).__init__() 20 | self.eps = 1e-8 21 | self.reject_threshold = reject_threshold 22 | self.ddp = ddp 23 | 24 | def forward(self, x, y): 25 | _, C = x.shape 26 | if self.ddp: 27 | # if DDP 28 | # first gather all x and labels from the world 29 | x = torch.cat(GatherLayer.apply(x), dim=0) 30 | y = global_gather(y) 31 | 32 | loss = 0.0 33 | uniq_l, uniq_c = y.unique(return_counts=True) 34 | n_count = 0 35 | for i, label in enumerate(uniq_l): 36 | if uniq_c[i] <= self.reject_threshold: 37 | continue 38 | x_label = x[y==label, :] 39 | x_label = x_label - x_label.mean(dim=0, keepdim=True) 40 | x_label = x_label / torch.sqrt(self.eps + x_label.var(dim=0, keepdim=True)) 41 | 42 | N = x_label.shape[0] 43 | corr_mat = torch.matmul(x_label.t(), x_label) 44 | 45 | # Notice that here the implementation is a little bit different 46 | # from the paper as we extract only the off-diagonal terms for regularization. 47 | # Mathematically, these two are the same thing since diagonal terms are all constant 1. 48 | # However, we find that this implementation is more numerically stable. 49 | loss += (off_diagonal(corr_mat).pow(2)).mean() 50 | 51 | n_count += N 52 | 53 | if n_count == 0: 54 | # there is no effective class to compute correlation matrix 55 | return 0 56 | else: 57 | loss = loss / n_count 58 | return loss 59 | -------------------------------------------------------------------------------- /src/approach/dmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from copy import deepcopy 4 | from argparse import ArgumentParser 5 | 6 | from datasets.data_loader import get_loaders 7 | from .incremental_learning import Inc_Learning_Appr 8 | from datasets.exemplars_dataset import ExemplarsDataset 9 | 10 | 11 | class Appr(Inc_Learning_Appr): 12 | """ Class implementing the Deep Model Consolidation (DMC) approach 13 | described in https://arxiv.org/abs/1903.07864 14 | Original code available at https://github.com/juntingzh/incremental-learning-baselines 15 | """ 16 | 17 | def __init__(self, model, device, nepochs=160, lr=0.1, lr_min=1e-4, lr_factor=10, lr_patience=8, clipgrad=10000, 18 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 19 | logger=None, exemplars_dataset=None, aux_dataset='imagenet_32', aux_batch_size=128): 20 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 21 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 22 | exemplars_dataset) 23 | self.model_old = None 24 | self.model_new = None 25 | self.aux_dataset = aux_dataset 26 | self.aux_batch_size = aux_batch_size 27 | # get dataloader for auxiliar dataset 28 | aux_trn_ldr, _, aux_val_ldr, _ = get_loaders([self.aux_dataset], num_tasks=1, nc_first_task=None, validation=0, 29 | batch_size=self.aux_batch_size, num_workers=4, pin_memory=False) 30 | self.aux_trn_loader = aux_trn_ldr[0] 31 | self.aux_val_loader = aux_val_ldr[0] 32 | # Since an auxiliary dataset is available, using exemplars could be redundant 33 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 34 | assert (have_exemplars == 0), 'Warning: DMC does not use exemplars. Comment this line to force it.' 35 | 36 | @staticmethod 37 | def exemplars_dataset_class(): 38 | return ExemplarsDataset 39 | 40 | @staticmethod 41 | def extra_parser(args): 42 | """Returns a parser containing the approach specific parameters""" 43 | parser = ArgumentParser() 44 | # Sec. 4.2.1 "We use ImageNet32x32 dataset as the source for auxiliary data in the model consolidation stage." 45 | parser.add_argument('--aux-dataset', default='imagenet_32_reduced', type=str, required=False, 46 | help='Auxiliary dataset (default=%(default)s)') 47 | parser.add_argument('--aux-batch-size', default=128, type=int, required=False, 48 | help='Batch size for auxiliary dataset (default=%(default)s)') 49 | return parser.parse_known_args(args) 50 | 51 | def _get_optimizer(self): 52 | """Returns the optimizer""" 53 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 54 | # if there are no exemplars, previous heads are not modified 55 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 56 | else: 57 | params = self.model.parameters() 58 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 59 | 60 | def pre_train_process(self, t, trn_loader): 61 | """Runs before training all epochs of the task (before the train session)""" 62 | if t > 0: 63 | # Re-initialize model 64 | for m in self.model.modules(): 65 | if isinstance(m, (nn.Conv2d, nn.BatchNorm2d, nn.Linear)): 66 | m.reset_parameters() 67 | # Get new model 68 | self.model_new = deepcopy(self.model) 69 | for h in self.model_new.heads[:-1]: 70 | with torch.no_grad(): 71 | h.weight.zero_() 72 | h.bias.zero_() 73 | for p in h.parameters(): 74 | p.requires_grad = False 75 | else: 76 | self.model_new = self.model 77 | 78 | def train_loop(self, t, trn_loader, val_loader): 79 | """Contains the epochs loop""" 80 | if t > 0: 81 | # Args for the new data trainer and for the student trainer are the same 82 | dmc_args = dict(nepochs=self.nepochs, lr=self.lr, lr_min=self.lr_min, lr_factor=self.lr_factor, 83 | lr_patience=self.lr_patience, clipgrad=self.clipgrad, momentum=self.momentum, 84 | wd=self.wd, multi_softmax=self.multi_softmax, wu_nepochs=self.warmup_epochs, 85 | wu_lr_factor=self.warmup_lr, fix_bn=self.fix_bn, logger=self.logger) 86 | # Train new model in new data 87 | new_trainer = NewTaskTrainer(self.model_new, self.device, **dmc_args) 88 | new_trainer.train_loop(t, trn_loader, val_loader) 89 | self.model_new.eval() 90 | self.model_new.freeze_all() 91 | print('=' * 108) 92 | print("Training of student") 93 | print('=' * 108) 94 | # Train student model using both old and new model 95 | student_trainer = StudentTrainer(self.model, self.model_new, self.model_old, self.device, **dmc_args) 96 | student_trainer.train_loop(t, self.aux_trn_loader, self.aux_val_loader) 97 | else: 98 | # FINETUNING TRAINING -- contains the epochs loop 99 | super().train_loop(t, trn_loader, val_loader) 100 | 101 | def post_train_process(self, t, trn_loader): 102 | """Runs after training all the epochs of the task (after the train session)""" 103 | 104 | # Restore best and save model for future tasks 105 | self.model_old = deepcopy(self.model) 106 | self.model_old.eval() 107 | self.model_old.freeze_all() 108 | 109 | 110 | class NewTaskTrainer(Inc_Learning_Appr): 111 | def __init__(self, model, device, nepochs=160, lr=0.1, lr_min=1e-4, lr_factor=10, lr_patience=8, clipgrad=10000, 112 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 113 | eval_on_train=False, logger=None): 114 | super(NewTaskTrainer, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, 115 | momentum, wd, multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, 116 | eval_on_train, logger) 117 | 118 | 119 | class StudentTrainer(Inc_Learning_Appr): 120 | def __init__(self, model, model_new, model_old, device, nepochs=160, lr=0.1, lr_min=1e-4, lr_factor=10, 121 | lr_patience=8, clipgrad=10000, momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, 122 | wu_lr_factor=1, fix_bn=False, eval_on_train=False, logger=None): 123 | super(StudentTrainer, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, 124 | momentum, wd, multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, 125 | eval_on_train, logger) 126 | 127 | self.model_old = model_old 128 | self.model_new = model_new 129 | 130 | # Runs a single epoch of student's training 131 | def train_epoch(self, t, trn_loader): 132 | self.model.train() 133 | if self.fix_bn and t > 0: 134 | self.model.freeze_bn() 135 | for images, targets in trn_loader: 136 | images, targets = images.cuda(), targets.cuda() 137 | # Forward old and new model 138 | targets_old = self.model_old(images) 139 | targets_new = self.model_new(images) 140 | # Forward current model 141 | outputs = self.model(images) 142 | loss = self.criterion(t, outputs, targets_old, targets_new) 143 | # Backward 144 | self.optimizer.zero_grad() 145 | loss.backward() 146 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 147 | self.optimizer.step() 148 | 149 | # Contains the evaluation code for evaluating the student 150 | def eval(self, t, val_loader): 151 | with torch.no_grad(): 152 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0 153 | self.model.eval() 154 | for images, targets in val_loader: 155 | images = images.cuda() 156 | # Forward old and new model 157 | targets_old = self.model_old(images) 158 | targets_new = self.model_new(images) 159 | # Forward current model 160 | outputs = self.model(images) 161 | loss = self.criterion(t, outputs, targets_old, targets_new) 162 | # Log 163 | total_loss += loss.item() * len(targets) 164 | total_num += len(targets) 165 | return total_loss / total_num, -1, -1 166 | 167 | # Returns the loss value for the student 168 | def criterion(self, t, outputs, targets_old, targets_new=None): 169 | # Eq. 2: Model Consolidation 170 | with torch.no_grad(): 171 | # Eq. 4: "The regression target of the consolidated model is the concatenation of normalized logits of 172 | # the two specialist models." 173 | targets = torch.cat(targets_old[:t] + [targets_new[t]], dim=1) 174 | targets -= targets.mean(0) 175 | # Eq. 3: Double Distillation Loss 176 | return torch.nn.functional.mse_loss(torch.cat(outputs, dim=1), targets.detach(), reduction='mean') 177 | -------------------------------------------------------------------------------- /src/approach/eeil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from copy import deepcopy 4 | from argparse import ArgumentParser 5 | from torch.nn import functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from .incremental_learning import Inc_Learning_Appr 9 | from datasets.exemplars_dataset import ExemplarsDataset 10 | 11 | 12 | class Appr(Inc_Learning_Appr): 13 | """Class implementing the End-to-end Incremental Learning (EEIL) approach described in 14 | http://openaccess.thecvf.com/content_ECCV_2018/papers/Francisco_M._Castro_End-to-End_Incremental_Learning_ECCV_2018_paper.pdf 15 | Original code available at https://github.com/fmcp/EndToEndIncrementalLearning 16 | Helpful code from https://github.com/arthurdouillard/incremental_learning.pytorch 17 | """ 18 | 19 | def __init__(self, model, device, nepochs=90, lr=0.1, lr_min=1e-6, lr_factor=10, lr_patience=5, clipgrad=10000, 20 | momentum=0.9, wd=0.0001, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 21 | eval_on_train=False, logger=None, exemplars_dataset=None, lamb=1.0, T=2, lr_finetuning_factor=0.1, 22 | nepochs_finetuning=40, noise_grad=False): 23 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 24 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 25 | exemplars_dataset) 26 | self.model_old = None 27 | self.lamb = lamb 28 | self.T = T 29 | self.lr_finetuning_factor = lr_finetuning_factor 30 | self.nepochs_finetuning = nepochs_finetuning 31 | self.noise_grad = noise_grad 32 | 33 | self._train_epoch = 0 34 | self._finetuning_balanced = None 35 | 36 | # EEIL is expected to be used with exemplars. If needed to be used without exemplars, overwrite here the 37 | # `_get_optimizer` function with the one in LwF and update the criterion 38 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 39 | if not have_exemplars: 40 | warnings.warn("Warning: EEIL is expected to use exemplars. Check documentation.") 41 | 42 | @staticmethod 43 | def exemplars_dataset_class(): 44 | return ExemplarsDataset 45 | 46 | @staticmethod 47 | def extra_parser(args): 48 | """Returns a parser containing the approach specific parameters""" 49 | parser = ArgumentParser() 50 | # Added trade-off between the terms of Eq. 1 -- L = L_C + lamb * L_D 51 | parser.add_argument('--lamb', default=1.0, type=float, required=False, 52 | help='Forgetting-intransigence trade-off (default=%(default)s)') 53 | # Page 6: "Based on our empirical results, we set T to 2 for all our experiments" 54 | parser.add_argument('--T', default=2.0, type=float, required=False, 55 | help='Temperature scaling (default=%(default)s)') 56 | # "The same reduction is used in the case of fine-tuning, except that the starting rate is 0.01." 57 | parser.add_argument('--lr-finetuning-factor', default=0.01, type=float, required=False, 58 | help='Finetuning learning rate factor (default=%(default)s)') 59 | # Number of epochs for balanced training 60 | parser.add_argument('--nepochs-finetuning', default=40, type=int, required=False, 61 | help='Number of epochs for balanced training (default=%(default)s)') 62 | # the addition of noise to the gradients 63 | parser.add_argument('--noise-grad', action='store_true', 64 | help='Add noise to gradients (default=%(default)s)') 65 | return parser.parse_known_args(args) 66 | 67 | def _train_unbalanced(self, t, trn_loader, val_loader): 68 | """Unbalanced training""" 69 | self._finetuning_balanced = False 70 | self._train_epoch = 0 71 | loader = self._get_train_loader(trn_loader, False) 72 | super().train_loop(t, loader, val_loader) 73 | return loader 74 | 75 | def _train_balanced(self, t, trn_loader, val_loader): 76 | """Balanced finetuning""" 77 | self._finetuning_balanced = True 78 | self._train_epoch = 0 79 | orig_lr = self.lr 80 | self.lr *= self.lr_finetuning_factor 81 | orig_nepochs = self.nepochs 82 | self.nepochs = self.nepochs_finetuning 83 | loader = self._get_train_loader(trn_loader, True) 84 | super().train_loop(t, loader, val_loader) 85 | self.lr = orig_lr 86 | self.nepochs = orig_nepochs 87 | 88 | def _get_train_loader(self, trn_loader, balanced=False): 89 | """Modify loader to be balanced or unbalanced""" 90 | exemplars_ds = self.exemplars_dataset 91 | trn_dataset = trn_loader.dataset 92 | if balanced: 93 | indices = torch.randperm(len(trn_dataset)) 94 | trn_dataset = torch.utils.data.Subset(trn_dataset, indices[:len(exemplars_ds)]) 95 | ds = exemplars_ds + trn_dataset 96 | return DataLoader(ds, batch_size=trn_loader.batch_size, 97 | shuffle=True, 98 | num_workers=trn_loader.num_workers, 99 | pin_memory=trn_loader.pin_memory) 100 | 101 | def _noise_grad(self, parameters, iteration, eta=0.3, gamma=0.55): 102 | """Add noise to the gradients""" 103 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 104 | variance = eta / ((1 + iteration) ** gamma) 105 | for p in parameters: 106 | p.grad.add_(torch.randn(p.grad.shape, device=p.grad.device) * variance) 107 | 108 | def train_loop(self, t, trn_loader, val_loader): 109 | """Contains the epochs loop""" 110 | if t == 0: # First task is simple training 111 | super().train_loop(t, trn_loader, val_loader) 112 | loader = trn_loader 113 | else: 114 | # Page 4: "4. Incremental Learning" -- Only modification is that instead of preparing examplars before 115 | # training, we do it online using the stored old model. 116 | 117 | # Training process (new + old) - unbalanced training 118 | loader = self._train_unbalanced(t, trn_loader, val_loader) 119 | # Balanced fine-tunning (new + old) 120 | self._train_balanced(t, trn_loader, val_loader) 121 | 122 | # After task training: update exemplars 123 | self.exemplars_dataset.collect_exemplars(self.model, loader, val_loader.dataset.transform) 124 | 125 | def post_train_process(self, t, trn_loader): 126 | """Runs after training all the epochs of the task (after the train session)""" 127 | 128 | # Save old model to extract features later 129 | self.model_old = deepcopy(self.model) 130 | self.model_old.eval() 131 | self.model_old.freeze_all() 132 | 133 | def train_epoch(self, t, trn_loader): 134 | """Runs a single epoch""" 135 | self.model.train() 136 | if self.fix_bn and t > 0: 137 | self.model.freeze_bn() 138 | for images, targets in trn_loader: 139 | images = images.to(self.device) 140 | # Forward old model 141 | outputs_old = None 142 | if t > 0: 143 | outputs_old = self.model_old(images) 144 | # Forward current model 145 | outputs = self.model(images) 146 | loss = self.criterion(t, outputs, targets.to(self.device), outputs_old) 147 | # Backward 148 | self.optimizer.zero_grad() 149 | loss.backward() 150 | # Page 8: "We apply L2-regularization and random noise [21] (with parameters eta = 0.3, gamma = 0.55) 151 | # on the gradients to minimize overfitting" 152 | # https://github.com/fmcp/EndToEndIncrementalLearning/blob/master/cnn_train_dag_exemplars.m#L367 153 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 154 | if self.noise_grad: 155 | self._noise_grad(self.model.parameters(), self._train_epoch) 156 | self.optimizer.step() 157 | self._train_epoch += 1 158 | 159 | def criterion(self, t, outputs, targets, outputs_old=None): 160 | """Returns the loss value""" 161 | 162 | # Classification loss for new classes 163 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 164 | # Distilation loss 165 | if t > 0 and outputs_old: 166 | # take into account current head when doing balanced finetuning 167 | last_head_idx = t if self._finetuning_balanced else (t - 1) 168 | for i in range(last_head_idx): 169 | loss += self.lamb * F.binary_cross_entropy(F.softmax(outputs[i] / self.T, dim=1), 170 | F.softmax(outputs_old[i] / self.T, dim=1)) 171 | return loss 172 | -------------------------------------------------------------------------------- /src/approach/ewc.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import itertools 4 | from argparse import ArgumentParser 5 | 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | from datasets.exemplars_selection import override_dataset_transform 8 | from .incremental_learning import Inc_Learning_Appr 9 | from torch.distributions.categorical import Categorical 10 | 11 | from torch.utils.data import DataLoader 12 | 13 | class Appr(Inc_Learning_Appr): 14 | """Class implementing the Elastic Weight Consolidation (EWC) approach 15 | described in http://arxiv.org/abs/1612.00796 16 | """ 17 | 18 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], clipgrad=10000, 19 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, ddp=False, 20 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred', 21 | fi_num_samples=-1, save_fisher=False): 22 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, clipgrad, momentum, wd, 23 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 24 | exemplars_dataset) 25 | self.lamb = lamb 26 | self.alpha = alpha 27 | self.sampling_type = fi_sampling_type 28 | self.num_samples = fi_num_samples 29 | 30 | # In all cases, we only keep importance weights for the model, but not for the heads. 31 | feat_ext = self.model.model 32 | # Store current parameters as the initial parameters before first task starts 33 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad} 34 | # Store fisher information weight importance 35 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 36 | if p.requires_grad} 37 | self.save_fisher = save_fisher 38 | 39 | @staticmethod 40 | def exemplars_dataset_class(): 41 | return ExemplarsDataset 42 | 43 | @staticmethod 44 | def extra_parser(args): 45 | """Returns a parser containing the approach specific parameters""" 46 | parser = ArgumentParser() 47 | # Eq. 3: "lambda sets how important the old task is compared to the new one" 48 | parser.add_argument('--lamb', default=5000, type=float, required=False, 49 | help='Forgetting-intransigence trade-off (default=%(default)s)') 50 | # Define how old and new fisher is fused, by default it is a 50-50 fusion 51 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 52 | help='EWC alpha (default=%(default)s)') 53 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False, 54 | choices=['true', 'max_pred', 'multinomial'], 55 | help='Sampling type for Fisher information (default=%(default)s)') 56 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False, 57 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 58 | parser.add_argument('--save-fisher', action='store_true', 59 | help='whether to save Fisher information') 60 | return parser.parse_known_args(args) 61 | 62 | def _get_optimizer(self): 63 | """Returns the optimizer""" 64 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 65 | # if there are no exemplars, previous heads are not modified 66 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 67 | else: 68 | params = self.model.parameters() 69 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 70 | 71 | # def compute_fisher_matrix_diag(self, trn_loader): 72 | # # Store Fisher Information 73 | # fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 74 | # if p.requires_grad} 75 | # # Compute fisher information for specified number of samples -- rounded to the batch size 76 | # n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 77 | # else (len(trn_loader.dataset) // trn_loader.batch_size) 78 | # # Do forward and backward pass to compute the fisher information 79 | # self.model.train() 80 | # for images, targets in itertools.islice(trn_loader, n_samples_batches): 81 | # outputs = self.model.forward(images.to(self.device)) 82 | 83 | # if self.sampling_type == 'true': 84 | # # Use the labels to compute the gradients based on the CE-loss with the ground truth 85 | # preds = targets.to(self.device) 86 | # elif self.sampling_type == 'max_pred': 87 | # # Not use labels and compute the gradients related to the prediction the model has learned 88 | # preds = torch.cat(outputs, dim=1).argmax(1).flatten() 89 | # elif self.sampling_type == 'multinomial': 90 | # # Use a multinomial sampling to compute the gradients 91 | # # probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1) 92 | # # preds = torch.multinomial(probs, len(targets)).flatten() 93 | # preds = Categorical(logits=outputs[-1]).sample() 94 | 95 | # loss = torch.nn.functional.cross_entropy(outputs[-1], preds) 96 | # self.optimizer.zero_grad() 97 | # loss.backward() 98 | # # Accumulate all gradients from loss with regularization 99 | # for n, p in self.model.model.named_parameters(): 100 | # if p.grad is not None: 101 | # fisher[n] += p.grad.pow(2) * len(targets) 102 | # # Apply mean across all samples 103 | # n_samples = n_samples_batches * trn_loader.batch_size 104 | # fisher = {n: (p / n_samples) for n, p in fisher.items()} 105 | # return fisher 106 | 107 | def compute_fisher_matrix_diag(self, trn_loader, val_loader): 108 | # Store Fisher Information 109 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 110 | if p.requires_grad} 111 | # Do forward and backward pass to compute the fisher information 112 | self.model.eval() 113 | with override_dataset_transform(trn_loader.dataset, val_loader.dataset.transform) as _ds: 114 | fisher_loader = DataLoader(_ds, batch_size=1, shuffle=False, 115 | num_workers=trn_loader.num_workers, pin_memory=trn_loader.pin_memory) 116 | n_samples = 0 117 | for images, targets in tqdm(fisher_loader): 118 | images, targets = images.to(self.device), targets.to(self.device) 119 | 120 | outputs = self.model.forward(images) 121 | if self.sampling_type == 'true': 122 | preds = targets 123 | elif self.sampling_type == 'max_pred': 124 | preds = torch.cat(outputs, dim=1).argmax(1).flatten() 125 | elif self.sampling_type == 'multinomial': 126 | preds = Categorical(logits=outputs[-1]).sample() 127 | 128 | loss = torch.nn.functional.cross_entropy(outputs[-1], preds) 129 | self.optimizer.zero_grad() 130 | loss.backward() 131 | # Accumulate all gradients from loss with regularization 132 | for n, p in self.model.model.named_parameters(): 133 | if p.grad is not None: 134 | fisher[n] += p.grad.pow(2) * len(targets) 135 | n_samples += len(targets) 136 | 137 | fisher = {n: (p / n_samples) for n, p in fisher.items()} 138 | return fisher 139 | 140 | def train_loop(self, t, trn_loader, val_loader): 141 | """Contains the epochs loop""" 142 | 143 | # add exemplars to train_loader 144 | if len(self.exemplars_dataset) > 0 and t > 0: 145 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 146 | batch_size=trn_loader.batch_size, 147 | shuffle=True, 148 | num_workers=trn_loader.num_workers, 149 | pin_memory=trn_loader.pin_memory) 150 | 151 | # FINETUNING TRAINING -- contains the epochs loop 152 | super().train_loop(t, trn_loader, val_loader) 153 | 154 | # EXEMPLAR MANAGEMENT -- select training subset 155 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 156 | 157 | def post_train_process(self, t, trn_loader, val_loader): 158 | """Runs after training all the epochs of the task (after the train session)""" 159 | 160 | # Store current parameters for the next task 161 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 162 | 163 | # calculate Fisher information 164 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader, val_loader) 165 | # merge fisher information, we do not want to keep fisher information for each task in memory 166 | for n in self.fisher.keys(): 167 | # Added option to accumulate fisher over time with a pre-fixed growing alpha 168 | # if self.alpha == -1: 169 | # alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 170 | # self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n] 171 | # else: 172 | # self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n]) 173 | 174 | # directly adding more constraint 175 | self.fisher[n] = self.fisher[n] + curr_fisher[n] 176 | 177 | if self.save_fisher: 178 | torch.save(self.fisher, './fisher/lamb_{}_task_{}.pt'.format(self.lamb, t)) 179 | 180 | def criterion(self, t, outputs, targets): 181 | """Returns the loss value""" 182 | loss = 0 183 | if t > 0: 184 | loss_reg = 0 185 | # Eq. 3: elastic weight consolidation quadratic penalty 186 | for n, p in self.model.model.named_parameters(): 187 | if n in self.fisher.keys(): 188 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2 189 | loss += self.lamb * loss_reg 190 | # Current cross-entropy loss -- with exemplars use all heads 191 | if len(self.exemplars_dataset) > 0: 192 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 193 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 194 | -------------------------------------------------------------------------------- /src/approach/finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | import torch.nn.functional as F 8 | from argparse import ArgumentParser 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from .lucir_utils import CosineLinear, BasicBlockNoRelu, BottleneckNoRelu 11 | 12 | 13 | class Appr(Inc_Learning_Appr): 14 | """Class implementing the Class Incremental Learning With Dual Memory (IL2M) approach described in 15 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf 16 | """ 17 | 18 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 19 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 20 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None, 21 | first_task_lr=0.1, first_task_bz=128): 22 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 23 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank, 24 | logger, exemplars_dataset) 25 | self.init_classes_means = [] 26 | self.current_classes_means = [] 27 | self.models_confidence = [] 28 | # FLAG to not do scores rectification while finetuning training 29 | self.ft_train = False 30 | 31 | self.first_task_lr = first_task_lr 32 | self.first_task_bz = first_task_bz 33 | self.first_task = True 34 | 35 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 36 | assert (have_exemplars > 0), 'Error: IL2M needs exemplars.' 37 | 38 | @staticmethod 39 | def exemplars_dataset_class(): 40 | return ExemplarsDataset 41 | 42 | @staticmethod 43 | def extra_parser(args): 44 | """Returns a parser containing the approach specific parameters""" 45 | parser = ArgumentParser() 46 | parser.add_argument('--first-task-lr', default=0.1, type=float) 47 | parser.add_argument('--first-task-bz', default=32, type=int) 48 | return parser.parse_known_args(args) 49 | 50 | def _get_optimizer(self): 51 | """Returns the optimizer""" 52 | if self.ddp: 53 | model = self.model.module 54 | else: 55 | model = self.model 56 | 57 | params = model.parameters() 58 | 59 | if self.first_task: 60 | self.first_task = False 61 | optimizer = torch.optim.SGD(params, lr=self.first_task_lr, weight_decay=self.wd, momentum=self.momentum) 62 | else: 63 | optimizer = torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 64 | print(optimizer.param_groups[0]['lr']) 65 | return optimizer 66 | 67 | def pre_train_process(self, t, trn_loader): 68 | """Runs before training all epochs of the task (before the train session)""" 69 | if self.ddp: 70 | model = self.model.module 71 | else: 72 | model = self.model 73 | 74 | if t == 0: 75 | # Sec. 4.1: "the ReLU in the penultimate layer is removed to allow the features to take both positive and 76 | # negative values" 77 | if model.model.__class__.__name__ == 'ResNetCifar': 78 | old_block = model.model.layer3[-1] 79 | model.model.layer3[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu, 80 | old_block.conv2, old_block.bn2, old_block.downsample) 81 | elif model.model.__class__.__name__ == 'ResNet': 82 | old_block = model.model.layer4[-1] 83 | model.model.layer4[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu, 84 | old_block.conv2, old_block.bn2, old_block.downsample) 85 | elif model.model.__class__.__name__ == 'ResNetBottleneck': 86 | old_block = model.model.layer4[-1] 87 | model.model.layer4[-1] = BottleneckNoRelu(old_block.conv1, old_block.bn1, 88 | old_block.relu, old_block.conv2, old_block.bn2, 89 | old_block.conv3, old_block.bn3, old_block.downsample) 90 | else: 91 | warnings.warn("Warning: ReLU not removed from last block.") 92 | 93 | # Changes the new head to a CosineLinear 94 | model.heads[-1] = CosineLinear(model.heads[-1].in_features, model.heads[-1].out_features) 95 | model.to(self.device) 96 | # if t > 0: 97 | # Share sigma (Eta in paper) between all the heads 98 | # Yujun: according to il2m, since we'll correct this with model confidence 99 | # maybe we shouldn't share sigma here. 100 | # model.heads[-1].sigma = model.heads[-2].sigma 101 | 102 | # and we probably shouldn't freeze sigma here. 103 | # for h in model.heads[:-1]: 104 | # for param in h.parameters(): 105 | # param.requires_grad = False 106 | # model.heads[-1].sigma.requires_grad = True 107 | 108 | # if ddp option is activated, need to re-wrap the ddp model 109 | if self.ddp: 110 | self.model = DDP(self.model.module, device_ids=[self.local_rank]) 111 | 112 | # The original code has an option called "imprint weights" that seems to initialize the new head. 113 | # However, this is not mentioned in the paper and doesn't seem to make a significant difference. 114 | super().pre_train_process(t, trn_loader) 115 | 116 | def train_loop(self, t, trn_loader, val_loader): 117 | """Contains the epochs loop""" 118 | if t == 0: 119 | dset = trn_loader.dataset 120 | trn_loader = torch.utils.data.DataLoader(dset, 121 | batch_size=self.first_task_bz, 122 | sampler=trn_loader.sampler, 123 | num_workers=trn_loader.num_workers, 124 | pin_memory=trn_loader.pin_memory) 125 | 126 | # add exemplars to train_loader 127 | if len(self.exemplars_dataset) > 0 and t > 0: 128 | dset = trn_loader.dataset + self.exemplars_dataset 129 | if self.ddp: 130 | trn_sampler = torch.utils.data.DistributedSampler(dset, shuffle=True) 131 | trn_loader = torch.utils.data.DataLoader(dset, 132 | batch_size=trn_loader.batch_size, 133 | sampler=trn_sampler, 134 | num_workers=trn_loader.num_workers, 135 | pin_memory=trn_loader.pin_memory) 136 | else: 137 | trn_loader = torch.utils.data.DataLoader(dset, 138 | batch_size=trn_loader.batch_size, 139 | shuffle=True, 140 | num_workers=trn_loader.num_workers, 141 | pin_memory=trn_loader.pin_memory) 142 | 143 | 144 | # FINETUNING TRAINING -- contains the epochs loop 145 | self.ft_train = True 146 | super().train_loop(t, trn_loader, val_loader) 147 | self.ft_train = False 148 | 149 | if self.ddp: 150 | # need to change the trainloader to the original version without distributed sampler 151 | dset = trn_loader.dataset 152 | trn_loader = torch.utils.data.DataLoader(dset, 153 | batch_size=200, shuffle=False, num_workers=trn_loader.num_workers, 154 | pin_memory=trn_loader.pin_memory) 155 | 156 | # EXEMPLAR MANAGEMENT -- select training subset 157 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform, self.ddp) 158 | 159 | def criterion(self, t, outputs, targets): 160 | if self.ddp: 161 | model = self.model.module 162 | else: 163 | model = self.model 164 | 165 | if type(outputs[0]) == dict: 166 | outputs = [o['wsigma'] for o in outputs] 167 | 168 | """Returns the loss value""" 169 | if len(self.exemplars_dataset) > 0: 170 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 171 | return torch.nn.functional.cross_entropy(outputs[t], targets - model.task_offset[t]) 172 | -------------------------------------------------------------------------------- /src/approach/freezing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the freezing baseline""" 10 | 11 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 12 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 13 | logger=None, exemplars_dataset=None, freeze_after=0, all_outputs=False): 14 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 15 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 16 | exemplars_dataset) 17 | self.freeze_after = freeze_after 18 | self.all_out = all_outputs 19 | 20 | @staticmethod 21 | def exemplars_dataset_class(): 22 | return ExemplarsDataset 23 | 24 | @staticmethod 25 | def extra_parser(args): 26 | """Returns a parser containing the approach specific parameters""" 27 | parser = ArgumentParser() 28 | parser.add_argument('--freeze-after', default=0, type=int, required=False, 29 | help='Freeze model except current head after the specified task (default=%(default)s)') 30 | parser.add_argument('--all-outputs', action='store_true', required=False, 31 | help='Allow all weights related to all outputs to be modified (default=%(default)s)') 32 | return parser.parse_known_args(args) 33 | 34 | def _get_optimizer(self): 35 | """Returns the optimizer""" 36 | return torch.optim.SGD(self._train_parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 37 | 38 | def _has_exemplars(self): 39 | """Returns True in case exemplars are being used""" 40 | return self.exemplars_dataset is not None and len(self.exemplars_dataset) > 0 41 | 42 | def post_train_process(self, t, trn_loader): 43 | """Runs after training all the epochs of the task (after the train session)""" 44 | if t >= self.freeze_after: 45 | self.model.freeze_backbone() 46 | 47 | def train_loop(self, t, trn_loader, val_loader): 48 | """Contains the epochs loop""" 49 | 50 | # add exemplars to train_loader 51 | if t > 0 and self._has_exemplars(): 52 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 53 | batch_size=trn_loader.batch_size, 54 | shuffle=True, 55 | num_workers=trn_loader.num_workers, 56 | pin_memory=trn_loader.pin_memory) 57 | 58 | # FINETUNING TRAINING -- contains the epochs loop 59 | super().train_loop(t, trn_loader, val_loader) 60 | 61 | # EXEMPLAR MANAGEMENT -- select training subset 62 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 63 | 64 | def train_epoch(self, t, trn_loader): 65 | """Runs a single epoch""" 66 | self._model_train(t) 67 | for images, targets in trn_loader: 68 | # Forward current model 69 | outputs = self.model(images.to(self.device)) 70 | loss = self.criterion(t, outputs, targets.to(self.device)) 71 | # Backward 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | torch.nn.utils.clip_grad_norm_(self._train_parameters(), self.clipgrad) 75 | self.optimizer.step() 76 | 77 | def _model_train(self, t): 78 | """Freezes the necessary weights""" 79 | if self.fix_bn and t > 0: 80 | self.model.freeze_bn() 81 | if self.freeze_after >= 0 and t <= self.freeze_after: # non-frozen task - whole model to train 82 | self.model.train() 83 | else: 84 | self.model.model.eval() 85 | if self._has_exemplars(): 86 | # with exemplars - use all heads 87 | for head in self.model.heads: 88 | head.train() 89 | else: 90 | # no exemplars - use current head 91 | self.model.heads[-1].train() 92 | 93 | def _train_parameters(self): 94 | """Includes the necessary weights to the optimizer""" 95 | if len(self.model.heads) <= (self.freeze_after + 1): 96 | return self.model.parameters() 97 | else: 98 | if self._has_exemplars(): 99 | return [p for head in self.model.heads for p in head.parameters()] 100 | else: 101 | return self.model.heads[-1].parameters() 102 | 103 | def criterion(self, t, outputs, targets): 104 | """Returns the loss value""" 105 | if self.all_out or self._has_exemplars(): 106 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 107 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 108 | -------------------------------------------------------------------------------- /src/approach/il2m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | import torch.nn.functional as F 8 | from argparse import ArgumentParser 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from .lucir_utils import CosineLinear, BasicBlockNoRelu, BottleneckNoRelu 11 | 12 | 13 | class Appr(Inc_Learning_Appr): 14 | """Class implementing the Class Incremental Learning With Dual Memory (IL2M) approach described in 15 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf 16 | """ 17 | 18 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 19 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 20 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None, 21 | first_task_lr=0.1, first_task_bz=128): 22 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 23 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank, 24 | logger, exemplars_dataset) 25 | self.init_classes_means = [] 26 | self.current_classes_means = [] 27 | self.models_confidence = [] 28 | # FLAG to not do scores rectification while finetuning training 29 | self.ft_train = False 30 | 31 | self.first_task_lr = first_task_lr 32 | self.first_task_bz = first_task_bz 33 | self.first_task = True 34 | 35 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 36 | assert (have_exemplars > 0), 'Error: IL2M needs exemplars.' 37 | 38 | @staticmethod 39 | def exemplars_dataset_class(): 40 | return ExemplarsDataset 41 | 42 | @staticmethod 43 | def extra_parser(args): 44 | """Returns a parser containing the approach specific parameters""" 45 | parser = ArgumentParser() 46 | parser.add_argument('--first-task-lr', default=0.1, type=float) 47 | parser.add_argument('--first-task-bz', default=32, type=int) 48 | return parser.parse_known_args(args) 49 | 50 | def _get_optimizer(self): 51 | """Returns the optimizer""" 52 | if self.ddp: 53 | model = self.model.module 54 | else: 55 | model = self.model 56 | 57 | params = model.parameters() 58 | 59 | if self.first_task: 60 | self.first_task = False 61 | optimizer = torch.optim.SGD(params, lr=self.first_task_lr, weight_decay=self.wd, momentum=self.momentum) 62 | else: 63 | optimizer = torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 64 | print(optimizer.param_groups[0]['lr']) 65 | return optimizer 66 | 67 | def pre_train_process(self, t, trn_loader): 68 | """Runs before training all epochs of the task (before the train session)""" 69 | if self.ddp: 70 | model = self.model.module 71 | else: 72 | model = self.model 73 | 74 | if t == 0: 75 | # Sec. 4.1: "the ReLU in the penultimate layer is removed to allow the features to take both positive and 76 | # negative values" 77 | if model.model.__class__.__name__ == 'ResNetCifar': 78 | old_block = model.model.layer3[-1] 79 | model.model.layer3[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu, 80 | old_block.conv2, old_block.bn2, old_block.downsample) 81 | elif model.model.__class__.__name__ == 'ResNet': 82 | old_block = model.model.layer4[-1] 83 | model.model.layer4[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu, 84 | old_block.conv2, old_block.bn2, old_block.downsample) 85 | elif model.model.__class__.__name__ == 'ResNetBottleneck': 86 | old_block = model.model.layer4[-1] 87 | model.model.layer4[-1] = BottleneckNoRelu(old_block.conv1, old_block.bn1, 88 | old_block.relu, old_block.conv2, old_block.bn2, 89 | old_block.conv3, old_block.bn3, old_block.downsample) 90 | else: 91 | warnings.warn("Warning: ReLU not removed from last block.") 92 | 93 | # Changes the new head to a CosineLinear 94 | model.heads[-1] = CosineLinear(model.heads[-1].in_features, model.heads[-1].out_features) 95 | model.to(self.device) 96 | # if t > 0: 97 | # Share sigma (Eta in paper) between all the heads 98 | # Yujun: according to il2m, since we'll correct this with model confidence 99 | # maybe we shouldn't share sigma here. 100 | # model.heads[-1].sigma = model.heads[-2].sigma 101 | 102 | # and we probably shouldn't freeze sigma here. 103 | # for h in model.heads[:-1]: 104 | # for param in h.parameters(): 105 | # param.requires_grad = False 106 | # model.heads[-1].sigma.requires_grad = True 107 | 108 | # if ddp option is activated, need to re-wrap the ddp model 109 | if self.ddp: 110 | self.model = DDP(self.model.module, device_ids=[self.local_rank]) 111 | 112 | # The original code has an option called "imprint weights" that seems to initialize the new head. 113 | # However, this is not mentioned in the paper and doesn't seem to make a significant difference. 114 | super().pre_train_process(t, trn_loader) 115 | 116 | # assume the trn_loader using naive sampler instead of distributed sampler 117 | def il2m(self, t, trn_loader): 118 | """Compute and store statistics for score rectification""" 119 | if self.ddp: 120 | model = self.model.module 121 | else: 122 | model = self.model 123 | 124 | old_classes_number = sum(model.task_cls[:t]) 125 | classes_counts = [0 for _ in range(sum(model.task_cls))] 126 | models_counts = 0 127 | 128 | # to store statistics for the classes as learned in the current incremental state 129 | self.current_classes_means = [0 for _ in range(old_classes_number)] 130 | # to store statistics for past classes as learned in their initial states 131 | for cls in range(old_classes_number, old_classes_number + model.task_cls[t]): 132 | self.init_classes_means.append(0) 133 | # to store statistics for model confidence in different states (i.e. avg top-1 pred scores) 134 | self.models_confidence.append(0) 135 | 136 | # compute the mean prediction scores that will be used to rectify scores in subsequent tasks 137 | with torch.no_grad(): 138 | self.model.eval() 139 | for images, targets in trn_loader: 140 | outputs = self.model(images.to(self.device)) 141 | scores = np.array(torch.cat(outputs, dim=1).data.cpu().numpy(), dtype=np.float) 142 | for m in range(len(targets)): 143 | if targets[m] < old_classes_number: 144 | # computation of class means for past classes of the current state. 145 | self.current_classes_means[targets[m]] += scores[m, targets[m]] 146 | classes_counts[targets[m]] += 1 147 | else: 148 | # compute the mean prediction scores for the new classes of the current state 149 | self.init_classes_means[targets[m]] += scores[m, targets[m]] 150 | classes_counts[targets[m]] += 1 151 | # compute the mean top scores for the new classes of the current state 152 | self.models_confidence[t] += np.max(scores[m, ]) 153 | models_counts += 1 154 | # Normalize by corresponding number of images 155 | for cls in range(old_classes_number): 156 | self.current_classes_means[cls] /= classes_counts[cls] 157 | for cls in range(old_classes_number, old_classes_number + model.task_cls[t]): 158 | self.init_classes_means[cls] /= classes_counts[cls] 159 | self.models_confidence[t] /= models_counts 160 | 161 | def train_loop(self, t, trn_loader, val_loader): 162 | """Contains the epochs loop""" 163 | if t == 0: 164 | dset = trn_loader.dataset 165 | trn_loader = torch.utils.data.DataLoader(dset, 166 | batch_size=self.first_task_bz, 167 | sampler=trn_loader.sampler, 168 | num_workers=trn_loader.num_workers, 169 | pin_memory=trn_loader.pin_memory) 170 | 171 | # add exemplars to train_loader 172 | if len(self.exemplars_dataset) > 0 and t > 0: 173 | dset = trn_loader.dataset + self.exemplars_dataset 174 | if self.ddp: 175 | trn_sampler = torch.utils.data.DistributedSampler(dset, shuffle=True) 176 | trn_loader = torch.utils.data.DataLoader(dset, 177 | batch_size=trn_loader.batch_size, 178 | sampler=trn_sampler, 179 | num_workers=trn_loader.num_workers, 180 | pin_memory=trn_loader.pin_memory) 181 | else: 182 | trn_loader = torch.utils.data.DataLoader(dset, 183 | batch_size=trn_loader.batch_size, 184 | shuffle=True, 185 | num_workers=trn_loader.num_workers, 186 | pin_memory=trn_loader.pin_memory) 187 | 188 | 189 | # FINETUNING TRAINING -- contains the epochs loop 190 | self.ft_train = True 191 | super().train_loop(t, trn_loader, val_loader) 192 | self.ft_train = False 193 | 194 | if self.ddp: 195 | # need to change the trainloader to the original version without distributed sampler 196 | dset = trn_loader.dataset 197 | trn_loader = torch.utils.data.DataLoader(dset, 198 | batch_size=200, shuffle=False, num_workers=trn_loader.num_workers, 199 | pin_memory=trn_loader.pin_memory) 200 | 201 | # IL2M outputs rectification 202 | self.il2m(t, trn_loader) 203 | 204 | # EXEMPLAR MANAGEMENT -- select training subset 205 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform, self.ddp) 206 | 207 | def calculate_metrics(self, outputs, targets): 208 | """Contains the main Task-Aware and Task-Agnostic metrics""" 209 | if self.ft_train: 210 | # no score rectification while training 211 | hits_taw, hits_tag = super().calculate_metrics(outputs, targets) 212 | else: 213 | if self.ddp: 214 | model = self.model.module 215 | else: 216 | model = self.model 217 | # Task-Aware Multi-Head 218 | pred = torch.zeros_like(targets.to(self.device)) 219 | for m in range(len(pred)): 220 | this_task = (model.task_cls.cumsum(0) <= targets[m]).sum() 221 | pred[m] = outputs[this_task][m].argmax() + model.task_offset[this_task] 222 | hits_taw = (pred == targets.to(self.device)).float() 223 | # Task-Agnostic Multi-Head 224 | if self.multi_softmax: 225 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs] 226 | # Eq. 1: rectify predicted scores 227 | old_classes_number = sum(model.task_cls[:-1]) 228 | for m in range(len(targets)): 229 | rectified_outputs = torch.cat(outputs, dim=1) 230 | pred[m] = rectified_outputs[m].argmax() 231 | if old_classes_number: 232 | # if the top-1 class predicted by the network is a new one, rectify the score 233 | if int(pred[m]) >= old_classes_number: 234 | for o in range(old_classes_number): 235 | o_task = int((model.task_cls.cumsum(0) <= o).sum()) 236 | rectified_outputs[m, o] *= (self.init_classes_means[o] / self.current_classes_means[o]) * \ 237 | (self.models_confidence[-1] / self.models_confidence[o_task]) 238 | pred[m] = rectified_outputs[m].argmax() 239 | # otherwise, rectification is not done because an old class is directly predicted 240 | hits_tag = (pred == targets.to(self.device)).float() 241 | return hits_taw, hits_tag 242 | 243 | def criterion(self, t, outputs, targets): 244 | if self.ddp: 245 | model = self.model.module 246 | else: 247 | model = self.model 248 | 249 | if type(outputs[0]) == dict: 250 | outputs = [o['wsigma'] for o in outputs] 251 | 252 | """Returns the loss value""" 253 | if len(self.exemplars_dataset) > 0: 254 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 255 | return torch.nn.functional.cross_entropy(outputs[t], targets - model.task_offset[t]) 256 | -------------------------------------------------------------------------------- /src/approach/incremental_learning.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import time 3 | import torch 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | 7 | from loggers.exp_logger import ExperimentLogger 8 | from datasets.exemplars_dataset import ExemplarsDataset 9 | 10 | from .utils import reduce_tensor_mean, reduce_tensor_sum 11 | 12 | class Inc_Learning_Appr: 13 | """Basic class for implementing incremental learning approaches""" 14 | 15 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], 16 | lr_decay=0.1, clipgrad=10000, 17 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 18 | eval_on_train=False, ddp=False, local_rank=0, 19 | logger: ExperimentLogger = None, exemplars_dataset: ExemplarsDataset = None): 20 | self.model = model 21 | self.device = device 22 | self.nepochs = nepochs 23 | self.lr = lr 24 | self.decay_mile_stone = decay_mile_stone 25 | self.lr_decay = lr_decay 26 | self.clipgrad = clipgrad 27 | self.momentum = momentum 28 | self.wd = wd 29 | self.multi_softmax = multi_softmax 30 | self.logger = logger 31 | self.exemplars_dataset = exemplars_dataset 32 | self.warmup_epochs = wu_nepochs 33 | self.warmup_lr = lr * wu_lr_factor 34 | self.warmup_loss = torch.nn.CrossEntropyLoss() 35 | self.fix_bn = fix_bn 36 | self.eval_on_train = eval_on_train 37 | self.ddp = ddp 38 | self.local_rank = local_rank 39 | self.optimizer = None 40 | 41 | @staticmethod 42 | def extra_parser(args): 43 | """Returns a parser containing the approach specific parameters""" 44 | parser = ArgumentParser() 45 | return parser.parse_known_args(args) 46 | 47 | @staticmethod 48 | def exemplars_dataset_class(): 49 | """Returns a exemplar dataset to use during the training if the approach needs it 50 | :return: ExemplarDataset class or None 51 | """ 52 | return None 53 | 54 | def _get_optimizer(self): 55 | """Returns the optimizer""" 56 | return torch.optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 57 | 58 | def train(self, t, trn_loader, val_loader): 59 | """Main train structure""" 60 | self.pre_train_process(t, trn_loader) 61 | self.train_loop(t, trn_loader, val_loader) 62 | self.post_train_process(t, trn_loader, val_loader) 63 | 64 | def pre_train_process(self, t, trn_loader): 65 | """Runs before training all epochs of the task (before the train session)""" 66 | 67 | # Warm-up phase 68 | if self.warmup_epochs and t > 0: 69 | self.optimizer = torch.optim.SGD(self.model.heads[-1].parameters(), lr=self.warmup_lr) 70 | # Loop epochs -- train warm-up head 71 | for e in range(self.warmup_epochs): 72 | warmupclock0 = time.time() 73 | self.model.heads[-1].train() 74 | for images, targets in trn_loader: 75 | outputs = self.model(images.to(self.device)) 76 | loss = self.warmup_loss(outputs[t], targets.to(self.device) - self.model.task_offset[t]) 77 | self.optimizer.zero_grad() 78 | loss.backward() 79 | torch.nn.utils.clip_grad_norm_(self.model.heads[-1].parameters(), self.clipgrad) 80 | self.optimizer.step() 81 | warmupclock1 = time.time() 82 | with torch.no_grad(): 83 | total_loss, total_acc_taw = 0, 0 84 | self.model.eval() 85 | for images, targets in trn_loader: 86 | outputs = self.model(images.to(self.device)) 87 | loss = self.warmup_loss(outputs[t], targets.to(self.device) - self.model.task_offset[t]) 88 | pred = torch.zeros_like(targets.to(self.device)) 89 | for m in range(len(pred)): 90 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum() 91 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task] 92 | hits_taw = (pred == targets.to(self.device)).float() 93 | total_loss += loss.item() * len(targets) 94 | total_acc_taw += hits_taw.sum().item() 95 | total_num = len(trn_loader.dataset.labels) 96 | trn_loss, trn_acc = total_loss / total_num, total_acc_taw / total_num 97 | warmupclock2 = time.time() 98 | if self.local_rank == 0: 99 | print('| Warm-up Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format( 100 | e + 1, warmupclock1 - warmupclock0, warmupclock2 - warmupclock1, trn_loss, 100 * trn_acc)) 101 | self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=trn_loss, group="warmup") 102 | self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * trn_acc, group="warmup") 103 | 104 | 105 | def train_loop(self, t, trn_loader, val_loader): 106 | """Contains the epochs loop""" 107 | ####################### 108 | # best_acc = 0 109 | # if self.ddp: 110 | # best_model = self.model.module.state_dict() 111 | # else: 112 | # best_model = self.model.state_dict() 113 | ####################### 114 | 115 | self.optimizer = self._get_optimizer() 116 | scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.decay_mile_stone, gamma=self.lr_decay) 117 | 118 | # Loop epochs 119 | for e in range(self.nepochs): 120 | # Train 121 | clock0 = time.time() 122 | self.train_epoch(t, trn_loader) 123 | clock1 = time.time() 124 | if self.eval_on_train: 125 | train_loss, train_acc, _ = self.eval(t, trn_loader) 126 | clock2 = time.time() 127 | if self.local_rank == 0: 128 | print('| Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format( 129 | e + 1, clock1 - clock0, clock2 - clock1, train_loss, 100 * train_acc), end='') 130 | else: 131 | if self.local_rank == 0: 132 | print('| Epoch {:3d}, time={:5.1f}s | Train: skip eval |'.format(e + 1, clock1 - clock0), end='') 133 | 134 | # Valid 135 | clock3 = time.time() 136 | valid_loss, valid_acc, _ = self.eval(t, val_loader) 137 | clock4 = time.time() 138 | if self.local_rank == 0: 139 | print(' Valid: time={:5.1f}s loss={:.3f}, TAw acc={:5.1f}% |'.format( 140 | clock4 - clock3, valid_loss, 100 * valid_acc), end='') 141 | 142 | scheduler.step() 143 | ####################### 144 | # if valid_acc > best_acc: 145 | # if self.ddp: 146 | # best_model = deepcopy(self.model.module.state_dict()) 147 | # else: 148 | # best_model = deepcopy(self.model.state_dict()) 149 | # best_acc = valid_acc 150 | ####################### 151 | if self.local_rank == 0: 152 | print() 153 | 154 | ####################### 155 | # if self.ddp: 156 | # self.model.module.set_state_dict(best_model) 157 | # else: 158 | # self.model.set_state_dict(best_model) 159 | ####################### 160 | 161 | def post_train_process(self, t, trn_loader, val_loader): 162 | """Runs after training all the epochs of the task (after the train session)""" 163 | pass 164 | 165 | def train_epoch(self, t, trn_loader): 166 | """Runs a single epoch""" 167 | self.model.train() 168 | if self.fix_bn and t > 0: 169 | self.model.freeze_bn() 170 | for images, targets in trn_loader: 171 | # Forward current model 172 | outputs = self.model(images.to(self.device)) 173 | loss = self.criterion(t, outputs, targets.to(self.device)) 174 | # Backward 175 | self.optimizer.zero_grad() 176 | loss.backward() 177 | # clipgrad < 0 implicitly implies disabling gradient clipping 178 | if self.clipgrad > 0: 179 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 180 | self.optimizer.step() 181 | 182 | def eval(self, t, val_loader): 183 | """Contains the evaluation code""" 184 | with torch.no_grad(): 185 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0 186 | self.model.eval() 187 | for images, targets in val_loader: 188 | # Forward current model 189 | outputs = self.model(images.to(self.device)) 190 | loss = self.criterion(t, outputs, targets.to(self.device)) 191 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets) 192 | 193 | # if self.ddp: 194 | # hits_taw, hits_tag = reduce_tensor_mean(hits_taw, self.world_size), reduce_tensor_mean(hits_tag, self.world_size) 195 | # loss = reduce_tensor_mean(loss, self.world_size) 196 | 197 | total_loss += loss.item() * len(targets) 198 | total_acc_taw += hits_taw.sum().item() 199 | total_acc_tag += hits_tag.sum().item() 200 | total_num += len(targets) 201 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num 202 | 203 | def calculate_metrics(self, outputs, targets): 204 | """Contains the main Task-Aware and Task-Agnostic metrics""" 205 | pred = torch.zeros_like(targets.to(self.device)) 206 | # Task-Aware Multi-Head 207 | if self.ddp: 208 | for m in range(len(pred)): 209 | this_task = (self.model.module.task_cls.cumsum(0) <= targets[m]).sum() 210 | pred[m] = outputs[this_task][m].argmax() + self.model.module.task_offset[this_task] 211 | else: 212 | for m in range(len(pred)): 213 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum() 214 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task] 215 | hits_taw = (pred == targets.to(self.device)).float() 216 | # Task-Agnostic Multi-Head 217 | if self.multi_softmax: 218 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs] 219 | pred = torch.cat(outputs, dim=1).argmax(1) 220 | else: 221 | pred = torch.cat(outputs, dim=1).argmax(1) 222 | hits_tag = (pred == targets.to(self.device)).float() 223 | return hits_taw, hits_tag 224 | 225 | def criterion(self, t, outputs, targets): 226 | """Returns the loss value""" 227 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 228 | -------------------------------------------------------------------------------- /src/approach/joint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | from torch.utils.data import DataLoader, Dataset 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | from .lucir_utils import CosineLinear, BasicBlockNoRelu, BottleneckNoRelu 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the joint baseline""" 11 | 12 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 13 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 14 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None, 15 | lamb=5., lamb_mr=1., dist=0.5, K=2): 16 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 17 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank, 18 | logger, exemplars_dataset) 19 | self.trn_datasets = [] 20 | self.val_datasets = [] 21 | 22 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 23 | assert (have_exemplars == 0), 'Warning: Joint does not use exemplars. Comment this line to force it.' 24 | 25 | @staticmethod 26 | def exemplars_dataset_class(): 27 | return ExemplarsDataset 28 | 29 | @staticmethod 30 | def extra_parser(args): 31 | """Returns a parser containing the approach specific parameters""" 32 | parser = ArgumentParser() 33 | return parser.parse_known_args(args) 34 | 35 | def pre_train_process(self, t, trn_loader): 36 | """Runs before training all epochs of the task (before the train session)""" 37 | if self.ddp: 38 | model = self.model.module 39 | else: 40 | model = self.model 41 | 42 | if t == 0: 43 | # Sec. 4.1: "the ReLU in the penultimate layer is removed to allow the features to take both positive and 44 | # negative values" 45 | if model.model.__class__.__name__ == 'ResNetCifar': 46 | old_block = model.model.layer3[-1] 47 | model.model.layer3[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu, 48 | old_block.conv2, old_block.bn2, old_block.downsample) 49 | elif model.model.__class__.__name__ == 'ResNet': 50 | old_block = model.model.layer4[-1] 51 | model.model.layer4[-1] = BasicBlockNoRelu(old_block.conv1, old_block.bn1, old_block.relu, 52 | old_block.conv2, old_block.bn2, old_block.downsample) 53 | elif model.model.__class__.__name__ == 'ResNetBottleneck': 54 | old_block = model.model.layer4[-1] 55 | model.model.layer4[-1] = BottleneckNoRelu(old_block.conv1, old_block.bn1, 56 | old_block.relu, old_block.conv2, old_block.bn2, 57 | old_block.conv3, old_block.bn3, old_block.downsample) 58 | else: 59 | warnings.warn("Warning: ReLU not removed from last block.") 60 | # Changes the new head to a CosineLinear 61 | model.heads[-1] = CosineLinear(model.heads[-1].in_features, model.heads[-1].out_features) 62 | model.to(self.device) 63 | 64 | # if ddp option is activated, need to re-wrap the ddp model 65 | # yujun: debug to make sure this one is ok 66 | if self.ddp: 67 | self.model = DDP(self.model.module, device_ids=[self.local_rank]) 68 | # The original code has an option called "imprint weights" that seems to initialize the new head. 69 | # However, this is not mentioned in the paper and doesn't seem to make a significant difference. 70 | super().pre_train_process(t, trn_loader) 71 | 72 | def post_train_process(self, t, trn_loader, val_loader): 73 | """Runs after training all the epochs of the task (after the train session)""" 74 | pass 75 | 76 | def train_loop(self, t, trn_loader, val_loader): 77 | """Contains the epochs loop""" 78 | 79 | # add new datasets to existing cumulative ones 80 | self.trn_datasets.append(trn_loader.dataset) 81 | self.val_datasets.append(val_loader.dataset) 82 | trn_dset = JointDataset(self.trn_datasets) 83 | val_dset = JointDataset(self.val_datasets) 84 | trn_loader = DataLoader(trn_dset, 85 | batch_size=trn_loader.batch_size, 86 | shuffle=True, 87 | num_workers=trn_loader.num_workers, 88 | pin_memory=trn_loader.pin_memory) 89 | val_loader = DataLoader(val_dset, 90 | batch_size=val_loader.batch_size, 91 | shuffle=False, 92 | num_workers=val_loader.num_workers, 93 | pin_memory=val_loader.pin_memory) 94 | # continue training as usual 95 | super().train_loop(t, trn_loader, val_loader) 96 | 97 | def train_epoch(self, t, trn_loader): 98 | """Runs a single epoch""" 99 | self.model.train() 100 | if self.fix_bn and t > 0: 101 | self.model.freeze_bn() 102 | 103 | for images, targets in trn_loader: 104 | images, targets = images.to(self.device), targets.to(self.device) 105 | # Forward current model 106 | outputs = self.model(images) 107 | loss = self.criterion(t, outputs, targets) 108 | # Backward 109 | self.optimizer.zero_grad() 110 | loss.backward() 111 | self.optimizer.step() 112 | 113 | def criterion(self, t, outputs, targets): 114 | """Returns the loss value""" 115 | if type(outputs[0])==dict: 116 | outputs = torch.cat([o['wsigma'] for o in outputs], dim=1) 117 | else: 118 | outputs = torch.cat([o for o in outputs], dim=1) 119 | return torch.nn.functional.cross_entropy(outputs, targets) 120 | 121 | class JointDataset(Dataset): 122 | """Characterizes a dataset for PyTorch -- this dataset accumulates each task dataset incrementally""" 123 | 124 | def __init__(self, datasets): 125 | self.datasets = datasets 126 | self._len = sum([len(d) for d in self.datasets]) 127 | 128 | def __len__(self): 129 | 'Denotes the total number of samples' 130 | return self._len 131 | 132 | def __getitem__(self, index): 133 | for d in self.datasets: 134 | if len(d) <= index: 135 | index -= len(d) 136 | else: 137 | x, y = d[index] 138 | return x, y 139 | -------------------------------------------------------------------------------- /src/approach/lucir_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torch.nn import Module, Parameter 6 | 7 | 8 | # Sec 3.2: This class implements the cosine normalizing linear layer module using Eq. 4 9 | class CosineLinear(Module): 10 | def __init__(self, in_features, out_features, sigma=True): 11 | super(CosineLinear, self).__init__() 12 | self.in_features = in_features 13 | self.out_features = out_features 14 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 15 | if sigma: 16 | self.sigma = Parameter(torch.Tensor(1)) 17 | else: 18 | self.register_parameter('sigma', None) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | stdv = 1. / math.sqrt(self.weight.size(1)) 23 | self.weight.data.uniform_(-stdv, stdv) 24 | if self.sigma is not None: 25 | self.sigma.data.fill_(1) # for initializaiton of sigma 26 | 27 | def forward(self, input): 28 | out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) 29 | if self.sigma is not None: 30 | out_s = self.sigma * out 31 | else: 32 | out_s = out 33 | if self.training: 34 | return {'wsigma': out_s, 'wosigma': out} 35 | else: 36 | return out_s 37 | 38 | 39 | # This class implements a ResNet Basic Block without the final ReLu in the forward 40 | class BasicBlockNoRelu(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, conv1, bn1, relu, conv2, bn2, downsample): 44 | super(BasicBlockNoRelu, self).__init__() 45 | self.conv1 = conv1 46 | self.bn1 = bn1 47 | self.relu = relu 48 | self.conv2 = conv2 49 | self.bn2 = bn2 50 | self.downsample = downsample 51 | 52 | def forward(self, x): 53 | residual = x 54 | out = self.relu(self.bn1(self.conv1(x))) 55 | out = self.bn2(self.conv2(out)) 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | out += residual 59 | # Removed final ReLU 60 | return out 61 | 62 | class BottleneckNoRelu(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, conv1, bn1, relu, conv2, bn2, conv3, bn3, downsample): 66 | super(BottleneckNoRelu, self).__init__() 67 | self.conv1 = conv1 68 | self.bn1 = bn1 69 | self.conv2 = conv2 70 | self.bn2 = bn2 71 | self.conv3 = conv3 72 | self.bn3 = bn3 73 | self.relu = relu 74 | self.downsample = downsample 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | # Removed final ReLU 95 | return out -------------------------------------------------------------------------------- /src/approach/lwf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from argparse import ArgumentParser 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Learning Without Forgetting (LwF) approach 11 | described in https://arxiv.org/abs/1606.09282 12 | """ 13 | 14 | # Weight decay of 0.0005 is used in the original article (page 4). 15 | # Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our 16 | # method or the compared Less Forgetting Learning (see Table 2(b))." 17 | def __init__(self, model, device, nepochs=160, lr=0.1, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 18 | momentum=0.9, wd=5e-4, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 19 | eval_on_train=False, ddp=False, local_rank=0, logger=None, exemplars_dataset=None, 20 | lamb=1, T=2): 21 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 22 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, ddp, local_rank, 23 | logger, exemplars_dataset) 24 | self.model_old = None 25 | self.lamb = lamb 26 | self.T = T 27 | 28 | @staticmethod 29 | def exemplars_dataset_class(): 30 | return ExemplarsDataset 31 | 32 | @staticmethod 33 | def extra_parser(args): 34 | """Returns a parser containing the approach specific parameters""" 35 | parser = ArgumentParser() 36 | # Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor 37 | # the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by 38 | # changing lambda." 39 | parser.add_argument('--lamb', default=1, type=float, required=False, 40 | help='Forgetting-intransigence trade-off (default=%(default)s)') 41 | # Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’ 42 | # recommendations." -- Using a higher value for T produces a softer probability distribution over classes. 43 | parser.add_argument('--T', default=2, type=int, required=False, 44 | help='Temperature scaling (default=%(default)s)') 45 | return parser.parse_known_args(args) 46 | 47 | def _get_optimizer(self): 48 | """Returns the optimizer""" 49 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 50 | # if there are no exemplars, previous heads are not modified 51 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 52 | else: 53 | params = self.model.parameters() 54 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 55 | 56 | def train_loop(self, t, trn_loader, val_loader): 57 | """Contains the epochs loop""" 58 | 59 | # add exemplars to train_loader 60 | if len(self.exemplars_dataset) > 0 and t > 0: 61 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 62 | batch_size=trn_loader.batch_size, 63 | shuffle=True, 64 | num_workers=trn_loader.num_workers, 65 | pin_memory=trn_loader.pin_memory) 66 | 67 | # FINETUNING TRAINING -- contains the epochs loop 68 | super().train_loop(t, trn_loader, val_loader) 69 | 70 | # EXEMPLAR MANAGEMENT -- select training subset 71 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform, self.ddp) 72 | 73 | def post_train_process(self, t, trn_loader, val_loader): 74 | """Runs after training all the epochs of the task (after the train session)""" 75 | 76 | # Restore best and save model for future tasks 77 | self.model_old = deepcopy(self.model) 78 | self.model_old.eval() 79 | self.model_old.freeze_all() 80 | 81 | def train_epoch(self, t, trn_loader): 82 | """Runs a single epoch""" 83 | self.model.train() 84 | if self.fix_bn and t > 0: 85 | self.model.freeze_bn() 86 | for images, targets in trn_loader: 87 | # Forward old model 88 | targets_old = None 89 | if t > 0: 90 | targets_old = self.model_old(images.to(self.device)) 91 | # Forward current model 92 | outputs = self.model(images.to(self.device)) 93 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old) 94 | # Backward 95 | self.optimizer.zero_grad() 96 | loss.backward() 97 | self.optimizer.step() 98 | 99 | def eval(self, t, val_loader): 100 | """Contains the evaluation code""" 101 | with torch.no_grad(): 102 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0 103 | self.model.eval() 104 | for images, targets in val_loader: 105 | # Forward old model 106 | targets_old = None 107 | if t > 0: 108 | targets_old = self.model_old(images.to(self.device)) 109 | # Forward current model 110 | outputs = self.model(images.to(self.device)) 111 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old) 112 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets) 113 | # Log 114 | total_loss += loss.data.cpu().numpy().item() * len(targets) 115 | total_acc_taw += hits_taw.sum().data.cpu().numpy().item() 116 | total_acc_tag += hits_tag.sum().data.cpu().numpy().item() 117 | total_num += len(targets) 118 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num 119 | 120 | def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5): 121 | """Calculates cross-entropy with temperature scaling""" 122 | out = torch.nn.functional.softmax(outputs, dim=1) 123 | tar = torch.nn.functional.softmax(targets, dim=1) 124 | if exp != 1: 125 | out = out.pow(exp) 126 | out = out / out.sum(1).view(-1, 1).expand_as(out) 127 | tar = tar.pow(exp) 128 | tar = tar / tar.sum(1).view(-1, 1).expand_as(tar) 129 | out = out + eps / out.size(1) 130 | out = out / out.sum(1).view(-1, 1).expand_as(out) 131 | ce = -(tar * out.log()).sum(1) 132 | if size_average: 133 | ce = ce.mean() 134 | return ce 135 | 136 | def criterion(self, t, outputs, targets, outputs_old=None): 137 | """Returns the loss value""" 138 | loss = 0 139 | if t > 0: 140 | # Knowledge distillation loss for all previous tasks 141 | loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1), 142 | torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T) 143 | # Current cross-entropy loss -- with exemplars use all heads 144 | if len(self.exemplars_dataset) > 0: 145 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 146 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 147 | -------------------------------------------------------------------------------- /src/approach/mas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from argparse import ArgumentParser 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Memory Aware Synapses (MAS) approach (global version) 11 | described in https://arxiv.org/abs/1711.09601 12 | Original code available at https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses 13 | """ 14 | 15 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 16 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 17 | logger=None, exemplars_dataset=None, lamb=1, alpha=0.5, fi_num_samples=-1): 18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 20 | exemplars_dataset) 21 | self.lamb = lamb 22 | self.alpha = alpha 23 | self.num_samples = fi_num_samples 24 | 25 | # In all cases, we only keep importance weights for the model, but not for the heads. 26 | feat_ext = self.model.model 27 | # Store current parameters as the initial parameters before first task starts 28 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad} 29 | # Store fisher information weight importance 30 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 31 | if p.requires_grad} 32 | 33 | @staticmethod 34 | def exemplars_dataset_class(): 35 | return ExemplarsDataset 36 | 37 | @staticmethod 38 | def extra_parser(args): 39 | """Returns a parser containing the approach specific parameters""" 40 | parser = ArgumentParser() 41 | # Eq. 3: lambda is the regularizer trade-off -- In original code: MAS.ipynb block [4]: lambda set to 1 42 | parser.add_argument('--lamb', default=1, type=float, required=False, 43 | help='Forgetting-intransigence trade-off (default=%(default)s)') 44 | # Define how old and new importance is fused, by default it is a 50-50 fusion 45 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 46 | help='MAS alpha (default=%(default)s)') 47 | # Number of samples from train for estimating importance 48 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False, 49 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 50 | return parser.parse_known_args(args) 51 | 52 | def _get_optimizer(self): 53 | """Returns the optimizer""" 54 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 55 | # if there are no exemplars, previous heads are not modified 56 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 57 | else: 58 | params = self.model.parameters() 59 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 60 | 61 | # Section 4.1: MAS (global) is implemented since the paper shows is more efficient than l-MAS (local) 62 | def estimate_parameter_importance(self, trn_loader): 63 | # Initialize importance matrices 64 | importance = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 65 | if p.requires_grad} 66 | # Compute fisher information for specified number of samples -- rounded to the batch size 67 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 68 | else (len(trn_loader.dataset) // trn_loader.batch_size) 69 | # Do forward and backward pass to accumulate L2-loss gradients 70 | self.model.train() 71 | for images, targets in itertools.islice(trn_loader, n_samples_batches): 72 | # MAS allows any unlabeled data to do the estimation, we choose the current data as in main experiments 73 | outputs = self.model.forward(images.to(self.device)) 74 | # Page 6: labels not required, "...use the gradients of the squared L2-norm of the learned function output." 75 | loss = torch.norm(torch.cat(outputs, dim=1), p=2, dim=1).mean() 76 | self.optimizer.zero_grad() 77 | loss.backward() 78 | # Eq. 2: accumulate the gradients over the inputs to obtain importance weights 79 | for n, p in self.model.model.named_parameters(): 80 | if p.grad is not None: 81 | importance[n] += p.grad.abs() * len(targets) 82 | # Eq. 2: divide by N total number of samples 83 | n_samples = n_samples_batches * trn_loader.batch_size 84 | importance = {n: (p / n_samples) for n, p in importance.items()} 85 | return importance 86 | 87 | def train_loop(self, t, trn_loader, val_loader): 88 | """Contains the epochs loop""" 89 | 90 | # add exemplars to train_loader 91 | if len(self.exemplars_dataset) > 0 and t > 0: 92 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 93 | batch_size=trn_loader.batch_size, 94 | shuffle=True, 95 | num_workers=trn_loader.num_workers, 96 | pin_memory=trn_loader.pin_memory) 97 | 98 | # FINETUNING TRAINING -- contains the epochs loop 99 | super().train_loop(t, trn_loader, val_loader) 100 | 101 | # EXEMPLAR MANAGEMENT -- select training subset 102 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 103 | 104 | def post_train_process(self, t, trn_loader): 105 | """Runs after training all the epochs of the task (after the train session)""" 106 | 107 | # Store current parameters for the next task 108 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 109 | 110 | # calculate Fisher information 111 | curr_importance = self.estimate_parameter_importance(trn_loader) 112 | # merge fisher information, we do not want to keep fisher information for each task in memory 113 | for n in self.importance.keys(): 114 | # Added option to accumulate importance over time with a pre-fixed growing alpha 115 | if self.alpha == -1: 116 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 117 | self.importance[n] = alpha * self.importance[n] + (1 - alpha) * curr_importance[n] 118 | else: 119 | # As in original code: MAS_utils/MAS_based_Training.py line 638 -- just add prev and new 120 | self.importance[n] = self.alpha * self.importance[n] + (1 - self.alpha) * curr_importance[n] 121 | 122 | def criterion(self, t, outputs, targets): 123 | """Returns the loss value""" 124 | loss = 0 125 | if t > 0: 126 | loss_reg = 0 127 | # Eq. 3: memory aware synapses regularizer penalty 128 | for n, p in self.model.model.named_parameters(): 129 | if n in self.importance.keys(): 130 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2)) / 2 131 | loss += self.lamb * loss_reg 132 | # Current cross-entropy loss -- with exemplars use all heads 133 | if len(self.exemplars_dataset) > 0: 134 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 135 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 136 | -------------------------------------------------------------------------------- /src/approach/path_integral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the Path Integral (aka Synaptic Intelligence) approach 10 | described in http://proceedings.mlr.press/v70/zenke17a.html 11 | Original code available at https://github.com/ganguli-lab/pathint 12 | """ 13 | 14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 16 | logger=None, exemplars_dataset=None, lamb=0.1, damping=0.1): 17 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 18 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 19 | exemplars_dataset) 20 | self.lamb = lamb 21 | self.damping = damping 22 | 23 | # In all cases, we only keep importance weights for the model, but not for the heads. 24 | feat_ext = self.model.model 25 | # Page 3, following Eq. 3: "The w now have an intuitive interpretation as the parameter specific contribution to 26 | # changes in the total loss." 27 | self.w = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() if p.requires_grad} 28 | # Store current parameters as the initial parameters before first task starts 29 | self.older_params = {n: p.clone().detach().to(self.device) for n, p in feat_ext.named_parameters() 30 | if p.requires_grad} 31 | # Store importance weights matrices 32 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 33 | if p.requires_grad} 34 | 35 | @staticmethod 36 | def exemplars_dataset_class(): 37 | return ExemplarsDataset 38 | 39 | @staticmethod 40 | def extra_parser(args): 41 | """Returns a parser containing the approach specific parameters""" 42 | parser = ArgumentParser() 43 | # Eq. 4: lamb is the 'c' trade-off parameter from the surrogate loss -- 1e-3 < c < 0.1 44 | parser.add_argument('--lamb', default=0.1, type=float, required=False, 45 | help='Forgetting-intransigence trade-off (default=%(default)s)') 46 | # Eq. 5: damping parameter is set to 0.1 in the MNIST case 47 | parser.add_argument('--damping', default=0.1, type=float, required=False, 48 | help='Damping (default=%(default)s)') 49 | return parser.parse_known_args(args) 50 | 51 | def _get_optimizer(self): 52 | """Returns the optimizer""" 53 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 54 | # if there are no exemplars, previous heads are not modified 55 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 56 | else: 57 | params = self.model.parameters() 58 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 59 | 60 | def train_loop(self, t, trn_loader, val_loader): 61 | """Contains the epochs loop""" 62 | 63 | # add exemplars to train_loader 64 | if len(self.exemplars_dataset) > 0 and t > 0: 65 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 66 | batch_size=trn_loader.batch_size, 67 | shuffle=True, 68 | num_workers=trn_loader.num_workers, 69 | pin_memory=trn_loader.pin_memory) 70 | 71 | # FINETUNING TRAINING -- contains the epochs loop 72 | super().train_loop(t, trn_loader, val_loader) 73 | 74 | # EXEMPLAR MANAGEMENT -- select training subset 75 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 76 | 77 | def post_train_process(self, t, trn_loader): 78 | """Runs after training all the epochs of the task (after the train session)""" 79 | 80 | # Eq. 5: accumulate Omega regularization strength (importance matrix) 81 | with torch.no_grad(): 82 | curr_params = {n: p for n, p in self.model.model.named_parameters() if p.requires_grad} 83 | for n, p in self.importance.items(): 84 | p += self.w[n] / ((curr_params[n] - self.older_params[n]) ** 2 + self.damping) 85 | self.w[n].zero_() 86 | 87 | # Store current parameters for the next task 88 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 89 | 90 | def train_epoch(self, t, trn_loader): 91 | """Runs a single epoch""" 92 | self.model.train() 93 | if self.fix_bn and t > 0: 94 | self.model.freeze_bn() 95 | for images, targets in trn_loader: 96 | # store current model without heads 97 | curr_feat_ext = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 98 | 99 | # Forward current model 100 | outputs = self.model(images.to(self.device)) 101 | # theoretically this is the correct one for 2 tasks, however, for more tasks maybe is the current loss 102 | # check https://github.com/ganguli-lab/pathint/blob/master/pathint/optimizers.py line 123 103 | # cross-entropy loss on current task 104 | if len(self.exemplars_dataset) == 0: 105 | loss = torch.nn.functional.cross_entropy(outputs[t], targets.to(self.device) - self.model.task_offset[t]) 106 | else: 107 | # with exemplars we check output from all heads (train data has all labels) 108 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets.to(self.device)) 109 | self.optimizer.zero_grad() 110 | loss.backward(retain_graph=True) 111 | # store gradients without regularization term 112 | unreg_grads = {n: p.grad.clone().detach() for n, p in self.model.model.named_parameters() 113 | if p.grad is not None} 114 | # apply loss with path integral regularization 115 | loss = self.criterion(t, outputs, targets.to(self.device)) 116 | 117 | # Backward 118 | self.optimizer.zero_grad() 119 | loss.backward() 120 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 121 | self.optimizer.step() 122 | 123 | # Eq. 3: accumulate w, compute the path integral -- "In practice, we can approximate w online as the running 124 | # sum of the product of the gradient with the parameter update". 125 | with torch.no_grad(): 126 | for n, p in self.model.model.named_parameters(): 127 | if n in unreg_grads.keys(): 128 | # w[n] >=0, but minus for loss decrease 129 | self.w[n] -= unreg_grads[n] * (p.detach() - curr_feat_ext[n]) 130 | 131 | def criterion(self, t, outputs, targets): 132 | """Returns the loss value""" 133 | loss = 0 134 | if t > 0: 135 | loss_reg = 0 136 | # Eq. 4: quadratic surrogate loss 137 | for n, p in self.model.model.named_parameters(): 138 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2)) 139 | loss += self.lamb * loss_reg 140 | # Current cross-entropy loss -- with exemplars use all heads 141 | if len(self.exemplars_dataset) > 0: 142 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 143 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 144 | -------------------------------------------------------------------------------- /src/approach/r_walk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from argparse import ArgumentParser 4 | from torch.utils.data import DataLoader 5 | 6 | from .incremental_learning import Inc_Learning_Appr 7 | from datasets.exemplars_dataset import ExemplarsDataset 8 | 9 | 10 | class Appr(Inc_Learning_Appr): 11 | """Class implementing the Riemannian Walk (RWalk) approach described in 12 | http://openaccess.thecvf.com/content_ECCV_2018/papers/Arslan_Chaudhry__Riemannian_Walk_ECCV_2018_paper.pdf 13 | """ 14 | 15 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 16 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 17 | logger=None, exemplars_dataset=None, lamb=1, alpha=0.5, damping=0.1, fim_sampling_type='max_pred', 18 | fim_num_samples=-1): 19 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 20 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 21 | exemplars_dataset) 22 | self.lamb = lamb 23 | self.alpha = alpha 24 | self.damping = damping 25 | self.sampling_type = fim_sampling_type 26 | self.num_samples = fim_num_samples 27 | 28 | # In all cases, we only keep importance weights for the model, but not for the heads. 29 | feat_ext = self.model.model 30 | # Page 7: "task-specific parameter importance over the entire training trajectory." 31 | self.w = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() if p.requires_grad} 32 | # Store current parameters as the initial parameters before first task starts 33 | self.older_params = {n: p.clone().detach().to(self.device) for n, p in feat_ext.named_parameters() 34 | if p.requires_grad} 35 | # Store scores and fisher information 36 | self.scores = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 37 | if p.requires_grad} 38 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 39 | if p.requires_grad} 40 | 41 | @staticmethod 42 | def exemplars_dataset_class(): 43 | return ExemplarsDataset 44 | 45 | @staticmethod 46 | def extra_parser(args): 47 | """Returns a parser containing the approach specific parameters""" 48 | parser = ArgumentParser() 49 | # Eq. 5 and 8: "regularization hyperparameter lambda being less sensitive to the number of tasks. Whereas, 50 | # EWC and Path Integral are highly sensitive to lambda, making them relatively less reliable for IL" 51 | parser.add_argument('--lamb', default=1, type=float, required=False, 52 | help='Forgetting-intransigence trade-off (default=%(default)s)') 53 | # Define how old and new fisher is fused, by default it is a 50-50 fusion 54 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 55 | help='RWalk alpha (default=%(default)s)') 56 | # Damping parameter as in Path Integral 57 | parser.add_argument('--damping', default=0.1, type=float, required=False, 58 | help='(default=%(default)s)') 59 | parser.add_argument('--fim_sampling_type', default='max_pred', type=str, required=False, 60 | choices=['true', 'max_pred', 'multinomial'], 61 | help='Sampling type for Fisher information (default=%(default)s)') 62 | parser.add_argument('--fim_num_samples', default=-1, type=int, required=False, 63 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 64 | return parser.parse_known_args(args) 65 | 66 | def _get_optimizer(self): 67 | """Returns the optimizer""" 68 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 69 | # if there are no exemplars, previous heads are not modified 70 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 71 | else: 72 | params = self.model.parameters() 73 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 74 | 75 | def compute_fisher_matrix_diag(self, trn_loader): 76 | # Store Fisher Information 77 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 78 | if p.requires_grad} 79 | # Compute fisher information for specified number of samples -- rounded to the batch size 80 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 81 | else (len(trn_loader.dataset) // trn_loader.batch_size) 82 | # Do forward and backward pass to compute the fisher information 83 | self.model.train() 84 | for images, targets in itertools.islice(trn_loader, n_samples_batches): 85 | outputs = self.model.forward(images.to(self.device)) 86 | 87 | if self.sampling_type == 'true': 88 | # Use the labels to compute the gradients based on the CE-loss with the ground truth 89 | preds = targets.to(self.device) 90 | elif self.sampling_type == 'max_pred': 91 | # Not use labels and compute the gradients related to the prediction the model has learned 92 | preds = torch.cat(outputs, dim=1).argmax(1).flatten() 93 | elif self.sampling_type == 'multinomial': 94 | # Use a multinomial sampling to compute the gradients 95 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1) 96 | preds = torch.multinomial(probs, len(targets)).flatten() 97 | 98 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds) 99 | self.optimizer.zero_grad() 100 | loss.backward() 101 | # Page 6: "the Fisher component [...] is the expected square of the loss gradient w.r.t the i-th parameter." 102 | for n, p in self.model.model.named_parameters(): 103 | if p.grad is not None: 104 | fisher[n] += p.grad.pow(2) * len(targets) 105 | # Apply mean across all samples 106 | n_samples = n_samples_batches * trn_loader.batch_size 107 | fisher = {n: (p / n_samples) for n, p in fisher.items()} 108 | return fisher 109 | 110 | def train_loop(self, t, trn_loader, val_loader): 111 | """Contains the epochs loop""" 112 | 113 | # add exemplars to train_loader 114 | if len(self.exemplars_dataset) > 0 and t > 0: 115 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 116 | batch_size=trn_loader.batch_size, 117 | shuffle=True, 118 | num_workers=trn_loader.num_workers, 119 | pin_memory=trn_loader.pin_memory) 120 | 121 | # FINETUNING TRAINING -- contains the epochs loop 122 | super().train_loop(t, trn_loader, val_loader) 123 | 124 | # EXEMPLAR MANAGEMENT -- select training subset 125 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 126 | 127 | def post_train_process(self, t, trn_loader): 128 | """Runs after training all the epochs of the task (after the train session)""" 129 | 130 | # calculate Fisher Information Matrix 131 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader) 132 | 133 | # Eq. 10: efficiently update Fisher Information Matrix 134 | for n in self.fisher.keys(): 135 | # Added option to accumulate fisher over time with a pre-fixed growing alpha 136 | if self.alpha == -1: 137 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 138 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n] 139 | else: 140 | self.fisher[n] = self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n] 141 | # Page 7: Optimization Path-based Parameter Importance: importance scores computation 142 | curr_score = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 143 | if p.requires_grad} 144 | with torch.no_grad(): 145 | curr_params = {n: p for n, p in self.model.model.named_parameters() if p.requires_grad} 146 | for n, p in self.scores.items(): 147 | curr_score[n] = self.w[n] / ( 148 | self.fisher[n] * ((curr_params[n] - self.older_params[n]) ** 2) + self.damping) 149 | self.w[n].zero_() 150 | # Page 7: "Since we care about positive influence of the parameters, negative scores are set to zero." 151 | curr_score[n] = torch.nn.functional.relu(curr_score[n]) 152 | # Page 8: alleviating regularization getting increasingly rigid by averaging scores 153 | for n, p in self.scores.items(): 154 | self.scores[n] = (self.scores[n] + curr_score[n]) / 2 155 | 156 | # Store current parameters for the next task 157 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 158 | 159 | def train_epoch(self, t, trn_loader): 160 | """Runs a single epoch""" 161 | self.model.train() 162 | if self.fix_bn and t > 0: 163 | self.model.freeze_bn() 164 | for images, targets in trn_loader: 165 | # store current model 166 | curr_feat_ext = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 167 | 168 | # Forward current model 169 | outputs = self.model(images.to(self.device)) 170 | # cross-entropy loss on current task 171 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets.to(self.device)) 172 | self.optimizer.zero_grad() 173 | loss.backward(retain_graph=True) 174 | # store gradients without regularization term 175 | unreg_grads = {n: p.grad.clone().detach() for n, p in self.model.model.named_parameters() 176 | if p.grad is not None} 177 | # apply loss with path integral regularization 178 | loss = self.criterion(t, outputs, targets.to(self.device)) 179 | 180 | # Backward 181 | self.optimizer.zero_grad() 182 | loss.backward() 183 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 184 | self.optimizer.step() 185 | 186 | # Page 7: "accumulate task-specific parameter importance over the entire training trajectory" 187 | # "the parameter importance is defined as the ratio of the change in the loss function to the distance 188 | # between the conditional likelihod distributions per step in the parameter space." 189 | with torch.no_grad(): 190 | for n, p in self.model.model.named_parameters(): 191 | if n in unreg_grads.keys(): 192 | self.w[n] -= unreg_grads[n] * (p.detach() - curr_feat_ext[n]) 193 | 194 | def criterion(self, t, outputs, targets): 195 | """Returns the loss value""" 196 | loss = 0 197 | if t > 0: 198 | loss_reg = 0 199 | # Eq. 9: final objective function 200 | for n, p in self.model.model.named_parameters(): 201 | if n in self.fisher.keys(): 202 | loss_reg += torch.sum((self.fisher[n] + self.scores[n]) * (p - self.older_params[n]).pow(2)) 203 | loss += self.lamb * loss_reg 204 | # Current cross-entropy loss -- with exemplars use all heads 205 | if len(self.exemplars_dataset) > 0: 206 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 207 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 208 | -------------------------------------------------------------------------------- /src/approach/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch import distributed as dist 7 | 8 | 9 | # ------------------- SSL utils ----------------------- 10 | # Supervised learning and SSL 11 | class TransformSLAndSSL: 12 | def __init__(self, orig_transform, transform_ssl): 13 | self.orig_transform = orig_transform 14 | self.transform_ssl = transform_ssl 15 | 16 | def __call__(self, inp): 17 | out = self.orig_transform(inp) 18 | out_ssl_1 = self.transform_ssl(inp) 19 | out_ssl_2 = self.transform_ssl(inp) 20 | return out, out_ssl_1, out_ssl_2 21 | 22 | # (H,W): data input size 23 | def get_simclr_transforms(H, W, orig_transform): 24 | simclr_aug = transforms.Compose([ 25 | transforms.RandomResizedCrop((H,W)), 26 | transforms.RandomHorizontalFlip(), # with 0.5 probability 27 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 28 | transforms.RandomGrayscale(p=0.2), 29 | transforms.ToTensor(),]) 30 | return TransformSLAndSSL(orig_transform, simclr_aug) 31 | # ----------------------------------------------------- 32 | 33 | # transform a scalar to multihot vector 34 | # targets: (N,) 35 | def scalar2onehot(targets, num_class): 36 | N = targets.shape[0] 37 | onehot_target = torch.zeros(N, num_class).to(targets.device).scatter_(1, targets.unsqueeze(-1), 1) 38 | return onehot_target 39 | 40 | def scalar2SmoothOneHot(targets, num_class): 41 | N = targets.shape[0] 42 | hot_prob = 0.9 43 | smooth_prob = (1.0 - hot_prob) / (N-1) 44 | onehot_target = (smooth_prob*torch.ones(N, num_class)).to(targets.device).scatter_(1, targets.unsqueeze(-1), hot_prob) 45 | return onehot_target 46 | 47 | # Cutmix 48 | # assuming the targets are one-hot vectors 49 | def cut_and_mix(images, targets): 50 | N,_,H,W = images.shape 51 | assert H==W, 'only support square right now' 52 | 53 | cut_len = int(np.floor(np.random.rand()*H)) 54 | cut_len = max(min(cut_len, H-1), 1) 55 | mix_ratio = float(cut_len*cut_len) / (H*W) 56 | 57 | top = np.random.randint(0, H - cut_len) 58 | left = np.random.randint(0, W - cut_len) 59 | bottom = top + cut_len 60 | right = left + cut_len 61 | 62 | # cut and mix 63 | # shuffled batch images 64 | rp = torch.randperm(N) 65 | shuffled_images = images[rp,:,:,:] 66 | shuffled_targets = targets[rp,:] 67 | images[:, :, top:bottom, left:right] = shuffled_images[:, :, top:bottom, left:right] 68 | 69 | # adjust the target 70 | targets = (1-mix_ratio)*targets + mix_ratio*shuffled_targets 71 | 72 | return images, targets 73 | 74 | # Cutmix w/ small window 75 | # assuming the targets are one-hot vectors 76 | def cut_and_mix_small_window(images, targets): 77 | N,_,H,W = images.shape 78 | assert H==W, 'only support square right now' 79 | 80 | cut_len = int(np.floor(0.4*np.random.rand()*H)) 81 | cut_len = max(min(cut_len, H-1), 1) 82 | mix_ratio = float(cut_len*cut_len) / (H*W) 83 | 84 | top = np.random.randint(0, H - cut_len) 85 | left = np.random.randint(0, W - cut_len) 86 | bottom = top + cut_len 87 | right = left + cut_len 88 | 89 | # cut and mix 90 | # shuffled batch images 91 | rp = torch.randperm(N) 92 | shuffled_images = images[rp,:,:,:] 93 | # shuffled_targets = targets[rp,:] 94 | images[:, :, top:bottom, left:right] = shuffled_images[:, :, top:bottom, left:right] 95 | 96 | # do not adjust the target 97 | # targets = (1-mix_ratio)*targets + mix_ratio*shuffled_targets 98 | 99 | return images, targets 100 | 101 | # Mixup 102 | # assuming the targets are one-hot vectors 103 | def mixup(images, targets): 104 | lam = np.random.beta(1.0, 1.0) 105 | N = images.shape[0] 106 | rp = torch.randperm(N) 107 | shuffled_images = images[rp,:,:,:] 108 | shuffled_targets = targets[rp,:] 109 | 110 | images = lam * images + (1-lam)*shuffled_images 111 | targets = lam * targets + (1-lam)*shuffled_targets 112 | 113 | return images, targets 114 | 115 | class MultiLabelCrossEntropyLoss(nn.Module): 116 | 117 | def __init__(self): 118 | super(MultiLabelCrossEntropyLoss, self).__init__() 119 | 120 | # logit: (N, C) 121 | # label: (N, C) 122 | def forward(self, logits, label): 123 | loss = -(label*F.log_softmax(logits, dim=1)).sum(dim=-1).mean() 124 | return loss 125 | 126 | class BalancedCrossEntropy(nn.Module): 127 | 128 | def __init__(self, tao): 129 | super(BalancedCrossEntropy, self).__init__() 130 | self.tao = tao 131 | self.eps = 1e-8 132 | 133 | def forward(self, logits, label): 134 | num_classes = logits.shape[1] 135 | label_onehot = scalar2onehot(label, num_classes) 136 | 137 | # v1 (undesired solution) 138 | # loss = -self.tao*(logits*label).sum(dim=-1) + (2.0 - self.tao)*torch.log(self.eps + logits.exp().sum(dim=-1)) 139 | 140 | # v2 (undesired solution) 141 | loss = -self.tao*(logits*label_onehot).sum(dim=-1) + torch.log(self.eps + logits.exp().sum(dim=-1)) 142 | 143 | # v3 (broken solution) 144 | # pos_loss = -self.tao*(logits*label).sum(dim=-1) 145 | # neg_loss = (2.0 - self.tao)*torch.log(self.eps + ((1.-label)*logits.exp()).sum(dim=-1)) 146 | # loss = pos_loss + neg_loss 147 | 148 | # v4 (broken solution) 149 | # pos_loss = -self.tao*(logits*label).sum(dim=-1) 150 | # reweight = 1. - (1. - self.tao)*label 151 | # neg_loss = torch.log(self.eps + (reweight*logits).exp().sum(dim=-1)) 152 | # loss = pos_loss + neg_loss 153 | # return loss.mean() 154 | 155 | # v5 156 | # reweight = 1. - (1. - self.tao)*label_onehot 157 | # logits = reweight*logits 158 | # loss = nn.CrossEntropyLoss(None)(logits, label) 159 | # return loss 160 | 161 | # pos_loss = -(logits*label_onehot).sum(dim=-1) 162 | # reweight = (1. - label_onehot)*self.tao + label_onehot 163 | # neg_loss = torch.log(self.eps + (reweight*logits.exp()).sum(dim=-1)) 164 | # loss = pos_loss + neg_loss 165 | 166 | return loss.mean() 167 | 168 | class AugCrossEntropy(nn.Module): 169 | 170 | def __init__(self, n_aug): 171 | super(AugCrossEntropy, self).__init__() 172 | self.n_aug = n_aug 173 | self.ce = nn.CrossEntropyLoss() 174 | 175 | def forward(self, logits, features, targets): 176 | N,C = features.shape 177 | device = features.device 178 | # first generating random fake class embeddings 179 | pseudo_embeddings = F.normalize(torch.rand(self.n_aug, C, device=device) - 0.5, dim=-1) 180 | pseudo_logits = torch.matmul(features, pseudo_embeddings.t()) 181 | cat_logits = torch.cat([logits, pseudo_logits], dim=-1) 182 | loss = self.ce(cat_logits, targets) 183 | return loss 184 | 185 | 186 | # -------------- DDP utils ------------------- 187 | def reduce_tensor_mean(tensor, n): 188 | rt = tensor.clone() 189 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 190 | rt /= n 191 | return rt 192 | 193 | def reduce_tensor_sum(tensor): 194 | rt = tensor.clone() 195 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 196 | return rt 197 | 198 | def global_gather(x): 199 | all_x = [torch.ones_like(x) for _ in range(dist.get_world_size())] 200 | dist.all_gather(all_x, x, async_op=False) 201 | return torch.cat(all_x, dim=0) 202 | 203 | # differentiable gather layer 204 | class GatherLayer(torch.autograd.Function): 205 | """Gather tensors from all process, supporting backward propagation.""" 206 | 207 | @staticmethod 208 | def forward(ctx, input): 209 | ctx.save_for_backward(input) 210 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 211 | dist.all_gather(output, input) 212 | return tuple(output) 213 | 214 | @staticmethod 215 | def backward(ctx, *grads): 216 | (input,) = ctx.saved_tensors 217 | grad_out = torch.zeros_like(input) 218 | grad_out[:] = grads[dist.get_rank()] 219 | return grad_out 220 | # -------------------------------------------- 221 | 222 | class SAM(torch.optim.Optimizer): 223 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 224 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 225 | 226 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 227 | super(SAM, self).__init__(params, defaults) 228 | 229 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 230 | self.param_groups = self.base_optimizer.param_groups 231 | 232 | @torch.no_grad() 233 | def first_step(self, zero_grad=False): 234 | grad_norm = self._grad_norm() 235 | for group in self.param_groups: 236 | scale = group["rho"] / (grad_norm + 1e-12) 237 | 238 | for p in group["params"]: 239 | if p.grad is None: continue 240 | self.state[p]["old_p"] = p.data.clone() 241 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 242 | p.add_(e_w) # climb to the local maximum "w + e(w)" 243 | 244 | if zero_grad: self.zero_grad() 245 | 246 | @torch.no_grad() 247 | def second_step(self, zero_grad=False): 248 | for group in self.param_groups: 249 | for p in group["params"]: 250 | if p.grad is None: continue 251 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 252 | 253 | self.base_optimizer.step() # do the actual "sharpness-aware" update 254 | 255 | if zero_grad: self.zero_grad() 256 | 257 | @torch.no_grad() 258 | def step(self, closure=None): 259 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 260 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 261 | 262 | self.first_step(zero_grad=True) 263 | closure() 264 | self.second_step() 265 | 266 | def _grad_norm(self): 267 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 268 | norm = torch.norm( 269 | torch.stack([ 270 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 271 | for group in self.param_groups for p in group["params"] 272 | if p.grad is not None 273 | ]), 274 | p=2 275 | ) 276 | return norm 277 | 278 | def load_state_dict(self, state_dict): 279 | super().load_state_dict(state_dict) 280 | self.base_optimizer.param_groups = self.param_groups 281 | -------------------------------------------------------------------------------- /src/data/imagenet/gen_lst_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | subset_num = 100 5 | 6 | root_dir = 'train' 7 | with open(root_dir+'_'+str(subset_num)+'.txt', 'w') as f: 8 | classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir()) 9 | seed = 1993 10 | np.random.seed(seed) 11 | subset_classes = np.random.choice(classes, subset_num, replace=False) 12 | 13 | for class_id, class_name in enumerate(subset_classes): 14 | folder_name = os.path.join(root_dir, class_name) 15 | for img_name in sorted(os.listdir(folder_name)): 16 | write_line = os.path.join(root_dir, class_name, img_name) 17 | write_line += ' ' + str(class_id) + '\n' 18 | f.write(write_line) 19 | 20 | root_dir = 'val' 21 | with open(root_dir+'_'+str(subset_num)+'.txt', 'w') as f: 22 | classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir()) 23 | seed = 1993 24 | np.random.seed(seed) 25 | subset_classes = np.random.choice(classes, subset_num, replace=False) 26 | 27 | for class_id, class_name in enumerate(subset_classes): 28 | folder_name = os.path.join(root_dir, class_name) 29 | for img_name in sorted(os.listdir(folder_name)): 30 | write_line = os.path.join(root_dir, class_name, img_name) 31 | write_line += ' ' + str(class_id) + '\n' 32 | f.write(write_line) 33 | -------------------------------------------------------------------------------- /src/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class BaseDataset(Dataset): 9 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory""" 10 | 11 | def __init__(self, data, transform, class_indices=None): 12 | """Initialization""" 13 | self.labels = data['y'] 14 | self.images = data['x'] 15 | self.transform = transform 16 | self.class_indices = class_indices 17 | 18 | def __len__(self): 19 | """Denotes the total number of samples""" 20 | return len(self.images) 21 | 22 | def __getitem__(self, index): 23 | """Generates one sample of data""" 24 | x = Image.open(self.images[index]).convert('RGB') 25 | x = self.transform(x) 26 | y = self.labels[index] 27 | return x, y 28 | 29 | 30 | def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None, trn_lst=None, tst_lst=None): 31 | """Prepare data: dataset splits, task partition, class order""" 32 | 33 | data = {} 34 | taskcla = [] 35 | 36 | # read filenames and labels 37 | if trn_lst is None and tst_lst is None: 38 | trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str) 39 | tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str) 40 | else: 41 | trn_lines = np.loadtxt(os.path.join(path, trn_lst), dtype=str) 42 | tst_lines = np.loadtxt(os.path.join(path, tst_lst), dtype=str) 43 | 44 | if class_order is None: 45 | num_classes = len(np.unique(trn_lines[:, 1])) 46 | class_order = list(range(num_classes)) 47 | else: 48 | num_classes = len(class_order) 49 | class_order = class_order.copy() 50 | # yujun: a little hack here, in no case shall we shuffle the classes 51 | # if shuffle_classes: 52 | # np.random.shuffle(class_order) 53 | 54 | # compute classes per task and num_tasks 55 | if nc_first_task is None: 56 | cpertask = np.array([num_classes // num_tasks] * num_tasks) 57 | for i in range(num_classes % num_tasks): 58 | cpertask[i] += 1 59 | else: 60 | assert nc_first_task < num_classes, "first task wants more classes than exist" 61 | remaining_classes = num_classes - nc_first_task 62 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2 63 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1)) 64 | for i in range(remaining_classes % (num_tasks - 1)): 65 | cpertask[i + 1] += 1 66 | 67 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes" 68 | cpertask_cumsum = np.cumsum(cpertask) 69 | init_class = np.concatenate(([0], cpertask_cumsum[:-1])) 70 | 71 | # initialize data structure 72 | for tt in range(num_tasks): 73 | data[tt] = {} 74 | data[tt]['name'] = 'task-' + str(tt) 75 | data[tt]['trn'] = {'x': [], 'y': []} 76 | data[tt]['val'] = {'x': [], 'y': []} 77 | data[tt]['tst'] = {'x': [], 'y': []} 78 | 79 | # ALL OR TRAIN 80 | for this_image, this_label in trn_lines: 81 | if not os.path.isabs(this_image): 82 | this_image = os.path.join(path, this_image) 83 | this_label = int(this_label) 84 | if this_label not in class_order: 85 | continue 86 | # If shuffling is false, it won't change the class number 87 | this_label = class_order.index(this_label) 88 | 89 | # add it to the corresponding split 90 | this_task = (this_label >= cpertask_cumsum).sum() 91 | data[this_task]['trn']['x'].append(this_image) 92 | data[this_task]['trn']['y'].append(this_label - init_class[this_task]) 93 | 94 | # ALL OR TEST 95 | for this_image, this_label in tst_lines: 96 | if not os.path.isabs(this_image): 97 | this_image = os.path.join(path, this_image) 98 | this_label = int(this_label) 99 | if this_label not in class_order: 100 | continue 101 | # If shuffling is false, it won't change the class number 102 | this_label = class_order.index(this_label) 103 | 104 | # add it to the corresponding split 105 | this_task = (this_label >= cpertask_cumsum).sum() 106 | data[this_task]['tst']['x'].append(this_image) 107 | data[this_task]['tst']['y'].append(this_label - init_class[this_task]) 108 | 109 | # check classes 110 | for tt in range(num_tasks): 111 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y'])) 112 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes" 113 | 114 | # validation 115 | if validation > 0.0: 116 | for tt in data.keys(): 117 | for cc in range(data[tt]['ncla']): 118 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 119 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation))) 120 | rnd_img.sort(reverse=True) 121 | for ii in range(len(rnd_img)): 122 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]]) 123 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]]) 124 | data[tt]['trn']['x'].pop(rnd_img[ii]) 125 | data[tt]['trn']['y'].pop(rnd_img[ii]) 126 | 127 | # other 128 | n = 0 129 | for t in data.keys(): 130 | taskcla.append((t, data[t]['ncla'])) 131 | n += data[t]['ncla'] 132 | data['ncla'] = n 133 | 134 | return data, taskcla, class_order 135 | -------------------------------------------------------------------------------- /src/datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils import data 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import MNIST as TorchVisionMNIST 7 | from torchvision.datasets import CIFAR100 as TorchVisionCIFAR100 8 | from torchvision.datasets import SVHN as TorchVisionSVHN 9 | 10 | from . import base_dataset as basedat 11 | from . import memory_dataset as memd 12 | from .dataset_config import dataset_config 13 | 14 | 15 | def get_loaders(datasets, num_tasks, nc_first_task, batch_size, num_workers, pin_memory, validation=.1, ddp=False): 16 | """Apply transformations to Datasets and create the DataLoaders for each task""" 17 | 18 | trn_load, val_load, tst_load = [], [], [] 19 | taskcla = [] 20 | dataset_offset = 0 21 | for idx_dataset, cur_dataset in enumerate(datasets, 0): 22 | # get configuration for current dataset 23 | dc = dataset_config[cur_dataset] 24 | 25 | # transformations 26 | trn_transform, tst_transform = get_transforms(resize=dc['resize'], 27 | resize_test=dc['resize_test'], 28 | pad=dc['pad'], 29 | crop=dc['crop'], 30 | cifar_crop=dc['cifar_crop'], 31 | flip=dc['flip'], 32 | normalize=dc['normalize'], 33 | extend_channel=dc['extend_channel']) 34 | 35 | # datasets 36 | trn_dset, val_dset, tst_dset, curtaskcla = get_datasets(cur_dataset, dc['path'], num_tasks, nc_first_task, 37 | validation=validation, 38 | trn_transform=trn_transform, 39 | tst_transform=tst_transform, 40 | class_order=dc['class_order']) 41 | 42 | # apply offsets in case of multiple datasets 43 | if idx_dataset > 0: 44 | for tt in range(num_tasks): 45 | trn_dset[tt].labels = [elem + dataset_offset for elem in trn_dset[tt].labels] 46 | val_dset[tt].labels = [elem + dataset_offset for elem in val_dset[tt].labels] 47 | tst_dset[tt].labels = [elem + dataset_offset for elem in tst_dset[tt].labels] 48 | dataset_offset = dataset_offset + sum([tc[1] for tc in curtaskcla]) 49 | 50 | # reassign class idx for multiple dataset case 51 | curtaskcla = [(tc[0] + idx_dataset * num_tasks, tc[1]) for tc in curtaskcla] 52 | 53 | # extend final taskcla list 54 | taskcla.extend(curtaskcla) 55 | 56 | # loaders 57 | if ddp: 58 | for tt in range(num_tasks): 59 | trn_sampler = torch.utils.data.DistributedSampler(trn_dset[tt], shuffle=True) 60 | trn_load.append(data.DataLoader(trn_dset[tt], batch_size=batch_size, num_workers=num_workers, 61 | sampler=trn_sampler, pin_memory=pin_memory)) 62 | val_load.append(data.DataLoader(val_dset[tt], batch_size=batch_size, shuffle=False, 63 | num_workers=num_workers, pin_memory=pin_memory)) 64 | tst_load.append(data.DataLoader(tst_dset[tt], batch_size=batch_size, shuffle=False, 65 | num_workers=num_workers, pin_memory=pin_memory)) 66 | else: 67 | for tt in range(num_tasks): 68 | trn_load.append(data.DataLoader(trn_dset[tt], batch_size=batch_size, shuffle=True, num_workers=num_workers, 69 | pin_memory=pin_memory)) 70 | val_load.append(data.DataLoader(val_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers, 71 | pin_memory=pin_memory)) 72 | tst_load.append(data.DataLoader(tst_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers, 73 | pin_memory=pin_memory)) 74 | return trn_load, val_load, tst_load, taskcla 75 | 76 | 77 | def get_datasets(dataset, path, num_tasks, nc_first_task, validation, trn_transform, tst_transform, class_order=None): 78 | """Extract datasets and create Dataset class""" 79 | 80 | trn_dset, val_dset, tst_dset = [], [], [] 81 | 82 | if 'mnist' in dataset: 83 | tvmnist_trn = TorchVisionMNIST(path, train=True, download=True) 84 | tvmnist_tst = TorchVisionMNIST(path, train=False, download=True) 85 | trn_data = {'x': tvmnist_trn.data.numpy(), 'y': tvmnist_trn.targets.tolist()} 86 | tst_data = {'x': tvmnist_tst.data.numpy(), 'y': tvmnist_tst.targets.tolist()} 87 | # compute splits 88 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 89 | num_tasks=num_tasks, nc_first_task=nc_first_task, 90 | shuffle_classes=class_order is None, class_order=class_order) 91 | # set dataset type 92 | Dataset = memd.MemoryDataset 93 | 94 | elif 'cifar100' in dataset: 95 | tvcifar_trn = TorchVisionCIFAR100(path, train=True, download=True) 96 | tvcifar_tst = TorchVisionCIFAR100(path, train=False, download=True) 97 | trn_data = {'x': tvcifar_trn.data, 'y': tvcifar_trn.targets} 98 | tst_data = {'x': tvcifar_tst.data, 'y': tvcifar_tst.targets} 99 | # compute splits 100 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 101 | num_tasks=num_tasks, nc_first_task=nc_first_task, 102 | shuffle_classes=class_order is None, class_order=class_order) 103 | # set dataset type 104 | Dataset = memd.MemoryDataset 105 | 106 | elif dataset == 'svhn': 107 | tvsvhn_trn = TorchVisionSVHN(path, split='train', download=True) 108 | tvsvhn_tst = TorchVisionSVHN(path, split='test', download=True) 109 | trn_data = {'x': tvsvhn_trn.data.transpose(0, 2, 3, 1), 'y': tvsvhn_trn.labels} 110 | tst_data = {'x': tvsvhn_tst.data.transpose(0, 2, 3, 1), 'y': tvsvhn_tst.labels} 111 | # Notice that SVHN in Torchvision has an extra training set in case needed 112 | # tvsvhn_xtr = TorchVisionSVHN(path, split='extra', download=True) 113 | # xtr_data = {'x': tvsvhn_xtr.data.transpose(0, 2, 3, 1), 'y': tvsvhn_xtr.labels} 114 | 115 | # compute splits 116 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 117 | num_tasks=num_tasks, nc_first_task=nc_first_task, 118 | shuffle_classes=class_order is None, class_order=class_order) 119 | # set dataset type 120 | Dataset = memd.MemoryDataset 121 | 122 | elif 'imagenet_32' in dataset: 123 | import pickle 124 | # load data 125 | x_trn, y_trn = [], [] 126 | for i in range(1, 11): 127 | with open(os.path.join(path, 'train_data_batch_{}'.format(i)), 'rb') as f: 128 | d = pickle.load(f) 129 | x_trn.append(d['data']) 130 | y_trn.append(np.array(d['labels']) - 1) # labels from 0 to 999 131 | with open(os.path.join(path, 'val_data'), 'rb') as f: 132 | d = pickle.load(f) 133 | x_trn.append(d['data']) 134 | y_tst = np.array(d['labels']) - 1 # labels from 0 to 999 135 | # reshape data 136 | for i, d in enumerate(x_trn, 0): 137 | x_trn[i] = d.reshape(d.shape[0], 3, 32, 32).transpose(0, 2, 3, 1) 138 | x_tst = x_trn[-1] 139 | x_trn = np.vstack(x_trn[:-1]) 140 | y_trn = np.concatenate(y_trn) 141 | trn_data = {'x': x_trn, 'y': y_trn} 142 | tst_data = {'x': x_tst, 'y': y_tst} 143 | # compute splits 144 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 145 | num_tasks=num_tasks, nc_first_task=nc_first_task, 146 | shuffle_classes=class_order is None, class_order=class_order) 147 | # set dataset type 148 | Dataset = memd.MemoryDataset 149 | 150 | elif dataset == 'imagenet_100': 151 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs 152 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task, 153 | validation=validation, shuffle_classes=class_order is None, 154 | class_order=class_order, trn_lst='train_100.txt', tst_lst='val_100.txt') 155 | # set dataset type 156 | Dataset = basedat.BaseDataset 157 | 158 | elif dataset == 'imagenet_1000': 159 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs 160 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task, 161 | validation=validation, shuffle_classes=class_order is None, 162 | class_order=class_order, trn_lst='train_1000.txt', tst_lst='val_1000.txt') 163 | # set dataset type 164 | Dataset = basedat.BaseDataset 165 | 166 | else: 167 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs 168 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task, 169 | validation=validation, shuffle_classes=class_order is None, 170 | class_order=class_order) 171 | # set dataset type 172 | Dataset = basedat.BaseDataset 173 | 174 | # get datasets, apply correct label offsets for each task 175 | offset = 0 176 | for task in range(num_tasks): 177 | all_data[task]['trn']['y'] = [label + offset for label in all_data[task]['trn']['y']] 178 | all_data[task]['val']['y'] = [label + offset for label in all_data[task]['val']['y']] 179 | all_data[task]['tst']['y'] = [label + offset for label in all_data[task]['tst']['y']] 180 | trn_dset.append(Dataset(all_data[task]['trn'], trn_transform, class_indices)) 181 | val_dset.append(Dataset(all_data[task]['val'], tst_transform, class_indices)) 182 | tst_dset.append(Dataset(all_data[task]['tst'], tst_transform, class_indices)) 183 | offset += taskcla[task][1] 184 | 185 | return trn_dset, val_dset, tst_dset, taskcla 186 | 187 | 188 | def get_transforms(resize, resize_test, pad, crop, cifar_crop, flip, normalize, extend_channel): 189 | """Unpack transformations and apply to train or test splits""" 190 | 191 | trn_transform_list = [] 192 | tst_transform_list = [] 193 | 194 | # resize 195 | if resize is not None: 196 | trn_transform_list.append(transforms.Resize(resize)) 197 | if resize_test is not None: 198 | tst_transform_list.append(transforms.Resize(resize_test)) 199 | 200 | # padding 201 | if pad is not None: 202 | trn_transform_list.append(transforms.Pad(pad)) 203 | # tst_transform_list.append(transforms.Pad(pad)) 204 | 205 | # crop 206 | if crop is not None: 207 | trn_transform_list.append(transforms.RandomResizedCrop(crop)) 208 | tst_transform_list.append(transforms.CenterCrop(crop)) 209 | 210 | if cifar_crop is not None: 211 | trn_transform_list.append(transforms.RandomCrop(cifar_crop)) 212 | 213 | # flips 214 | if flip: 215 | trn_transform_list.append(transforms.RandomHorizontalFlip()) 216 | 217 | # to tensor 218 | trn_transform_list.append(transforms.ToTensor()) 219 | tst_transform_list.append(transforms.ToTensor()) 220 | 221 | # normalization 222 | if normalize is not None: 223 | trn_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1])) 224 | tst_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1])) 225 | 226 | # gray to rgb 227 | if extend_channel is not None: 228 | trn_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1))) 229 | tst_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1))) 230 | 231 | print(trn_transform_list) 232 | print(tst_transform_list) 233 | return transforms.Compose(trn_transform_list), \ 234 | transforms.Compose(tst_transform_list) 235 | -------------------------------------------------------------------------------- /src/datasets/dataset_config.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | _BASE_DATA_PATH = "../data" 4 | 5 | dataset_config = { 6 | 'mnist': { 7 | 'path': join(_BASE_DATA_PATH, 'mnist'), 8 | 'normalize': ((0.1307,), (0.3081,)), 9 | # Use the next 3 lines to use MNIST with a 3x32x32 input 10 | # 'extend_channel': 3, 11 | # 'pad': 2, 12 | # 'normalize': ((0.1,), (0.2752,)) # values including padding 13 | }, 14 | 'svhn': { 15 | 'path': join(_BASE_DATA_PATH, 'svhn'), 16 | 'resize': (224, 224), 17 | 'crop': None, 18 | 'flip': False, 19 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 20 | }, 21 | 'cifar100': { 22 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 23 | 'resize': None, 24 | 'pad': 4, 25 | 'crop': 32, 26 | 'flip': True, 27 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) 28 | }, 29 | 'cifar100_icarl': { 30 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 31 | # 'resize': None, 32 | 'pad': 4, 33 | # 'crop': 32, 34 | 'cifar_crop': 32, 35 | 'flip': True, 36 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)), 37 | 'class_order': [ 38 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 39 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 40 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 41 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 42 | ] 43 | }, 44 | 'vggface2': { 45 | 'path': join(_BASE_DATA_PATH, 'VGGFace2'), 46 | 'resize': 256, 47 | 'crop': 224, 48 | 'flip': True, 49 | 'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169)) 50 | }, 51 | 'imagenet_1000': { 52 | 'path': join(_BASE_DATA_PATH, 'imagenet'), 53 | 'resize': None, 54 | 'resize_test': 256, 55 | 'crop': 224, 56 | 'flip': True, 57 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 58 | }, 59 | 'imagenet_100': { 60 | 'path': join(_BASE_DATA_PATH, 'imagenet'), 61 | 'resize': None, 62 | 'resize_test': 256, 63 | 'crop': 224, 64 | 'flip': True, 65 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 66 | }, 67 | 'cars': { 68 | 'path': join(_BASE_DATA_PATH, 'cars'), 69 | 'resize': None, 70 | 'resize_test': 256, 71 | 'crop': 224, 72 | 'flip': True, 73 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 74 | }, 75 | 'aircraft': { 76 | 'path': join(_BASE_DATA_PATH, 'aircraft'), 77 | 'resize': None, 78 | 'resize_test': 256, 79 | 'crop': 224, 80 | 'flip': True, 81 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 82 | }, 83 | 'imagenet_32_reduced': { 84 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_32'), 85 | 'resize': None, 86 | 'pad': 4, 87 | 'crop': 32, 88 | 'flip': True, 89 | 'normalize': ((0.481, 0.457, 0.408), (0.260, 0.253, 0.268)), 90 | 'class_order': [ 91 | 472, 46, 536, 806, 547, 976, 662, 12, 955, 651, 492, 80, 999, 996, 788, 471, 911, 907, 680, 126, 42, 882, 92 | 327, 719, 716, 224, 918, 647, 808, 261, 140, 908, 833, 925, 57, 388, 407, 215, 45, 479, 525, 641, 915, 923, 93 | 108, 461, 186, 843, 115, 250, 829, 625, 769, 323, 974, 291, 438, 50, 825, 441, 446, 200, 162, 373, 872, 112, 94 | 212, 501, 91, 672, 791, 370, 942, 172, 315, 959, 636, 635, 66, 86, 197, 182, 59, 736, 175, 445, 947, 268, 95 | 238, 298, 926, 851, 494, 760, 61, 293, 696, 659, 69, 819, 912, 486, 706, 343, 390, 484, 282, 729, 575, 731, 96 | 530, 32, 534, 838, 466, 734, 425, 400, 290, 660, 254, 266, 551, 775, 721, 134, 886, 338, 465, 236, 522, 655, 97 | 209, 861, 88, 491, 985, 304, 981, 560, 405, 902, 521, 909, 763, 455, 341, 905, 280, 776, 113, 434, 274, 581, 98 | 158, 738, 671, 702, 147, 718, 148, 35, 13, 585, 591, 371, 745, 281, 956, 935, 346, 352, 284, 604, 447, 415, 99 | 98, 921, 118, 978, 880, 509, 381, 71, 552, 169, 600, 334, 171, 835, 798, 77, 249, 318, 419, 990, 335, 374, 100 | 949, 316, 755, 878, 946, 142, 299, 863, 558, 306, 183, 417, 64, 765, 565, 432, 440, 939, 297, 805, 364, 735, 101 | 251, 270, 493, 94, 773, 610, 278, 16, 363, 92, 15, 593, 96, 468, 252, 699, 377, 95, 799, 868, 820, 328, 756, 102 | 81, 991, 464, 774, 584, 809, 844, 940, 720, 498, 310, 384, 619, 56, 406, 639, 285, 67, 634, 792, 232, 54, 103 | 664, 818, 513, 349, 330, 207, 361, 345, 279, 549, 944, 817, 353, 228, 312, 796, 193, 179, 520, 451, 871, 104 | 692, 60, 481, 480, 929, 499, 673, 331, 506, 70, 645, 759, 744, 459] 105 | }, 106 | 'mini_imagenet': { 107 | 'path': join(_BASE_DATA_PATH, 'mini_imagenet'), 108 | 'resize': None, 109 | 'crop': 84, 110 | 'flip': True, 111 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 112 | 'class_order': [26, 86, 2, 55, 75, 93, 16, 73, 54, 95, 53, 92, 78, 13, 7, 30, 22, 113 | 24, 33, 8, 43, 62, 3, 71, 45, 48, 6, 99, 82, 76, 60, 80, 90, 68, 114 | 51, 27, 18, 56, 63, 74, 1, 61, 42, 41, 4, 15, 17, 40, 38, 5, 91, 115 | 59, 0, 34, 28, 50, 11, 35, 23, 52, 10, 31, 66, 57, 79, 85, 32, 84, 116 | 14, 89, 19, 29, 49, 97, 98, 69, 20, 94, 72, 77, 25, 37, 81, 46, 39, 117 | 65, 58, 12, 88, 70, 87, 36, 21, 83, 9, 96, 67, 64, 47, 44] 118 | }, 119 | 120 | } 121 | 122 | # Add missing keys: 123 | for dset in dataset_config.keys(): 124 | for k in ['resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel', 'resize_test', 'cifar_crop']: 125 | if k not in dataset_config[dset].keys(): 126 | dataset_config[dset][k] = None 127 | if 'flip' not in dataset_config[dset].keys(): 128 | dataset_config[dset]['flip'] = False 129 | -------------------------------------------------------------------------------- /src/datasets/exemplars_dataset.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from argparse import ArgumentParser 3 | 4 | from datasets.memory_dataset import MemoryDataset 5 | 6 | 7 | class ExemplarsDataset(MemoryDataset): 8 | """Exemplar storage for approaches with an interface of Dataset""" 9 | 10 | def __init__(self, transform, class_indices, 11 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'): 12 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices) 13 | self.max_num_exemplars_per_class = num_exemplars_per_class 14 | self.max_num_exemplars = num_exemplars 15 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!' 16 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize()) 17 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name) 18 | self.exemplars_selector = selector_cls(self) 19 | 20 | # Returns a parser containing the approach specific parameters 21 | @staticmethod 22 | def extra_parser(args): 23 | parser = ArgumentParser("Exemplars Management Parameters") 24 | _group = parser.add_mutually_exclusive_group() 25 | _group.add_argument('--num-exemplars', default=0, type=int, required=False, 26 | help='Fixed memory, total number of exemplars (default=%(default)s)') 27 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False, 28 | help='Growing memory, number of exemplars per class (default=%(default)s)') 29 | parser.add_argument('--exemplar-selection', default='random', type=str, 30 | choices=['herding', 'random', 'entropy', 'distance'], 31 | required=False, help='Exemplar selection strategy (default=%(default)s)') 32 | return parser.parse_known_args(args) 33 | 34 | def _is_active(self): 35 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0 36 | 37 | def collect_exemplars(self, model, trn_loader, selection_transform, ddp): 38 | if self._is_active(): 39 | self.images, self.labels = self.exemplars_selector(model, trn_loader, selection_transform, ddp) 40 | -------------------------------------------------------------------------------- /src/datasets/exemplars_selection.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from contextlib import contextmanager 4 | from typing import Iterable 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, ConcatDataset 9 | from torchvision.transforms import Lambda 10 | 11 | from datasets.exemplars_dataset import ExemplarsDataset 12 | from networks.network import LLL_Net 13 | 14 | 15 | class ExemplarsSelector: 16 | """Exemplar selector for approaches with an interface of Dataset""" 17 | 18 | def __init__(self, exemplars_dataset: ExemplarsDataset): 19 | self.exemplars_dataset = exemplars_dataset 20 | 21 | def __call__(self, model: LLL_Net, trn_loader: DataLoader, transform, ddp=False): 22 | clock0 = time.time() 23 | exemplars_per_class = self._exemplars_per_class_num(model) 24 | with override_dataset_transform(trn_loader.dataset, transform) as ds_for_selection: 25 | # change loader and fix to go sequentially (shuffle=False), keeps same order for later, eval transforms 26 | sel_loader = DataLoader(ds_for_selection, batch_size=trn_loader.batch_size, shuffle=False, 27 | num_workers=trn_loader.num_workers, pin_memory=trn_loader.pin_memory) 28 | selected_indices = self._select_indices(model, sel_loader, exemplars_per_class, transform) 29 | if ddp: 30 | # make sure all process using the same exemplar set 31 | selected_indices = torch.from_numpy(np.array(selected_indices)).cuda() 32 | torch.distributed.broadcast(selected_indices, src=0) 33 | selected_indices = selected_indices.cpu().tolist() 34 | 35 | with override_dataset_transform(trn_loader.dataset, Lambda(lambda x: np.array(x))) as ds_for_raw: 36 | x, y = zip(*(ds_for_raw[idx] for idx in selected_indices)) 37 | clock1 = time.time() 38 | print('| Selected {:d} train exemplars, time={:5.1f}s'.format(len(x), clock1 - clock0)) 39 | return x, y 40 | 41 | def _exemplars_per_class_num(self, model: LLL_Net): 42 | if self.exemplars_dataset.max_num_exemplars_per_class: 43 | return self.exemplars_dataset.max_num_exemplars_per_class 44 | 45 | num_cls = model.task_cls.sum().item() 46 | num_exemplars = self.exemplars_dataset.max_num_exemplars 47 | exemplars_per_class = int(np.ceil(num_exemplars / num_cls)) 48 | assert exemplars_per_class > 0, \ 49 | "Not enough exemplars to cover all classes!\n" \ 50 | "Number of classes so far: {}. " \ 51 | "Limit of exemplars: {}".format(num_cls, 52 | num_exemplars) 53 | return exemplars_per_class 54 | 55 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 56 | pass 57 | 58 | 59 | class RandomExemplarsSelector(ExemplarsSelector): 60 | """Selection of new samples. This is based on random selection, which produces a random list of samples.""" 61 | 62 | def __init__(self, exemplars_dataset): 63 | super().__init__(exemplars_dataset) 64 | 65 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 66 | num_cls = sum(model.task_cls) 67 | result = [] 68 | labels = self._get_labels(sel_loader) 69 | for curr_cls in range(num_cls): 70 | # get all indices from current class -- check if there are exemplars from previous task in the loader 71 | cls_ind = np.where(labels == curr_cls)[0] 72 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 73 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 74 | # select the exemplars randomly 75 | result.extend(random.sample(list(cls_ind), exemplars_per_class)) 76 | return result 77 | 78 | def _get_labels(self, sel_loader): 79 | if hasattr(sel_loader.dataset, 'labels'): # BaseDataset, MemoryDataset 80 | labels = np.asarray(sel_loader.dataset.labels) 81 | elif isinstance(sel_loader.dataset, ConcatDataset): 82 | labels = [] 83 | for ds in sel_loader.dataset.datasets: 84 | labels.extend(ds.labels) 85 | labels = np.array(labels) 86 | else: 87 | raise RuntimeError("Unsupported dataset: {}".format(sel_loader.dataset.__class__.__name__)) 88 | return labels 89 | 90 | 91 | class HerdingExemplarsSelector(ExemplarsSelector): 92 | """Selection of new samples. This is based on herding selection, which produces a sorted list of samples of one 93 | class based on the distance to the mean sample of that class. From iCaRL algorithm 4 and 5: 94 | https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf 95 | """ 96 | def __init__(self, exemplars_dataset): 97 | super().__init__(exemplars_dataset) 98 | 99 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 100 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device 101 | 102 | # extract outputs from the model for all train samples 103 | extracted_features = [] 104 | extracted_targets = [] 105 | with torch.no_grad(): 106 | model.eval() 107 | for images, targets in sel_loader: 108 | feats = model(images.to(model_device), return_features=True)[1] 109 | feats = feats / feats.norm(dim=1).view(-1, 1) # Feature normalization 110 | extracted_features.append(feats) 111 | extracted_targets.extend(targets) 112 | extracted_features = (torch.cat(extracted_features)).cpu() 113 | extracted_targets = np.array(extracted_targets) 114 | result = [] 115 | # iterate through all classes 116 | for curr_cls in np.unique(extracted_targets): 117 | # get all indices from current class 118 | cls_ind = np.where(extracted_targets == curr_cls)[0] 119 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 120 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 121 | # get all extracted features for current class 122 | cls_feats = extracted_features[cls_ind] 123 | # calculate the mean 124 | cls_mu = cls_feats.mean(0) 125 | # select the exemplars closer to the mean of each class 126 | selected = [] 127 | selected_feat = [] 128 | for k in range(exemplars_per_class): 129 | # fix this to the dimension of the model features 130 | sum_others = torch.zeros(cls_feats.shape[1]) 131 | for j in selected_feat: 132 | sum_others += j / (k + 1) 133 | dist_min = np.inf 134 | # choose the closest to the mean of the current class 135 | for item in cls_ind: 136 | if item not in selected: 137 | feat = extracted_features[item] 138 | dist = torch.norm(cls_mu - feat / (k + 1) - sum_others) 139 | if dist < dist_min: 140 | dist_min = dist 141 | newone = item 142 | newonefeat = feat 143 | selected_feat.append(newonefeat) 144 | selected.append(newone) 145 | result.extend(selected) 146 | return result 147 | 148 | 149 | class EntropyExemplarsSelector(ExemplarsSelector): 150 | """Selection of new samples. This is based on entropy selection, which produces a sorted list of samples of one 151 | class based on entropy of each sample. From RWalk http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112 152 | """ 153 | def __init__(self, exemplars_dataset): 154 | super().__init__(exemplars_dataset) 155 | 156 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 157 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device 158 | 159 | # extract outputs from the model for all train samples 160 | extracted_logits = [] 161 | extracted_targets = [] 162 | with torch.no_grad(): 163 | model.eval() 164 | for images, targets in sel_loader: 165 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1)) 166 | extracted_targets.extend(targets) 167 | extracted_logits = (torch.cat(extracted_logits)).cpu() 168 | extracted_targets = np.array(extracted_targets) 169 | result = [] 170 | # iterate through all classes 171 | for curr_cls in np.unique(extracted_targets): 172 | # get all indices from current class 173 | cls_ind = np.where(extracted_targets == curr_cls)[0] 174 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 175 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 176 | # get all extracted features for current class 177 | cls_logits = extracted_logits[cls_ind] 178 | # select the exemplars with higher entropy (lower: -entropy) 179 | probs = torch.softmax(cls_logits, dim=1) 180 | log_probs = torch.log(probs) 181 | minus_entropy = (probs * log_probs).sum(1) # change sign of this variable for inverse order 182 | selected = cls_ind[minus_entropy.sort()[1][:exemplars_per_class]] 183 | result.extend(selected) 184 | return result 185 | 186 | 187 | class DistanceExemplarsSelector(ExemplarsSelector): 188 | """Selection of new samples. This is based on distance-based selection, which produces a sorted list of samples of 189 | one class based on closeness to decision boundary of each sample. From RWalk 190 | http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112 191 | """ 192 | def __init__(self, exemplars_dataset): 193 | super().__init__(exemplars_dataset) 194 | 195 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, 196 | transform) -> Iterable: 197 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device 198 | 199 | # extract outputs from the model for all train samples 200 | extracted_logits = [] 201 | extracted_targets = [] 202 | with torch.no_grad(): 203 | model.eval() 204 | for images, targets in sel_loader: 205 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1)) 206 | extracted_targets.extend(targets) 207 | extracted_logits = (torch.cat(extracted_logits)).cpu() 208 | extracted_targets = np.array(extracted_targets) 209 | result = [] 210 | # iterate through all classes 211 | for curr_cls in np.unique(extracted_targets): 212 | # get all indices from current class 213 | cls_ind = np.where(extracted_targets == curr_cls)[0] 214 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 215 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 216 | # get all extracted features for current class 217 | cls_logits = extracted_logits[cls_ind] 218 | # select the exemplars closer to boundary 219 | distance = cls_logits[:, curr_cls] # change sign of this variable for inverse order 220 | selected = cls_ind[distance.sort()[1][:exemplars_per_class]] 221 | result.extend(selected) 222 | return result 223 | 224 | 225 | def dataset_transforms(dataset, transform_to_change): 226 | if isinstance(dataset, ConcatDataset): 227 | r = [] 228 | for ds in dataset.datasets: 229 | r += dataset_transforms(ds, transform_to_change) 230 | return r 231 | else: 232 | old_transform = dataset.transform 233 | dataset.transform = transform_to_change 234 | return [(dataset, old_transform)] 235 | 236 | 237 | @contextmanager 238 | def override_dataset_transform(dataset, transform): 239 | try: 240 | datasets_with_orig_transform = dataset_transforms(dataset, transform) 241 | yield dataset 242 | finally: 243 | # get bac original transformations 244 | for ds, orig_transform in datasets_with_orig_transform: 245 | ds.transform = orig_transform 246 | -------------------------------------------------------------------------------- /src/datasets/memory_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MemoryDataset(Dataset): 8 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory""" 9 | 10 | def __init__(self, data, transform, class_indices=None): 11 | """Initialization""" 12 | self.labels = data['y'] 13 | self.images = data['x'] 14 | self.transform = transform 15 | self.class_indices = class_indices 16 | 17 | def __len__(self): 18 | """Denotes the total number of samples""" 19 | return len(self.images) 20 | 21 | def __getitem__(self, index): 22 | """Generates one sample of data""" 23 | x = Image.fromarray(self.images[index]) 24 | x = self.transform(x) 25 | y = self.labels[index] 26 | return x, y 27 | 28 | 29 | def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None): 30 | """Prepare data: dataset splits, task partition, class order""" 31 | 32 | data = {} 33 | taskcla = [] 34 | if class_order is None: 35 | num_classes = len(np.unique(trn_data['y'])) 36 | class_order = list(range(num_classes)) 37 | else: 38 | num_classes = len(class_order) 39 | class_order = class_order.copy() 40 | if shuffle_classes: 41 | np.random.shuffle(class_order) 42 | 43 | # compute classes per task and num_tasks 44 | if nc_first_task is None: 45 | cpertask = np.array([num_classes // num_tasks] * num_tasks) 46 | for i in range(num_classes % num_tasks): 47 | cpertask[i] += 1 48 | else: 49 | assert nc_first_task < num_classes, "first task wants more classes than exist" 50 | remaining_classes = num_classes - nc_first_task 51 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2 52 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1)) 53 | for i in range(remaining_classes % (num_tasks - 1)): 54 | cpertask[i + 1] += 1 55 | 56 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes" 57 | cpertask_cumsum = np.cumsum(cpertask) 58 | init_class = np.concatenate(([0], cpertask_cumsum[:-1])) 59 | 60 | # initialize data structure 61 | for tt in range(num_tasks): 62 | data[tt] = {} 63 | data[tt]['name'] = 'task-' + str(tt) 64 | data[tt]['trn'] = {'x': [], 'y': []} 65 | data[tt]['val'] = {'x': [], 'y': []} 66 | data[tt]['tst'] = {'x': [], 'y': []} 67 | 68 | # ALL OR TRAIN 69 | filtering = np.isin(trn_data['y'], class_order) 70 | if filtering.sum() != len(trn_data['y']): 71 | trn_data['x'] = trn_data['x'][filtering] 72 | trn_data['y'] = np.array(trn_data['y'])[filtering] 73 | for this_image, this_label in zip(trn_data['x'], trn_data['y']): 74 | # If shuffling is false, it won't change the class number 75 | this_label = class_order.index(this_label) 76 | # add it to the corresponding split 77 | this_task = (this_label >= cpertask_cumsum).sum() 78 | data[this_task]['trn']['x'].append(this_image) 79 | data[this_task]['trn']['y'].append(this_label - init_class[this_task]) 80 | 81 | # ALL OR TEST 82 | filtering = np.isin(tst_data['y'], class_order) 83 | if filtering.sum() != len(tst_data['y']): 84 | tst_data['x'] = tst_data['x'][filtering] 85 | tst_data['y'] = tst_data['y'][filtering] 86 | for this_image, this_label in zip(tst_data['x'], tst_data['y']): 87 | # If shuffling is false, it won't change the class number 88 | this_label = class_order.index(this_label) 89 | # add it to the corresponding split 90 | this_task = (this_label >= cpertask_cumsum).sum() 91 | data[this_task]['tst']['x'].append(this_image) 92 | data[this_task]['tst']['y'].append(this_label - init_class[this_task]) 93 | 94 | # check classes 95 | for tt in range(num_tasks): 96 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y'])) 97 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes" 98 | 99 | # validation 100 | if validation > 0.0: 101 | for tt in data.keys(): 102 | for cc in range(data[tt]['ncla']): 103 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 104 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation))) 105 | rnd_img.sort(reverse=True) 106 | for ii in range(len(rnd_img)): 107 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]]) 108 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]]) 109 | data[tt]['trn']['x'].pop(rnd_img[ii]) 110 | data[tt]['trn']['y'].pop(rnd_img[ii]) 111 | 112 | # convert them to numpy arrays 113 | for tt in data.keys(): 114 | for split in ['trn', 'val', 'tst']: 115 | data[tt][split]['x'] = np.asarray(data[tt][split]['x']) 116 | 117 | # other 118 | n = 0 119 | for t in data.keys(): 120 | taskcla.append((t, data[t]['ncla'])) 121 | n += data[t]['ncla'] 122 | data['ncla'] = n 123 | 124 | return data, taskcla, class_order 125 | -------------------------------------------------------------------------------- /src/exp_cifar_lucir.sh: -------------------------------------------------------------------------------- 1 | device_id=0 2 | SEED=0 3 | bz=64 4 | lr=0.1 5 | mom=0.9 6 | wd=5e-4 7 | data=cifar100_icarl 8 | network=resnet18_cifar 9 | nepochs=160 10 | 11 | appr=lucir 12 | lamb=5.0 13 | nc_first=50 14 | ntask=2 15 | 16 | first_task_bz=128 17 | first_task_lr=0.1 18 | 19 | 20 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name ${nc_first}_${ntask}_${SEED} \ 21 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \ 22 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 80 120 \ 23 | --clipping -1 --results-path results --save-models \ 24 | --approach $appr --lamb $lamb --first-task-bz $first_task_bz --first-task-lr $first_task_lr \ 25 | --num-exemplars-per-class 20 --exemplar-selection herding 26 | -------------------------------------------------------------------------------- /src/exp_cifar_lucir_cwd.sh: -------------------------------------------------------------------------------- 1 | device_id=0 2 | SEED=0 3 | bz=128 4 | lr=0.1 5 | mom=0.9 6 | wd=5e-4 7 | data=cifar100_icarl 8 | network=resnet18_cifar 9 | nepochs=160 10 | n_exemplar=20 11 | 12 | appr=lucir_cwd 13 | lamb=5.0 14 | nc_first=50 15 | ntask=6 16 | 17 | aux_coef=0.5 18 | rej_thresh=1 19 | first_task_lr=0.1 20 | first_task_bz=128 21 | 22 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \ 23 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \ 24 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 80 120 \ 25 | --clipping -1 --results-path results --save-models \ 26 | --approach $appr --lamb $lamb --num-exemplars-per-class $n_exemplar --exemplar-selection herding \ 27 | --aux-coef $aux_coef --reject-threshold $rej_thresh \ 28 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz 29 | 30 | -------------------------------------------------------------------------------- /src/exp_im100_joint.sh: -------------------------------------------------------------------------------- 1 | device_id=7 2 | SEED=1 3 | 4 | bz=128 5 | lr=0.1 6 | mom=0.9 7 | wd=1e-4 8 | data=imagenet_100 9 | network=resnet18 10 | nepochs=90 11 | 12 | appr=joint 13 | 14 | nc_first=10 15 | ntask=10 16 | 17 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \ 18 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \ 19 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \ 20 | --clipping -1 --results-path results --save-models \ 21 | --approach $appr 22 | 23 | 24 | -------------------------------------------------------------------------------- /src/exp_im100_lucir.sh: -------------------------------------------------------------------------------- 1 | device_id=0 2 | SEED=0 3 | 4 | bz=128 5 | lr=0.1 6 | mom=0.9 7 | wd=1e-4 8 | data=imagenet_100 9 | network=resnet18 10 | nepochs=90 11 | n_exemplar=20 12 | 13 | appr=lucir 14 | lamb=10.0 15 | 16 | nc_first=50 17 | ntask=6 18 | 19 | first_task_lr=0.1 20 | first_task_bz=128 21 | 22 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \ 23 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \ 24 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \ 25 | --clipping -1 --results-path results --save-models \ 26 | --approach $appr --lamb $lamb \ 27 | --num-exemplars-per-class $n_exemplar --exemplar-selection herding \ 28 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz 29 | 30 | -------------------------------------------------------------------------------- /src/exp_im100_lucir_cwd.sh: -------------------------------------------------------------------------------- 1 | device_id=0 2 | SEED=0 3 | bz=128 4 | lr=0.1 5 | mom=0.9 6 | wd=1e-4 7 | data=imagenet_100 8 | network=resnet18 9 | nepochs=90 10 | n_exemplar=20 11 | 12 | appr=lucir_cwd 13 | lamb=10.0 14 | nc_first=50 15 | ntask=6 16 | 17 | aux_coef=0.75 18 | rej_thresh=1 19 | first_task_lr=0.2 20 | first_task_bz=128 21 | 22 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \ 23 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \ 24 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \ 25 | --clipping -1 --results-path results --save-models \ 26 | --approach $appr --lamb $lamb --num-exemplars-per-class $n_exemplar --exemplar-selection herding \ 27 | --aux-coef $aux_coef --reject-threshold $rej_thresh \ 28 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz 29 | -------------------------------------------------------------------------------- /src/exp_im100_lucir_oracle.sh: -------------------------------------------------------------------------------- 1 | device_id=0 2 | SEED=0 3 | 4 | bz=128 5 | lr=0.1 6 | mom=0.9 7 | wd=1e-4 8 | data=imagenet_100 9 | network=resnet18 10 | nepochs=90 11 | n_exemplar=20 12 | 13 | appr=lucir_oracle 14 | lamb=10.0 15 | 16 | nc_first=50 17 | ntask=6 18 | 19 | first_task_lr=0.1 20 | first_task_bz=128 21 | aux_coef=10.0 22 | oracle_path=baselines/imagenet_subset_lucir_nc_first_99_ntask_2/models/task0.ckpt 23 | 24 | CUDA_VISIBLE_DEVICES=$device_id python3 main_incremental.py --exp-name nc_first_${nc_first}_ntask_${ntask} \ 25 | --datasets $data --num-tasks $ntask --nc-first-task $nc_first --network $network --seed $SEED \ 26 | --nepochs $nepochs --batch-size $bz --lr $lr --momentum $mom --weight-decay $wd --decay-mile-stone 30 60 \ 27 | --clipping -1 --results-path results \ 28 | --approach $appr --lamb $lamb \ 29 | --num-exemplars-per-class $n_exemplar --exemplar-selection herding \ 30 | --first-task-lr $first_task_lr --first-task-bz $first_task_bz \ 31 | --aux-coef $aux_coef --oracle-path $oracle_path 32 | -------------------------------------------------------------------------------- /src/gridsearch.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from argparse import ArgumentParser 4 | 5 | import utils 6 | 7 | 8 | class GridSearch: 9 | """Basic class for implementing hyperparameter grid search""" 10 | 11 | def __init__(self, appr_ft, seed, gs_config='gridsearch_config', acc_drop_thr=0.2, hparam_decay=0.5, 12 | max_num_searches=7): 13 | self.seed = seed 14 | GridSearchConfig = getattr(importlib.import_module(name=gs_config), 'GridSearchConfig') 15 | self.appr_ft = appr_ft 16 | self.gs_config = GridSearchConfig() 17 | self.acc_drop_thr = acc_drop_thr 18 | self.hparam_decay = hparam_decay 19 | self.max_num_searches = max_num_searches 20 | self.lr_first = 1.0 21 | 22 | @staticmethod 23 | def extra_parser(args): 24 | """Returns a parser containing the GridSearch specific parameters""" 25 | parser = ArgumentParser() 26 | # Configuration file with a GridSearchConfig class with all necessary args 27 | parser.add_argument('--gridsearch-config', type=str, default='gridsearch_config', required=False, 28 | help='Configuration file for GridSearch options (default=%(default)s)') 29 | # Accuracy threshold drop below which the search stops for that phase 30 | parser.add_argument('--gridsearch-acc-drop-thr', default=0.2, type=float, required=False, 31 | help='GridSearch accuracy drop threshold (default=%(default)f)') 32 | # Value at which hyperparameters decay 33 | parser.add_argument('--gridsearch-hparam-decay', default=0.5, type=float, required=False, 34 | help='GridSearch hyperparameter decay (default=%(default)f)') 35 | # Maximum number of searched before the search stops for that phase 36 | parser.add_argument('--gridsearch-max-num-searches', default=7, type=int, required=False, 37 | help='GridSearch maximum number of hyperparameter search (default=%(default)f)') 38 | return parser.parse_known_args(args) 39 | 40 | def search_lr(self, model, t, trn_loader, val_loader): 41 | """Search for accuracy and best LR on finetuning""" 42 | best_ft_acc = 0.0 43 | best_ft_lr = 0.0 44 | 45 | # Get general parameters and fix the ones with only one value 46 | gen_params = self.gs_config.get_params('general') 47 | for k, v in gen_params.items(): 48 | if not isinstance(v, list): 49 | setattr(self.appr_ft, k, v) 50 | if t > 0: 51 | # LR for search are 'lr_searches' largest LR below 'lr_first' 52 | list_lr = [lr for lr in gen_params['lr'] if lr < self.lr_first][:gen_params['lr_searches'][0]] 53 | else: 54 | # For first task, try larger LR range 55 | list_lr = gen_params['lr_first'] 56 | 57 | # Iterate through the other variable parameters 58 | for curr_lr in list_lr: 59 | utils.seed_everything(seed=self.seed) 60 | self.appr_ft.model = deepcopy(model) 61 | self.appr_ft.lr = curr_lr 62 | self.appr_ft.train(t, trn_loader, val_loader) 63 | _, ft_acc_taw, _ = self.appr_ft.eval(t, val_loader) 64 | if ft_acc_taw > best_ft_acc: 65 | best_ft_acc = ft_acc_taw 66 | best_ft_lr = curr_lr 67 | print('Current best LR: ' + str(best_ft_lr)) 68 | self.gs_config.current_lr = best_ft_lr 69 | print('Current best acc: {:5.1f}'.format(best_ft_acc * 100)) 70 | # After first task, keep LR used 71 | if t == 0: 72 | self.lr_first = best_ft_lr 73 | 74 | return best_ft_acc, best_ft_lr 75 | 76 | def search_tradeoff(self, appr_name, appr, t, trn_loader, val_loader, best_ft_acc): 77 | """Search for less-forgetting tradeoff with minimum accuracy loss""" 78 | best_tradeoff = None 79 | tradeoff_name = None 80 | 81 | # Get general parameters and fix all the ones that have only one option 82 | appr_params = self.gs_config.get_params(appr_name) 83 | for k, v in appr_params.items(): 84 | if isinstance(v, list): 85 | # get tradeoff name as the only one with multiple values 86 | tradeoff_name = k 87 | else: 88 | # Any other hyperparameters are fixed 89 | setattr(appr, k, v) 90 | 91 | # If there is no tradeoff, no need to gridsearch more 92 | if tradeoff_name is not None and t > 0: 93 | # get starting value for trade-off hyperparameter 94 | best_tradeoff = appr_params[tradeoff_name][0] 95 | # iterate through decreasing trade-off values -- limit to `max_num_searches` searches 96 | num_searches = 0 97 | while num_searches < self.max_num_searches: 98 | utils.seed_everything(seed=self.seed) 99 | # Make deepcopy of the appr without duplicating the logger 100 | appr_gs = type(appr)(deepcopy(appr.model), appr.device, exemplars_dataset=appr.exemplars_dataset) 101 | for attr, value in vars(appr).items(): 102 | if attr == 'logger': 103 | setattr(appr_gs, attr, value) 104 | else: 105 | setattr(appr_gs, attr, deepcopy(value)) 106 | 107 | # update tradeoff value 108 | setattr(appr_gs, tradeoff_name, best_tradeoff) 109 | # train this iteration 110 | appr_gs.train(t, trn_loader, val_loader) 111 | _, curr_acc, _ = appr_gs.eval(t, val_loader) 112 | print('Current acc: ' + str(curr_acc) + ' for ' + tradeoff_name + '=' + str(best_tradeoff)) 113 | # Check if accuracy is within acceptable threshold drop 114 | if curr_acc < ((1 - self.acc_drop_thr) * best_ft_acc): 115 | best_tradeoff = best_tradeoff * self.hparam_decay 116 | else: 117 | break 118 | num_searches += 1 119 | else: 120 | print('There is no trade-off to gridsearch.') 121 | 122 | return best_tradeoff, tradeoff_name 123 | -------------------------------------------------------------------------------- /src/gridsearch_config.py: -------------------------------------------------------------------------------- 1 | class GridSearchConfig(): 2 | def __init__(self): 3 | self.params = { 4 | 'general': { 5 | 'lr_first': [5e-1, 1e-1, 5e-2], 6 | 'lr': [1e-1, 5e-2, 1e-2, 5e-3, 1e-3], 7 | 'lr_searches': [3], 8 | 'lr_min': 1e-4, 9 | 'lr_factor': 3, 10 | 'lr_patience': 10, 11 | 'clipping': 10000, 12 | 'momentum': 0.9, 13 | 'wd': 0.0002 14 | }, 15 | 'finetuning': { 16 | }, 17 | 'freezing': { 18 | }, 19 | 'joint': { 20 | }, 21 | 'lwf': { 22 | 'lamb': [10], 23 | 'T': 2 24 | }, 25 | 'icarl': { 26 | 'lamb': [4] 27 | }, 28 | 'dmc': { 29 | 'aux_dataset': 'imagenet_32_reduced', 30 | 'aux_batch_size': 128 31 | }, 32 | 'il2m': { 33 | }, 34 | 'eeil': { 35 | 'lamb': [10], 36 | 'T': 2, 37 | 'lr_finetuning_factor': 0.1, 38 | 'nepochs_finetuning': 40, 39 | 'noise_grad': False 40 | }, 41 | 'bic': { 42 | 'T': 2, 43 | 'val_percentage': 0.1, 44 | 'bias_epochs': 200 45 | }, 46 | 'lucir': { 47 | 'lamda_base': [10], 48 | 'lamda_mr': 1.0, 49 | 'dist': 0.5, 50 | 'K': 2 51 | }, 52 | 'lwm': { 53 | 'beta': [2], 54 | 'gamma': 1.0 55 | }, 56 | 'ewc': { 57 | 'lamb': [10000] 58 | }, 59 | 'mas': { 60 | 'lamb': [400] 61 | }, 62 | 'path_integral': { 63 | 'lamb': [10], 64 | }, 65 | 'r_walk': { 66 | 'lamb': [20], 67 | }, 68 | } 69 | self.current_lr = self.params['general']['lr'][0] 70 | self.current_tradeoff = 0 71 | 72 | def get_params(self, approach): 73 | return self.params[approach] 74 | -------------------------------------------------------------------------------- /src/last_layer_analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | matplotlib.use('Agg') 6 | 7 | 8 | def last_layer_analysis(heads, task, taskcla, y_lim=False, sort_weights=False): 9 | """Plot last layer weight and bias analysis""" 10 | print('Plotting last layer analysis...') 11 | num_classes = sum([x for (_, x) in taskcla]) 12 | weights, biases, indexes = [], [], [] 13 | class_id = 0 14 | with torch.no_grad(): 15 | for t in range(task + 1): 16 | n_classes_t = taskcla[t][1] 17 | indexes.append(np.arange(class_id, class_id + n_classes_t)) 18 | if type(heads) == torch.nn.Linear: # Single head 19 | biases.append(heads.bias[class_id: class_id + n_classes_t].detach().cpu().numpy()) 20 | weights.append((heads.weight[class_id: class_id + n_classes_t] ** 2).sum(1).sqrt().detach().cpu().numpy()) 21 | else: # Multi-head 22 | weights.append((heads[t].weight ** 2).sum(1).sqrt().detach().cpu().numpy()) 23 | if type(heads[t]) == torch.nn.Linear: 24 | biases.append(heads[t].bias.detach().cpu().numpy()) 25 | else: 26 | biases.append(np.zeros(weights[-1].shape)) # For LUCIR 27 | class_id += n_classes_t 28 | 29 | # Figure weights 30 | f_weights = plt.figure(dpi=300) 31 | ax = f_weights.subplots(nrows=1, ncols=1) 32 | for i, (x, y) in enumerate(zip(indexes, weights), 0): 33 | if sort_weights: 34 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i)) 35 | else: 36 | ax.bar(x, y, label="Task {}".format(i)) 37 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif') 38 | ax.set_ylabel("Weights L2-norm", fontsize=11, fontfamily='serif') 39 | if num_classes is not None: 40 | ax.set_xlim(0, num_classes) 41 | if y_lim: 42 | ax.set_ylim(0, 5) 43 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif') 44 | 45 | # Figure biases 46 | f_biases = plt.figure(dpi=300) 47 | ax = f_biases.subplots(nrows=1, ncols=1) 48 | for i, (x, y) in enumerate(zip(indexes, biases), 0): 49 | if sort_weights: 50 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i)) 51 | else: 52 | ax.bar(x, y, label="Task {}".format(i)) 53 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif') 54 | ax.set_ylabel("Bias values", fontsize=11, fontfamily='serif') 55 | if num_classes is not None: 56 | ax.set_xlim(0, num_classes) 57 | if y_lim: 58 | ax.set_ylim(-1.0, 1.0) 59 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif') 60 | 61 | return f_weights, f_biases 62 | -------------------------------------------------------------------------------- /src/loggers/disk_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | from loggers.exp_logger import ExperimentLogger 9 | 10 | 11 | class Logger(ExperimentLogger): 12 | """Characterizes a disk logger""" 13 | 14 | def __init__(self, log_path, exp_name, begin_time=None): 15 | super(Logger, self).__init__(log_path, exp_name, begin_time) 16 | 17 | self.begin_time_str = self.begin_time.strftime("%Y-%m-%d-%H-%M") 18 | 19 | # Duplicate standard outputs 20 | sys.stdout = FileOutputDuplicator(sys.stdout, 21 | os.path.join(self.exp_path, 'stdout-{}.txt'.format(self.begin_time_str)), 'w') 22 | sys.stderr = FileOutputDuplicator(sys.stderr, 23 | os.path.join(self.exp_path, 'stderr-{}.txt'.format(self.begin_time_str)), 'w') 24 | 25 | # Raw log file 26 | self.raw_log_file = open(os.path.join(self.exp_path, "raw_log-{}.txt".format(self.begin_time_str)), 'a') 27 | 28 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 29 | if curtime is None: 30 | curtime = datetime.now() 31 | 32 | # Raw dump 33 | entry = {"task": task, "iter": iter, "name": name, "value": value, "group": group, 34 | "time": curtime.strftime("%Y-%m-%d-%H-%M")} 35 | self.raw_log_file.write(json.dumps(entry, sort_keys=True) + "\n") 36 | self.raw_log_file.flush() 37 | 38 | def log_args(self, args): 39 | with open(os.path.join(self.exp_path, 'args-{}.txt'.format(self.begin_time_str)), 'w') as f: 40 | json.dump(args.__dict__, f, separators=(',\n', ' : '), sort_keys=True) 41 | 42 | def log_result(self, array, name, step): 43 | if array.ndim <= 1: 44 | array = array[None] 45 | np.savetxt(os.path.join(self.exp_path, 'results', '{}-{}.txt'.format(name, self.begin_time_str)), 46 | array, '%.6f', delimiter='\t') 47 | 48 | def log_figure(self, name, iter, figure, curtime=None): 49 | curtime = datetime.now() 50 | figure.savefig(os.path.join(self.exp_path, 'figures', 51 | '{}_{}-{}.png'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S")))) 52 | figure.savefig(os.path.join(self.exp_path, 'figures', 53 | '{}_{}-{}.pdf'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S")))) 54 | 55 | def save_model(self, state_dict, task): 56 | torch.save(state_dict, os.path.join(self.exp_path, "models", "task{}.ckpt".format(task))) 57 | 58 | def __del__(self): 59 | self.raw_log_file.close() 60 | 61 | 62 | class FileOutputDuplicator(object): 63 | def __init__(self, duplicate, fname, mode): 64 | self.file = open(fname, mode) 65 | self.duplicate = duplicate 66 | 67 | def __del__(self): 68 | self.file.close() 69 | 70 | def write(self, data): 71 | self.file.write(data) 72 | self.duplicate.write(data) 73 | 74 | def flush(self): 75 | self.file.flush() 76 | -------------------------------------------------------------------------------- /src/loggers/exp_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from datetime import datetime 4 | 5 | 6 | class ExperimentLogger: 7 | """Main class for experiment logging""" 8 | 9 | def __init__(self, log_path, exp_name, begin_time=None): 10 | self.log_path = log_path 11 | self.exp_name = exp_name 12 | self.exp_path = os.path.join(log_path, exp_name) 13 | if begin_time is None: 14 | self.begin_time = datetime.now() 15 | else: 16 | self.begin_time = begin_time 17 | 18 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 19 | pass 20 | 21 | def log_args(self, args): 22 | pass 23 | 24 | def log_result(self, array, name, step): 25 | pass 26 | 27 | def log_figure(self, name, iter, figure, curtime=None): 28 | pass 29 | 30 | def save_model(self, state_dict, task): 31 | pass 32 | 33 | 34 | class MultiLogger(ExperimentLogger): 35 | """This class allows to use multiple loggers""" 36 | 37 | def __init__(self, log_path, exp_name, loggers=None, save_models=True): 38 | super(MultiLogger, self).__init__(log_path, exp_name) 39 | if os.path.exists(self.exp_path): 40 | print("WARNING: {} already exists!".format(self.exp_path)) 41 | else: 42 | os.makedirs(os.path.join(self.exp_path, 'models')) 43 | os.makedirs(os.path.join(self.exp_path, 'results')) 44 | os.makedirs(os.path.join(self.exp_path, 'figures')) 45 | 46 | self.save_models = save_models 47 | self.loggers = [] 48 | for l in loggers: 49 | lclass = getattr(importlib.import_module(name='loggers.' + l + '_logger'), 'Logger') 50 | self.loggers.append(lclass(self.log_path, self.exp_name)) 51 | 52 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 53 | if curtime is None: 54 | curtime = datetime.now() 55 | for l in self.loggers: 56 | l.log_scalar(task, iter, name, value, group, curtime) 57 | 58 | def log_args(self, args): 59 | for l in self.loggers: 60 | l.log_args(args) 61 | 62 | def log_result(self, array, name, step): 63 | for l in self.loggers: 64 | l.log_result(array, name, step) 65 | 66 | def log_figure(self, name, iter, figure, curtime=None): 67 | if curtime is None: 68 | curtime = datetime.now() 69 | for l in self.loggers: 70 | l.log_figure(name, iter, figure, curtime) 71 | 72 | def save_model(self, state_dict, task): 73 | if self.save_models: 74 | for l in self.loggers: 75 | l.save_model(state_dict, task) 76 | -------------------------------------------------------------------------------- /src/loggers/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | from loggers.exp_logger import ExperimentLogger 4 | import json 5 | import numpy as np 6 | 7 | 8 | class Logger(ExperimentLogger): 9 | """Characterizes a Tensorboard logger""" 10 | 11 | def __init__(self, log_path, exp_name, begin_time=None): 12 | super(Logger, self).__init__(log_path, exp_name, begin_time) 13 | self.tbwriter = SummaryWriter(self.exp_path) 14 | 15 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 16 | self.tbwriter.add_scalar(tag="t{}/{}_{}".format(task, group, name), 17 | scalar_value=value, 18 | global_step=iter) 19 | self.tbwriter.file_writer.flush() 20 | 21 | def log_figure(self, name, iter, figure, curtime=None): 22 | self.tbwriter.add_figure(tag=name, figure=figure, global_step=iter) 23 | self.tbwriter.file_writer.flush() 24 | 25 | def log_args(self, args): 26 | self.tbwriter.add_text( 27 | 'args', 28 | json.dumps(args.__dict__, 29 | separators=(',\n', ' : '), 30 | sort_keys=True)) 31 | self.tbwriter.file_writer.flush() 32 | 33 | def log_result(self, array, name, step): 34 | if array.ndim == 1: 35 | # log as scalars 36 | self.tbwriter.add_scalar(f'results/{name}', array[step], step) 37 | 38 | elif array.ndim == 2: 39 | s = "" 40 | i = step 41 | # for i in range(array.shape[0]): 42 | for j in range(array.shape[1]): 43 | s += '{:5.1f}% '.format(100 * array[i, j]) 44 | if np.trace(array) == 0.0: 45 | if i > 0: 46 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i].mean()) 47 | else: 48 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i + 1].mean()) 49 | self.tbwriter.add_text(f'results/{name}', s, step) 50 | 51 | def __del__(self): 52 | self.tbwriter.close() 53 | -------------------------------------------------------------------------------- /src/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | 3 | from .resnet18 import resnet18 4 | from .resnet18_cifar import resnet18_cifar 5 | 6 | # available torchvision models 7 | tvmodels = ['alexnet', 8 | 'densenet121', 'densenet169', 'densenet201', 'densenet161', 9 | 'googlenet', 10 | 'inception_v3', 11 | 'mobilenet_v2', 12 | 'resnet34', 13 | 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 14 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 15 | 'squeezenet1_0', 'squeezenet1_1', 16 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', 17 | 'wide_resnet50_2', 'wide_resnet101_2' 18 | ] 19 | 20 | allmodels = tvmodels + ['resnet18', 'resnet18_cifar'] 21 | 22 | 23 | def set_tvmodel_head_var(model): 24 | if type(model) == models.AlexNet: 25 | model.head_var = 'classifier' 26 | elif type(model) == models.DenseNet: 27 | model.head_var = 'classifier' 28 | elif type(model) == models.Inception3: 29 | model.head_var = 'fc' 30 | elif type(model) == models.ResNet: 31 | model.head_var = 'fc' 32 | elif type(model) == models.VGG: 33 | model.head_var = 'classifier' 34 | elif type(model) == models.GoogLeNet: 35 | model.head_var = 'fc' 36 | elif type(model) == models.MobileNetV2: 37 | model.head_var = 'classifier' 38 | elif type(model) == models.ShuffleNetV2: 39 | model.head_var = 'fc' 40 | elif type(model) == models.SqueezeNet: 41 | model.head_var = 'classifier' 42 | else: 43 | raise ModuleNotFoundError 44 | -------------------------------------------------------------------------------- /src/networks/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from copy import deepcopy 4 | 5 | 6 | class LLL_Net(nn.Module): 7 | """Basic class for implementing networks""" 8 | 9 | def __init__(self, model, remove_existing_head=False): 10 | head_var = model.head_var 11 | assert type(head_var) == str 12 | assert not remove_existing_head or hasattr(model, head_var), \ 13 | "Given model does not have a variable called {}".format(head_var) 14 | assert not remove_existing_head or type(getattr(model, head_var)) in [nn.Sequential, nn.Linear], \ 15 | "Given model's head {} does is not an instance of nn.Sequential or nn.Linear".format(head_var) 16 | super(LLL_Net, self).__init__() 17 | 18 | self.model = model 19 | last_layer = getattr(self.model, head_var) 20 | 21 | if remove_existing_head: 22 | if type(last_layer) == nn.Sequential: 23 | self.out_size = last_layer[-1].in_features 24 | # strips off last linear layer of classifier 25 | del last_layer[-1] 26 | elif type(last_layer) == nn.Linear: 27 | self.out_size = last_layer.in_features 28 | # converts last layer into identity 29 | # setattr(self.model, head_var, nn.Identity()) 30 | # WARNING: this is for when pytorch version is <1.2 31 | setattr(self.model, head_var, nn.Sequential()) 32 | else: 33 | self.out_size = last_layer.out_features 34 | 35 | self.heads = nn.ModuleList() 36 | self.task_cls = [] 37 | self.task_offset = [] 38 | self._initialize_weights() 39 | 40 | def add_head(self, num_outputs): 41 | """Add a new head with the corresponding number of outputs. Also update the number of classes per task and the 42 | corresponding offsets 43 | """ 44 | self.heads.append(nn.Linear(self.out_size, num_outputs, bias=False)) 45 | # we re-compute instead of append in case an approach makes changes to the heads 46 | self.task_cls = torch.tensor([head.out_features for head in self.heads]) 47 | self.task_offset = torch.cat([torch.LongTensor(1).zero_(), self.task_cls.cumsum(0)[:-1]]) 48 | 49 | def forward(self, x, return_features=False): 50 | """Applies the forward pass 51 | 52 | Simplification to work on multi-head only -- returns all head outputs in a list 53 | Args: 54 | x (tensor): input images 55 | return_features (bool): return the representations before the heads 56 | """ 57 | x = self.model(x) 58 | assert (len(self.heads) > 0), "Cannot access any head" 59 | y = [] 60 | for head in self.heads: 61 | y.append(head(x)) 62 | if return_features: 63 | return y, x 64 | else: 65 | return y 66 | 67 | # hard coded a interface specifically for podnet 68 | # output: prediction y, features x, pod_features 69 | def forward_pod(self, x): 70 | x, pod_features = self.model(x, return_pod=True) 71 | y = [] 72 | for head in self.heads: 73 | y.append(head(x)) 74 | return y, x, pod_features 75 | 76 | def forward_repres(self, x): 77 | repres = self.model(x) 78 | return repres 79 | 80 | def forward_cls(self, repres): 81 | y = [] 82 | for head in self.heads: 83 | y.append(head(repres)) 84 | return y 85 | 86 | def get_copy(self): 87 | """Get weights from the model""" 88 | return deepcopy(self.state_dict()) 89 | 90 | def set_state_dict(self, state_dict): 91 | """Load weights into the model""" 92 | self.load_state_dict(deepcopy(state_dict)) 93 | return 94 | 95 | def freeze_all(self): 96 | """Freeze all parameters from the model, including the heads""" 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def freeze_backbone(self): 101 | """Freeze all parameters from the main model, but not the heads""" 102 | for param in self.model.parameters(): 103 | param.requires_grad = False 104 | 105 | def freeze_bn(self): 106 | """Freeze all Batch Normalization layers from the model and use them in eval() mode""" 107 | for m in self.model.modules(): 108 | if isinstance(m, nn.BatchNorm2d): 109 | m.eval() 110 | 111 | def _initialize_weights(self): 112 | """Initialize weights using different strategies""" 113 | # TODO: add different initialization strategies 114 | pass 115 | -------------------------------------------------------------------------------- /src/networks/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | def conv3x3(in_planes, out_planes, stride=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv3x3(planes, planes) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | class ResNet(nn.Module): 42 | 43 | def __init__(self, block, layers, num_classes=1000): 44 | self.inplanes = 64 45 | super(ResNet, self).__init__() 46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 47 | bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 51 | self.layer1 = self._make_layer(block, 64, layers[0]) 52 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 53 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 54 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 55 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 56 | 57 | # last classifier layer (head) with as many outputs as classes 58 | self.fc = nn.Linear(512 * block.expansion, num_classes) 59 | self.last_dim = self.fc.in_features 60 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 61 | self.head_var = 'fc' 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 66 | elif isinstance(m, nn.BatchNorm2d): 67 | nn.init.constant_(m.weight, 1) 68 | nn.init.constant_(m.bias, 0) 69 | 70 | def _make_layer(self, block, planes, blocks, stride=1): 71 | downsample = None 72 | if stride != 1 or self.inplanes != planes * block.expansion: 73 | downsample = nn.Sequential( 74 | nn.Conv2d(self.inplanes, planes * block.expansion, 75 | kernel_size=1, stride=stride, bias=False), 76 | nn.BatchNorm2d(planes * block.expansion), 77 | ) 78 | 79 | layers = [] 80 | layers.append(block(self.inplanes, planes, stride, downsample)) 81 | self.inplanes = planes * block.expansion 82 | 83 | for i in range(1, blocks): 84 | layers.append(block(self.inplanes, planes)) 85 | 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | x = self.conv1(x) 90 | x = self.bn1(x) 91 | x = self.relu(x) 92 | x = self.maxpool(x) 93 | 94 | x = self.layer1(x) 95 | x = self.layer2(x) 96 | x = self.layer3(x) 97 | x = self.layer4(x) 98 | 99 | x = self.avgpool(x) 100 | x = x.view(x.size(0), -1) 101 | x = self.fc(x) 102 | 103 | return x 104 | 105 | 106 | def resnet18(pretrained=False, **kwargs): 107 | """Constructs a ResNet-18 model. 108 | 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | """ 112 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 113 | return model 114 | -------------------------------------------------------------------------------- /src/networks/resnet18_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | def conv3x3(in_planes, out_planes, stride=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv3x3(planes, planes) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | class ResNet(nn.Module): 42 | 43 | def __init__(self, block, layers, num_classes=1000): 44 | self.inplanes = 64 45 | super(ResNet, self).__init__() 46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 47 | bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.layer1 = self._make_layer(block, 64, layers[0]) 51 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 52 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 53 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 54 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 55 | 56 | # last classifier layer (head) with as many outputs as classes 57 | self.fc = nn.Linear(512 * block.expansion, num_classes) 58 | self.last_dim = self.fc.in_features 59 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 60 | self.head_var = 'fc' 61 | 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 65 | elif isinstance(m, nn.BatchNorm2d): 66 | nn.init.constant_(m.weight, 1) 67 | nn.init.constant_(m.bias, 0) 68 | 69 | def _make_layer(self, block, planes, blocks, stride=1): 70 | downsample = None 71 | if stride != 1 or self.inplanes != planes * block.expansion: 72 | downsample = nn.Sequential( 73 | nn.Conv2d(self.inplanes, planes * block.expansion, 74 | kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(planes * block.expansion), 76 | ) 77 | 78 | layers = [] 79 | layers.append(block(self.inplanes, planes, stride, downsample)) 80 | self.inplanes = planes * block.expansion 81 | 82 | for i in range(1, blocks): 83 | layers.append(block(self.inplanes, planes)) 84 | 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | x = self.conv1(x) 89 | x = self.bn1(x) 90 | x = self.relu(x) 91 | 92 | x = self.layer1(x) 93 | x = self.layer2(x) 94 | x = self.layer3(x) 95 | x = self.layer4(x) 96 | 97 | x = self.avgpool(x) 98 | x = x.view(x.size(0), -1) 99 | x = self.fc(x) 100 | 101 | return x 102 | 103 | 104 | def resnet18_cifar(pretrained=False, **kwargs): 105 | """Constructs a ResNet-18 model. 106 | 107 | Args: 108 | pretrained (bool): If True, returns a model pre-trained on ImageNet 109 | """ 110 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 111 | return model 112 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from sklearn.metrics import confusion_matrix 6 | 7 | cudnn_deterministic = True 8 | 9 | 10 | def seed_everything(seed=0): 11 | """Fix all random seeds""" 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | torch.backends.cudnn.deterministic = cudnn_deterministic 18 | 19 | 20 | def print_summary(taskcla, acc_taw, acc_tag, forg_taw, forg_tag): 21 | """Print summary of results""" 22 | tag_acc = [] 23 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]): 24 | print('*' * 108) 25 | print(name) 26 | for i in range(metric.shape[0]): 27 | print('\t', end='') 28 | for j in range(metric.shape[1]): 29 | print('{:5.1f}% '.format(100 * metric[i, j]), end='') 30 | 31 | # calculate average 32 | task_weight = np.array([ncla for _,ncla in taskcla[0:i+1]]) 33 | task_weight = task_weight / task_weight.sum() 34 | 35 | if np.trace(metric) == 0.0: 36 | if i > 0: 37 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='') 38 | else: 39 | avg_metric = 100 * (metric[i, :i + 1]*task_weight).sum() 40 | print('\tAvg.:{:5.1f}% '.format(avg_metric), end='') 41 | if name == 'TAg Acc': 42 | tag_acc.append(avg_metric) 43 | print() 44 | print('*' * 108) 45 | avg_tag_acc = np.array(tag_acc).mean() 46 | print('Average Incremental Accuracy: ', avg_tag_acc) 47 | print('done') 48 | 49 | # save results of abalation study 50 | def save_summary(save_path, taskcla, acc_taw, acc_tag, forg_taw, forg_tag, appr_args): 51 | """save summary of results""" 52 | with open(save_path, 'w') as f: 53 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]): 54 | f.write('*' * 108 + '\n') 55 | f.write(name + '\n') 56 | for i in range(metric.shape[0]): 57 | f.write('\t') 58 | for j in range(metric.shape[1]): 59 | f.write('{:5.1f}% '.format(100 * metric[i, j])) 60 | 61 | # calculate average 62 | task_weight = np.array([ncla for _,ncla in taskcla[0:i+1]]) 63 | task_weight = task_weight / task_weight.sum() 64 | 65 | if np.trace(metric) == 0.0: 66 | if i > 0: 67 | f.write('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean())) 68 | else: 69 | f.write('\tAvg.:{:5.1f}% '.format(100 * (metric[i, :i + 1]*task_weight).sum())) 70 | f.write('\n') 71 | 72 | # --------------- approach arguments ------------------ 73 | f.write('*' * 108 + '\n') 74 | f.write('Approach arguments =\n') 75 | for arg in np.sort(list(vars(appr_args).keys())): 76 | f.write('\t' + arg + ': ' + str(getattr(appr_args, arg)) + '\n') 77 | f.write('=' * 108 + '\n') 78 | # ----------------------------------------------------- 79 | 80 | # val_loaders: a list of data loaders 81 | def compute_confusion_matrix(model, val_loaders, num_classes): 82 | with torch.no_grad(): 83 | model.eval() 84 | num_classes = sum([head.out_features for head in model.heads]) 85 | cm = np.zeros((num_classes, num_classes)) 86 | for loader in val_loaders: 87 | for images, targets in loader: 88 | images = images.cuda() 89 | outputs = model(images) 90 | outputs = torch.cat(outputs, dim=1) 91 | 92 | pred = outputs.argmax(dim=1).cpu().numpy() 93 | targets = targets.cpu().numpy() 94 | cm += confusion_matrix(targets, pred, labels=np.arange(num_classes)) 95 | return cm 96 | --------------------------------------------------------------------------------