├── .DS_Store ├── LICENSE ├── README.md ├── SizeOfSubnetworks.yml ├── env.yml ├── scrollnet.gif └── src ├── .DS_Store ├── approach ├── __init__.py ├── ewc.py ├── finetuning.py ├── incremental_learning.py └── lwf.py ├── datasets ├── base_dataset.py ├── data_loader.py ├── dataset_config.py ├── exemplars_dataset.py ├── exemplars_selection.py └── memory_dataset.py ├── loggers ├── README.md ├── __pycache__ │ ├── disk_logger.cpython-37.pyc │ ├── disk_logger.cpython-38.pyc │ ├── exp_logger.cpython-37.pyc │ └── exp_logger.cpython-38.pyc ├── disk_logger.py ├── exp_logger.py └── tensorboard_logger.py ├── main_incremental.py ├── networks ├── __init__.py ├── network.py ├── scroll_resnet18.py └── slimmable_ops.py ├── utils.py └── widths └── config.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 FireFYF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | # [ScrollNet: Dynamic Weight Importance for Continual Learning](http://arxiv.org/abs/2308.16567) 5 | 6 |
7 | 8 | # Introduction 9 | The official Pytorch implementation for ScrollNet: Dynamic Weight Importance for Continual Learning, Visual Continual Learning workshop, ICCV 2023. 10 | 11 | # Installation 12 | ## Clone this github repository 13 | ``` 14 | git clone https://github.com/FireFYF/ScrollNet.git 15 | cd ScrollNet 16 | ``` 17 | ## Create a conda enviroment 18 | ``` 19 | conda env create --file env.yml --name ScrollNet 20 | ``` 21 | *Notice:* set the appropriate version of your CUDA driver for `cudatoolkit` in `env.yml`. 22 | ## Environment activation/deactivation 23 | ``` 24 | conda activate ScrollNet 25 | conda deactivate 26 | ``` 27 | 28 | # Launch experiments 29 | 30 | ## Run with ScrollNet-FT 31 | ``` 32 | python -u src/main_incremental.py --gpu 0 --approach finetuning --results-path ./results/5splits/scrollnet_ft --num-tasks 5 33 | ``` 34 | ## Run with ScrollNet-LWF 35 | ``` 36 | python -u src/main_incremental.py --gpu 0 --approach lwf --results-path ./results/5splits/scrollnet_lwf --num-tasks 5 37 | ``` 38 | ## Run with ScrollNet-EWC 39 | ``` 40 | python -u src/main_incremental.py --gpu 0 --approach ewc --results-path ./results/5splits/scrollnet_ewc --num-tasks 5 41 | ``` 42 | 43 | # Tune the number of subnetworks 44 | Please modify the file 'SizeOfSubnetworks.yml'. The default setting is for 4 subnetworks with equal splitting (ScrollNet-4). 45 | 46 | # Acknowledgement 47 | The implementation is based on [FACIL](https://github.com/mmasana/FACIL), which was developed as a framework based on class-incremental learning. We suggest referring to it if you want to incorporate more CL methods into ScrollNet. 48 | 49 | # Cite 50 | If you find this work useful for your research, please cite: 51 | ```bibtex 52 | @misc{yang2023scrollnet, 53 | title={ScrollNet: Dynamic Weight Importance for Continual Learning}, 54 | author={Fei Yang and Kai Wang and Joost van de Weijer}, 55 | year={2023}, 56 | eprint={2308.16567}, 57 | archivePrefix={arXiv}, 58 | primaryClass={cs.CV} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /SizeOfSubnetworks.yml: -------------------------------------------------------------------------------- 1 | # =========================== Size of each subnetwork in ScrollNet =========================== 2 | width_mult: 1.0 3 | width_mult_list: [0.25, 0.5, 0.75, 1.0] 4 | 5 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: scrollnet 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - certifi=2022.12.7=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.2=h6a678d5_6 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=1.1.1t=h7f8727e_0 16 | - pip=23.0.1=py38h06a4308_0 17 | - python=3.8.16=h7a1cb2a_3 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py38h06a4308_0 20 | - sqlite=3.41.1=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.38.4=py38h06a4308_0 23 | - xz=5.2.10=h5eee18b_1 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - charset-normalizer==3.1.0 27 | - cmake==3.26.0 28 | - contourpy==1.0.7 29 | - cycler==0.11.0 30 | - filelock==3.10.0 31 | - fonttools==4.39.2 32 | - idna==3.4 33 | - importlib-resources==5.12.0 34 | - jinja2==3.1.2 35 | - kiwisolver==1.4.4 36 | - lit==15.0.7 37 | - markupsafe==2.1.2 38 | - matplotlib==3.7.1 39 | - mpmath==1.3.0 40 | - networkx==3.0 41 | - numpy==1.24.2 42 | - nvidia-cublas-cu11==11.10.3.66 43 | - nvidia-cuda-cupti-cu11==11.7.101 44 | - nvidia-cuda-nvrtc-cu11==11.7.99 45 | - nvidia-cuda-runtime-cu11==11.7.99 46 | - nvidia-cudnn-cu11==8.5.0.96 47 | - nvidia-cufft-cu11==10.9.0.58 48 | - nvidia-curand-cu11==10.2.10.91 49 | - nvidia-cusolver-cu11==11.4.0.1 50 | - nvidia-cusparse-cu11==11.7.4.91 51 | - nvidia-nccl-cu11==2.14.3 52 | - nvidia-nvtx-cu11==11.7.91 53 | - packaging==23.0 54 | - pillow==9.4.0 55 | - ptflops==0.7 56 | - pyparsing==3.0.9 57 | - python-dateutil==2.8.2 58 | - pyyaml==6.0 59 | - requests==2.28.2 60 | - six==1.16.0 61 | - sympy==1.11.1 62 | - torch==2.0.0 63 | - torchaudio==2.0.1 64 | - torchvision==0.15.1 65 | - triton==2.0.0 66 | - typing-extensions==4.5.0 67 | - urllib3==1.26.15 68 | - zipp==3.15.0 69 | -------------------------------------------------------------------------------- /scrollnet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/scrollnet.gif -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/.DS_Store -------------------------------------------------------------------------------- /src/approach/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # list all approaches available 4 | __all__ = list( 5 | map(lambda x: x[:-3], 6 | filter(lambda x: x not in ['__init__.py', 'incremental_learning.py'] and x.endswith('.py'), 7 | os.listdir(os.path.dirname(__file__)) 8 | ) 9 | ) 10 | ) 11 | -------------------------------------------------------------------------------- /src/approach/ewc.py: -------------------------------------------------------------------------------- 1 | from turtle import width 2 | import torch 3 | import itertools 4 | from argparse import ArgumentParser 5 | 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | from .incremental_learning import Inc_Learning_Appr 8 | from widths.config import FLAGS 9 | 10 | class Appr(Inc_Learning_Appr): 11 | """Class implementing the Elastic Weight Consolidation (EWC) approach 12 | described in http://arxiv.org/abs/1612.00796 13 | """ 14 | 15 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 16 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False, eval_on_train=False, 17 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred', 18 | fi_num_samples=-1): 19 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 20 | multi_softmax, scroll_step, fix_bn, eval_on_train, logger, 21 | exemplars_dataset) 22 | self.lamb = lamb 23 | self.alpha = alpha 24 | self.sampling_type = fi_sampling_type 25 | self.num_samples = fi_num_samples 26 | self.scroll_step = scroll_step 27 | 28 | # In all cases, we only keep importance weights for the model, but not for the heads. 29 | feat_ext = self.model.model 30 | # Store current parameters as the initial parameters before first task starts 31 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad} 32 | # Store fisher information weight importance 33 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 34 | if p.requires_grad} 35 | 36 | @staticmethod 37 | def exemplars_dataset_class(): 38 | return ExemplarsDataset 39 | 40 | @staticmethod 41 | def extra_parser(args): 42 | """Returns a parser containing the approach specific parameters""" 43 | parser = ArgumentParser() 44 | # Eq. 3: "lambda sets how important the old task is compared to the new one" 45 | parser.add_argument('--lamb', default=5000, type=float, required=False, 46 | help='Forgetting-intransigence trade-off (default=%(default)s)') 47 | # Define how old and new fisher is fused, by default it is a 50-50 fusion 48 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 49 | help='EWC alpha (default=%(default)s)') 50 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False, 51 | choices=['true', 'max_pred', 'multinomial'], 52 | help='Sampling type for Fisher information (default=%(default)s)') 53 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False, 54 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 55 | 56 | return parser.parse_known_args(args) 57 | 58 | def _get_optimizer(self): 59 | """Returns the optimizer""" 60 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 61 | # if there are no exemplars, previous heads are not modified 62 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 63 | else: 64 | params = self.model.parameters() 65 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 66 | 67 | def compute_fisher_matrix_diag(self, trn_loader, t): 68 | # Store Fisher Information 69 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() if p.requires_grad} 70 | # Compute fisher information for specified number of samples -- rounded to the batch size 71 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 72 | else (len(trn_loader.dataset) // trn_loader.batch_size) 73 | # Do forward and backward pass to compute the fisher information 74 | self.model.train() 75 | for images, targets in itertools.islice(trn_loader, n_samples_batches): 76 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * t)) 77 | self.model.apply(lambda m: setattr(m, 'width_mult', 1.0)) 78 | 79 | outputs = self.model.forward(images.to(self.device)) 80 | 81 | if self.sampling_type == 'true': 82 | # Use the labels to compute the gradients based on the CE-loss with the ground truth 83 | preds = targets.to(self.device) 84 | elif self.sampling_type == 'max_pred': 85 | # Not use labels and compute the gradients related to the prediction the model has learned 86 | preds = torch.cat(outputs, dim=1).argmax(1).flatten() 87 | elif self.sampling_type == 'multinomial': 88 | # Use a multinomial sampling to compute the gradients 89 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1) 90 | preds = torch.multinomial(probs, len(targets)).flatten() 91 | 92 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds) 93 | self.optimizer.zero_grad() 94 | loss.backward() 95 | # Accumulate all gradients from loss with regularization 96 | for n, p in self.model.model.named_parameters(): 97 | if p.grad is not None: 98 | fisher[n] += p.grad.pow(2) * len(targets) 99 | # Apply mean across all samples 100 | n_samples = n_samples_batches * trn_loader.batch_size 101 | fisher = {n: (p / n_samples) for n, p in fisher.items()} 102 | return fisher 103 | 104 | def train_loop(self, t, trn_loader, val_loader): 105 | """Contains the epochs loop""" 106 | 107 | # add exemplars to train_loader 108 | if len(self.exemplars_dataset) > 0 and t > 0: 109 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 110 | batch_size=trn_loader.batch_size, 111 | shuffle=True, 112 | num_workers=trn_loader.num_workers, 113 | pin_memory=trn_loader.pin_memory) 114 | 115 | # FINETUNING TRAINING -- contains the epochs loop 116 | super().train_loop(t, trn_loader, val_loader) 117 | 118 | # EXEMPLAR MANAGEMENT -- select training subset 119 | self.exemplars_dataset.collect_exemplars(self.model, t, trn_loader, val_loader.dataset.transform) 120 | 121 | def post_train_process(self, t, trn_loader): 122 | """Runs after training all the epochs of the task (after the train session)""" 123 | 124 | # Store current parameters for the next task 125 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 126 | 127 | # calculate Fisher information 128 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader, t) 129 | # merge fisher information, we do not want to keep fisher information for each task in memory 130 | for n in self.fisher.keys(): 131 | # Added option to accumulate fisher over time with a pre-fixed growing alpha 132 | if self.alpha == -1: 133 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 134 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n] 135 | else: 136 | # pdb.set_trace() 137 | self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n]) 138 | 139 | def criterion(self, t, outputs, targets): 140 | """Returns the loss value""" 141 | loss = 0 142 | if t > 0: 143 | loss_reg = 0 144 | # Eq. 3: elastic weight consolidation quadratic penalty 145 | for n, p in self.model.model.named_parameters(): 146 | if n in self.fisher.keys(): 147 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2 148 | loss += self.lamb * loss_reg 149 | # Current cross-entropy loss -- with exemplars use all heads 150 | if len(self.exemplars_dataset) > 0: 151 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 152 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 153 | -------------------------------------------------------------------------------- /src/approach/finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the finetuning baseline""" 10 | 11 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 12 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False, 13 | eval_on_train=False, logger=None, exemplars_dataset=None, all_outputs=False): 14 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 15 | multi_softmax, scroll_step, fix_bn, eval_on_train, logger, 16 | exemplars_dataset) 17 | self.all_out = all_outputs 18 | 19 | @staticmethod 20 | def exemplars_dataset_class(): 21 | return ExemplarsDataset 22 | 23 | @staticmethod 24 | def extra_parser(args): 25 | """Returns a parser containing the approach specific parameters""" 26 | parser = ArgumentParser() 27 | parser.add_argument('--all-outputs', action='store_true', required=False, 28 | help='Allow all weights related to all outputs to be modified (default=%(default)s)') 29 | return parser.parse_known_args(args) 30 | 31 | def _get_optimizer(self): 32 | """Returns the optimizer""" 33 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1 and not self.all_out: 34 | # if there are no exemplars, previous heads are not modified 35 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 36 | else: 37 | params = self.model.parameters() 38 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 39 | 40 | def train_loop(self, t, trn_loader, val_loader): 41 | """Contains the epochs loop""" 42 | 43 | # add exemplars to train_loader 44 | if len(self.exemplars_dataset) > 0 and t > 0: 45 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 46 | batch_size=trn_loader.batch_size, 47 | shuffle=True, 48 | num_workers=trn_loader.num_workers, 49 | pin_memory=trn_loader.pin_memory) 50 | 51 | # FINETUNING TRAINING -- contains the epochs loop 52 | super().train_loop(t, trn_loader, val_loader) 53 | 54 | # EXEMPLAR MANAGEMENT -- select training subset 55 | self.exemplars_dataset.collect_exemplars(self.model, t, trn_loader, val_loader.dataset.transform) 56 | 57 | def criterion(self, t, outputs, targets): 58 | """Returns the loss value""" 59 | if self.all_out or len(self.exemplars_dataset) > 0: 60 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 61 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 62 | -------------------------------------------------------------------------------- /src/approach/incremental_learning.py: -------------------------------------------------------------------------------- 1 | from sched import scheduler 2 | import time 3 | import torch 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | from widths.config import FLAGS 7 | from loggers.exp_logger import ExperimentLogger 8 | from datasets.exemplars_dataset import ExemplarsDataset 9 | 10 | class Inc_Learning_Appr: 11 | """Basic class for implementing incremental learning approaches""" 12 | 13 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 14 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False, 15 | eval_on_train=False, logger: ExperimentLogger = None, exemplars_dataset: ExemplarsDataset = None): 16 | self.model = model 17 | self.device = device 18 | self.nepochs = nepochs 19 | self.lr = lr 20 | self.decay_mile_stone = decay_mile_stone 21 | self.lr_decay = lr_decay 22 | self.clipgrad = clipgrad 23 | self.momentum = momentum 24 | self.wd = wd 25 | self.multi_softmax = multi_softmax 26 | self.logger = logger 27 | self.exemplars_dataset = exemplars_dataset 28 | self.fix_bn = fix_bn 29 | self.eval_on_train = eval_on_train 30 | self.scroll_step = scroll_step 31 | self.optimizer = None 32 | 33 | @staticmethod 34 | def extra_parser(args): 35 | """Returns a parser containing the approach specific parameters""" 36 | parser = ArgumentParser() 37 | return parser.parse_known_args(args) 38 | 39 | @staticmethod 40 | def exemplars_dataset_class(): 41 | """Returns a exemplar dataset to use during the training if the approach needs it 42 | :return: ExemplarDataset class or None 43 | """ 44 | return None 45 | 46 | def _get_optimizer(self): 47 | """Returns the optimizer""" 48 | return torch.optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 49 | 50 | def train(self, t, trn_loader, val_loader): 51 | """Main train structure""" 52 | self.train_loop(t, trn_loader, val_loader) 53 | self.post_train_process(t, trn_loader) 54 | 55 | def train_loop(self, t, trn_loader, val_loader): 56 | """Contains the epochs loop""" 57 | lr = self.lr 58 | self.optimizer = self._get_optimizer() 59 | scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.decay_mile_stone, gamma=self.lr_decay) 60 | 61 | # Loop epochs 62 | for e in range(self.nepochs): 63 | # Train 64 | clock0 = time.time() 65 | self.train_epoch(t, trn_loader) 66 | clock1 = time.time() 67 | if self.eval_on_train: 68 | train_loss, train_acc, _ = self.eval(t, trn_loader) 69 | clock2 = time.time() 70 | print('| Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format( 71 | e + 1, clock1 - clock0, clock2 - clock1, train_loss, 100 * train_acc), end='') 72 | self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=train_loss, group="train") 73 | self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * train_acc, group="train") 74 | else: 75 | print('| Epoch {:3d}, time={:5.1f}s | Train: skip eval |'.format(e + 1, clock1 - clock0), end='') 76 | 77 | # Valid 78 | clock3 = time.time() 79 | valid_loss, valid_acc, _ = self.eval(t, val_loader, t) 80 | clock4 = time.time() 81 | print(' Valid: time={:5.1f}s loss={:.3f} TAw acc={:5.1f}% |'.format( 82 | clock4 - clock3, valid_loss, 100 * valid_acc), end='') 83 | self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=valid_loss, group="valid") 84 | self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * valid_acc, group="valid") 85 | scheduler.step() 86 | print(' lr={:.1e}'.format(self.optimizer.param_groups[0]['lr']), end='') 87 | self.logger.log_scalar(task=t, iter=e + 1, name="lr", value=lr, group="train") 88 | print() 89 | 90 | def post_train_process(self, t, trn_loader): 91 | """Runs after training all the epochs of the task (after the train session)""" 92 | pass 93 | 94 | def train_epoch(self, t, trn_loader): 95 | """Runs a single epoch""" 96 | self.model.train() 97 | if self.fix_bn and t > 0: 98 | self.model.freeze_bn() 99 | for images, targets in trn_loader: 100 | # Forward current model 101 | total_loss = 0.0 102 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * t)) # change the perception of channels to shuffle 103 | for width_mult in sorted(FLAGS.width_mult_list, reverse=True): 104 | self.model.apply(lambda m: setattr(m, 'width_mult', width_mult)) 105 | outputs = self.model(images.to(self.device)) 106 | loss = self.criterion(t, outputs, targets.to(self.device)) 107 | total_loss += loss 108 | # Backward 109 | self.optimizer.zero_grad() 110 | total_loss.backward() 111 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 112 | self.optimizer.step() 113 | 114 | def eval(self, t, val_loader, real_t): 115 | """Contains the evaluation code""" 116 | width_max = max(FLAGS.width_mult_list) 117 | with torch.no_grad(): 118 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0 119 | self.model.eval() 120 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * real_t)) 121 | for images, targets in val_loader: 122 | # Forward current model 123 | self.model.apply(lambda m: setattr(m, 'width_mult', width_max)) 124 | outputs = self.model(images.to(self.device)) 125 | loss = self.criterion(t, outputs, targets.to(self.device)) 126 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets) 127 | # Log 128 | total_loss += loss.item() * len(targets) 129 | total_acc_taw += hits_taw.sum().item() 130 | total_acc_tag += hits_tag.sum().item() 131 | total_num += len(targets) 132 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num 133 | 134 | def calculate_metrics(self, outputs, targets): 135 | """Contains the main Task-Aware and Task-Agnostic metrics""" 136 | pred = torch.zeros_like(targets.to(self.device)) 137 | # Task-Aware Multi-Head 138 | for m in range(len(pred)): 139 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum() 140 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task] 141 | hits_taw = (pred == targets.to(self.device)).float() 142 | # Task-Agnostic Multi-Head 143 | if self.multi_softmax: 144 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs] 145 | pred = torch.cat(outputs, dim=1).argmax(1) 146 | else: 147 | pred = torch.cat(outputs, dim=1).argmax(1) 148 | hits_tag = (pred == targets.to(self.device)).float() 149 | return hits_taw, hits_tag 150 | 151 | def criterion(self, t, outputs, targets): 152 | """Returns the loss value""" 153 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 154 | -------------------------------------------------------------------------------- /src/approach/lwf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from argparse import ArgumentParser 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | from widths.config import FLAGS 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Learning Without Forgetting (LwF) approach 11 | described in https://arxiv.org/abs/1606.09282 12 | """ 13 | 14 | # Weight decay of 0.0005 is used in the original article (page 4). 15 | # Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our 16 | # method or the compared Less Forgetting Learning (see Table 2(b))." 17 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000, 18 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False, eval_on_train=False, 19 | logger=None, exemplars_dataset=None, lamb=1, T=2): 20 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd, 21 | multi_softmax, scroll_step, fix_bn, eval_on_train, logger, 22 | exemplars_dataset) 23 | self.model_old = None 24 | self.lamb = lamb 25 | self.T = T 26 | 27 | @staticmethod 28 | def exemplars_dataset_class(): 29 | return ExemplarsDataset 30 | 31 | @staticmethod 32 | def extra_parser(args): 33 | """Returns a parser containing the approach specific parameters""" 34 | parser = ArgumentParser() 35 | # Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor 36 | # the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by 37 | # changing lambda." 38 | parser.add_argument('--lamb', default=1, type=float, required=False, 39 | help='Forgetting-intransigence trade-off (default=%(default)s)') 40 | # Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’ 41 | # recommendations." -- Using a higher value for T produces a softer probability distribution over classes. 42 | parser.add_argument('--T', default=2, type=int, required=False, 43 | help='Temperature scaling (default=%(default)s)') 44 | return parser.parse_known_args(args) 45 | 46 | def _get_optimizer(self): 47 | """Returns the optimizer""" 48 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 49 | # if there are no exemplars, previous heads are not modified 50 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 51 | else: 52 | params = self.model.parameters() 53 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 54 | 55 | def train_loop(self, t, trn_loader, val_loader): 56 | """Contains the epochs loop""" 57 | 58 | # add exemplars to train_loader 59 | if len(self.exemplars_dataset) > 0 and t > 0: 60 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 61 | batch_size=trn_loader.batch_size, 62 | shuffle=True, 63 | num_workers=trn_loader.num_workers, 64 | pin_memory=trn_loader.pin_memory) 65 | 66 | # FINETUNING TRAINING -- contains the epochs loop 67 | super().train_loop(t, trn_loader, val_loader) 68 | 69 | # EXEMPLAR MANAGEMENT -- select training subset 70 | self.exemplars_dataset.collect_exemplars(self.model, t, trn_loader, val_loader.dataset.transform) 71 | 72 | def post_train_process(self, t, trn_loader): 73 | """Runs after training all the epochs of the task (after the train session)""" 74 | 75 | # Restore best and save model for future tasks 76 | self.model_old = deepcopy(self.model) 77 | self.model_old.eval() 78 | self.model_old.freeze_all() 79 | 80 | def train_epoch(self, t, trn_loader): 81 | """Runs a single epoch""" 82 | self.model.train() 83 | 84 | if self.fix_bn and t > 0: 85 | self.model.freeze_bn() 86 | for images, targets in trn_loader: 87 | # Forward old model 88 | targets_old = None 89 | if t > 0: 90 | self.model_old.apply(lambda m: setattr(m, 'scroll', self.scroll_step * (t-1))) 91 | width_mult = max(FLAGS.width_mult_list) 92 | self.model_old.apply(lambda m: setattr(m, 'width_mult', width_mult)) 93 | targets_old = self.model_old(images.to(self.device)) 94 | # Forward current model 95 | total_loss = 0.0 96 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * t)) 97 | for i, width_mult in enumerate(sorted(FLAGS.width_mult_list, reverse=True)): 98 | self.model.apply(lambda m: setattr(m, 'width_mult', width_mult)) 99 | outputs = self.model(images.to(self.device)) 100 | if i == 0: 101 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old) 102 | else: 103 | loss = self.criterion(t, outputs, targets.to(self.device)) 104 | total_loss += loss 105 | # Backward 106 | self.optimizer.zero_grad() 107 | total_loss.backward() 108 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 109 | self.optimizer.step() 110 | 111 | def eval(self, t, val_loader, real_t): 112 | """Contains the evaluation code""" 113 | width_max = max(FLAGS.width_mult_list) 114 | with torch.no_grad(): 115 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0 116 | self.model.eval() 117 | for images, targets in val_loader: 118 | # Forward old model 119 | targets_old = None 120 | if t > 0: 121 | self.model_old.apply(lambda m: setattr(m, 'scroll', self.scroll_step * (real_t-1))) 122 | self.model_old.apply(lambda m: setattr(m, 'width_mult', width_max)) 123 | targets_old = self.model_old(images.to(self.device)) 124 | # Forward current model 125 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * (real_t))) 126 | self.model.apply(lambda m: setattr(m, 'width_mult', width_max)) 127 | outputs = self.model(images.to(self.device)) 128 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old) 129 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets) 130 | # Log 131 | total_loss += loss.item() * len(targets) 132 | total_acc_taw += hits_taw.sum().item() 133 | total_acc_tag += hits_tag.sum().item() 134 | total_num += len(targets) 135 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num 136 | 137 | def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5): 138 | """Calculates cross-entropy with temperature scaling""" 139 | out = torch.nn.functional.softmax(outputs, dim=1) 140 | tar = torch.nn.functional.softmax(targets, dim=1) 141 | if exp != 1: 142 | out = out.pow(exp) 143 | out = out / out.sum(1).view(-1, 1).expand_as(out) 144 | tar = tar.pow(exp) 145 | tar = tar / tar.sum(1).view(-1, 1).expand_as(tar) 146 | out = out + eps / out.size(1) 147 | out = out / out.sum(1).view(-1, 1).expand_as(out) 148 | ce = -(tar * out.log()).sum(1) 149 | if size_average: 150 | ce = ce.mean() 151 | return ce 152 | 153 | def criterion(self, t, outputs, targets, outputs_old=None): 154 | """Returns the loss value""" 155 | loss = 0 156 | if t > 0 and outputs_old is not None: 157 | # Knowledge distillation loss for all previous tasks 158 | loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1), 159 | torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T) 160 | # Current cross-entropy loss -- with exemplars use all heads 161 | if len(self.exemplars_dataset) > 0: 162 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 163 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 164 | -------------------------------------------------------------------------------- /src/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class BaseDataset(Dataset): 9 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory""" 10 | 11 | def __init__(self, data, transform, class_indices=None): 12 | """Initialization""" 13 | self.labels = data['y'] 14 | self.images = data['x'] 15 | self.transform = transform 16 | self.class_indices = class_indices 17 | 18 | def __len__(self): 19 | """Denotes the total number of samples""" 20 | return len(self.images) 21 | 22 | def __getitem__(self, index): 23 | """Generates one sample of data""" 24 | x = Image.open(self.images[index]).convert('RGB') 25 | x = self.transform(x) 26 | y = self.labels[index] 27 | return x, y 28 | 29 | 30 | def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None): 31 | """Prepare data: dataset splits, task partition, class order""" 32 | 33 | data = {} 34 | taskcla = [] 35 | 36 | # read filenames and labels 37 | trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str) 38 | tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str) 39 | if class_order is None: 40 | num_classes = len(np.unique(trn_lines[:, 1])) 41 | class_order = list(range(num_classes)) 42 | else: 43 | num_classes = len(class_order) 44 | class_order = class_order.copy() 45 | if shuffle_classes: 46 | np.random.shuffle(class_order) 47 | 48 | # compute classes per task and num_tasks 49 | if nc_first_task is None: 50 | cpertask = np.array([num_classes // num_tasks] * num_tasks) 51 | for i in range(num_classes % num_tasks): 52 | cpertask[i] += 1 53 | else: 54 | assert nc_first_task < num_classes, "first task wants more classes than exist" 55 | remaining_classes = num_classes - nc_first_task 56 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2 57 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1)) 58 | for i in range(remaining_classes % (num_tasks - 1)): 59 | cpertask[i + 1] += 1 60 | 61 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes" 62 | cpertask_cumsum = np.cumsum(cpertask) 63 | init_class = np.concatenate(([0], cpertask_cumsum[:-1])) 64 | 65 | # initialize data structure 66 | for tt in range(num_tasks): 67 | data[tt] = {} 68 | data[tt]['name'] = 'task-' + str(tt) 69 | data[tt]['trn'] = {'x': [], 'y': []} 70 | data[tt]['val'] = {'x': [], 'y': []} 71 | data[tt]['tst'] = {'x': [], 'y': []} 72 | 73 | # ALL OR TRAIN 74 | for this_image, this_label in trn_lines: 75 | if not os.path.isabs(this_image): 76 | this_image = os.path.join(path, this_image) 77 | this_label = int(this_label) 78 | if this_label not in class_order: 79 | continue 80 | # If shuffling is false, it won't change the class number 81 | this_label = class_order.index(this_label) 82 | 83 | # add it to the corresponding split 84 | this_task = (this_label >= cpertask_cumsum).sum() 85 | data[this_task]['trn']['x'].append(this_image) 86 | data[this_task]['trn']['y'].append(this_label - init_class[this_task]) 87 | 88 | # ALL OR TEST 89 | for this_image, this_label in tst_lines: 90 | if not os.path.isabs(this_image): 91 | this_image = os.path.join(path, this_image) 92 | this_label = int(this_label) 93 | if this_label not in class_order: 94 | continue 95 | # If shuffling is false, it won't change the class number 96 | this_label = class_order.index(this_label) 97 | 98 | # add it to the corresponding split 99 | this_task = (this_label >= cpertask_cumsum).sum() 100 | data[this_task]['tst']['x'].append(this_image) 101 | data[this_task]['tst']['y'].append(this_label - init_class[this_task]) 102 | 103 | # check classes 104 | for tt in range(num_tasks): 105 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y'])) 106 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes" 107 | 108 | # validation 109 | if validation > 0.0: 110 | for tt in data.keys(): 111 | for cc in range(data[tt]['ncla']): 112 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 113 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation))) 114 | rnd_img.sort(reverse=True) 115 | for ii in range(len(rnd_img)): 116 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]]) 117 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]]) 118 | data[tt]['trn']['x'].pop(rnd_img[ii]) 119 | data[tt]['trn']['y'].pop(rnd_img[ii]) 120 | 121 | # other 122 | n = 0 123 | for t in data.keys(): 124 | taskcla.append((t, data[t]['ncla'])) 125 | n += data[t]['ncla'] 126 | data['ncla'] = n 127 | 128 | return data, taskcla, class_order 129 | -------------------------------------------------------------------------------- /src/datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils import data 4 | import torchvision.transforms as transforms 5 | from torchvision.datasets import MNIST as TorchVisionMNIST 6 | from torchvision.datasets import CIFAR100 as TorchVisionCIFAR100 7 | from torchvision.datasets import SVHN as TorchVisionSVHN 8 | 9 | from . import base_dataset as basedat 10 | from . import memory_dataset as memd 11 | from .dataset_config import dataset_config 12 | 13 | 14 | def get_loaders(datasets, num_tasks, nc_first_task, batch_size, num_workers, pin_memory, validation=.1): 15 | """Apply transformations to Datasets and create the DataLoaders for each task""" 16 | 17 | trn_load, val_load, tst_load = [], [], [] 18 | taskcla = [] 19 | dataset_offset = 0 20 | for idx_dataset, cur_dataset in enumerate(datasets, 0): 21 | # get configuration for current dataset 22 | dc = dataset_config[cur_dataset] 23 | 24 | # transformations 25 | trn_transform, tst_transform = get_transforms(resize=dc['resize'], 26 | pad=dc['pad'], 27 | crop=dc['crop'], 28 | flip=dc['flip'], 29 | normalize=dc['normalize'], 30 | extend_channel=dc['extend_channel']) 31 | 32 | # datasets 33 | trn_dset, val_dset, tst_dset, curtaskcla = get_datasets(cur_dataset, dc['path'], num_tasks, nc_first_task, 34 | validation=validation, 35 | trn_transform=trn_transform, 36 | tst_transform=tst_transform, 37 | class_order=dc['class_order']) 38 | 39 | # apply offsets in case of multiple datasets 40 | if idx_dataset > 0: 41 | for tt in range(num_tasks): 42 | trn_dset[tt].labels = [elem + dataset_offset for elem in trn_dset[tt].labels] 43 | val_dset[tt].labels = [elem + dataset_offset for elem in val_dset[tt].labels] 44 | tst_dset[tt].labels = [elem + dataset_offset for elem in tst_dset[tt].labels] 45 | dataset_offset = dataset_offset + sum([tc[1] for tc in curtaskcla]) 46 | 47 | # reassign class idx for multiple dataset case 48 | curtaskcla = [(tc[0] + idx_dataset * num_tasks, tc[1]) for tc in curtaskcla] 49 | 50 | # extend final taskcla list 51 | taskcla.extend(curtaskcla) 52 | 53 | # loaders 54 | for tt in range(num_tasks): 55 | trn_load.append(data.DataLoader(trn_dset[tt], batch_size=batch_size, shuffle=True, num_workers=num_workers, 56 | pin_memory=pin_memory)) 57 | val_load.append(data.DataLoader(val_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers, 58 | pin_memory=pin_memory)) 59 | tst_load.append(data.DataLoader(tst_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers, 60 | pin_memory=pin_memory)) 61 | return trn_load, val_load, tst_load, taskcla 62 | 63 | 64 | def get_datasets(dataset, path, num_tasks, nc_first_task, validation, trn_transform, tst_transform, class_order=None): 65 | """Extract datasets and create Dataset class""" 66 | 67 | trn_dset, val_dset, tst_dset = [], [], [] 68 | 69 | if 'mnist' in dataset: 70 | tvmnist_trn = TorchVisionMNIST(path, train=True, download=True) 71 | tvmnist_tst = TorchVisionMNIST(path, train=False, download=True) 72 | trn_data = {'x': tvmnist_trn.data.numpy(), 'y': tvmnist_trn.targets.tolist()} 73 | tst_data = {'x': tvmnist_tst.data.numpy(), 'y': tvmnist_tst.targets.tolist()} 74 | # compute splits 75 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 76 | num_tasks=num_tasks, nc_first_task=nc_first_task, 77 | shuffle_classes=class_order is None, class_order=class_order) 78 | # set dataset type 79 | Dataset = memd.MemoryDataset 80 | 81 | elif 'cifar100' in dataset: 82 | tvcifar_trn = TorchVisionCIFAR100(path, train=True, download=True) 83 | tvcifar_tst = TorchVisionCIFAR100(path, train=False, download=True) 84 | trn_data = {'x': tvcifar_trn.data, 'y': tvcifar_trn.targets} 85 | tst_data = {'x': tvcifar_tst.data, 'y': tvcifar_tst.targets} 86 | # compute splits 87 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 88 | num_tasks=num_tasks, nc_first_task=nc_first_task, 89 | shuffle_classes=class_order is None, class_order=class_order) 90 | # set dataset type 91 | Dataset = memd.MemoryDataset 92 | 93 | elif dataset == 'svhn': 94 | tvsvhn_trn = TorchVisionSVHN(path, split='train', download=True) 95 | tvsvhn_tst = TorchVisionSVHN(path, split='test', download=True) 96 | trn_data = {'x': tvsvhn_trn.data.transpose(0, 2, 3, 1), 'y': tvsvhn_trn.labels} 97 | tst_data = {'x': tvsvhn_tst.data.transpose(0, 2, 3, 1), 'y': tvsvhn_tst.labels} 98 | # Notice that SVHN in Torchvision has an extra training set in case needed 99 | # tvsvhn_xtr = TorchVisionSVHN(path, split='extra', download=True) 100 | # xtr_data = {'x': tvsvhn_xtr.data.transpose(0, 2, 3, 1), 'y': tvsvhn_xtr.labels} 101 | 102 | # compute splits 103 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 104 | num_tasks=num_tasks, nc_first_task=nc_first_task, 105 | shuffle_classes=class_order is None, class_order=class_order) 106 | # set dataset type 107 | Dataset = memd.MemoryDataset 108 | 109 | elif 'imagenet_32' in dataset: 110 | import pickle 111 | # load data 112 | x_trn, y_trn = [], [] 113 | for i in range(1, 11): 114 | with open(os.path.join(path, 'train_data_batch_{}'.format(i)), 'rb') as f: 115 | d = pickle.load(f) 116 | x_trn.append(d['data']) 117 | y_trn.append(np.array(d['labels']) - 1) # labels from 0 to 999 118 | with open(os.path.join(path, 'val_data'), 'rb') as f: 119 | d = pickle.load(f) 120 | x_trn.append(d['data']) 121 | y_tst = np.array(d['labels']) - 1 # labels from 0 to 999 122 | # reshape data 123 | for i, d in enumerate(x_trn, 0): 124 | x_trn[i] = d.reshape(d.shape[0], 3, 32, 32).transpose(0, 2, 3, 1) 125 | x_tst = x_trn[-1] 126 | x_trn = np.vstack(x_trn[:-1]) 127 | y_trn = np.concatenate(y_trn) 128 | trn_data = {'x': x_trn, 'y': y_trn} 129 | tst_data = {'x': x_tst, 'y': y_tst} 130 | # compute splits 131 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation, 132 | num_tasks=num_tasks, nc_first_task=nc_first_task, 133 | shuffle_classes=class_order is None, class_order=class_order) 134 | # set dataset type 135 | Dataset = memd.MemoryDataset 136 | 137 | else: 138 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs 139 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task, 140 | validation=validation, shuffle_classes=class_order is None, 141 | class_order=class_order) 142 | # set dataset type 143 | Dataset = basedat.BaseDataset 144 | 145 | # get datasets, apply correct label offsets for each task 146 | offset = 0 147 | for task in range(num_tasks): 148 | all_data[task]['trn']['y'] = [label + offset for label in all_data[task]['trn']['y']] 149 | all_data[task]['val']['y'] = [label + offset for label in all_data[task]['val']['y']] 150 | all_data[task]['tst']['y'] = [label + offset for label in all_data[task]['tst']['y']] 151 | trn_dset.append(Dataset(all_data[task]['trn'], trn_transform, class_indices)) 152 | val_dset.append(Dataset(all_data[task]['val'], tst_transform, class_indices)) 153 | tst_dset.append(Dataset(all_data[task]['tst'], tst_transform, class_indices)) 154 | offset += taskcla[task][1] 155 | 156 | return trn_dset, val_dset, tst_dset, taskcla 157 | 158 | 159 | def get_transforms(resize, pad, crop, flip, normalize, extend_channel): 160 | """Unpack transformations and apply to train or test splits""" 161 | 162 | trn_transform_list = [] 163 | tst_transform_list = [] 164 | 165 | # resize 166 | if resize is not None: 167 | trn_transform_list.append(transforms.Resize(resize)) 168 | tst_transform_list.append(transforms.Resize(resize)) 169 | 170 | # padding 171 | if pad is not None: 172 | trn_transform_list.append(transforms.Pad(pad)) 173 | tst_transform_list.append(transforms.Pad(pad)) 174 | 175 | # crop 176 | if crop is not None: 177 | trn_transform_list.append(transforms.RandomResizedCrop(crop)) 178 | tst_transform_list.append(transforms.CenterCrop(crop)) 179 | 180 | # flips 181 | if flip: 182 | trn_transform_list.append(transforms.RandomHorizontalFlip()) 183 | 184 | # to tensor 185 | trn_transform_list.append(transforms.ToTensor()) 186 | tst_transform_list.append(transforms.ToTensor()) 187 | 188 | # normalization 189 | if normalize is not None: 190 | trn_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1])) 191 | tst_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1])) 192 | 193 | # gray to rgb 194 | if extend_channel is not None: 195 | trn_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1))) 196 | tst_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1))) 197 | 198 | return transforms.Compose(trn_transform_list), \ 199 | transforms.Compose(tst_transform_list) 200 | -------------------------------------------------------------------------------- /src/datasets/dataset_config.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | _BASE_DATA_PATH = "../data" 4 | 5 | dataset_config = { 6 | 'mnist': { 7 | 'path': join(_BASE_DATA_PATH, 'mnist'), 8 | 'normalize': ((0.1307,), (0.3081,)), 9 | # Use the next 3 lines to use MNIST with a 3x32x32 input 10 | # 'extend_channel': 3, 11 | # 'pad': 2, 12 | # 'normalize': ((0.1,), (0.2752,)) # values including padding 13 | }, 14 | 'svhn': { 15 | 'path': join(_BASE_DATA_PATH, 'svhn'), 16 | 'resize': (224, 224), 17 | 'crop': None, 18 | 'flip': False, 19 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 20 | }, 21 | 'cifar100': { 22 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 23 | 'resize': None, 24 | 'pad': 4, 25 | 'crop': 32, 26 | 'flip': True, 27 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) 28 | }, 29 | 'cifar100_icarl': { 30 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 31 | 'resize': None, 32 | 'pad': 4, 33 | 'crop': 32, 34 | 'flip': True, 35 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)), 36 | 'class_order': [ 37 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 38 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 39 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 40 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 41 | ] 42 | }, 43 | 'vggface2': { 44 | 'path': join(_BASE_DATA_PATH, 'VGGFace2'), 45 | 'resize': 256, 46 | 'crop': 224, 47 | 'flip': True, 48 | 'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169)) 49 | }, 50 | 'imagenet_256': { 51 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'), 52 | 'resize': None, 53 | 'crop': 224, 54 | 'flip': True, 55 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 56 | }, 57 | 'imagenet_subset': { 58 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'), 59 | 'resize': None, 60 | 'crop': 224, 61 | 'flip': True, 62 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 63 | 'class_order': [ 64 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 65 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 66 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 67 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 68 | ] 69 | }, 70 | 'imagenet_32_reduced': { 71 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_32'), 72 | 'resize': None, 73 | 'pad': 4, 74 | 'crop': 32, 75 | 'flip': True, 76 | 'normalize': ((0.481, 0.457, 0.408), (0.260, 0.253, 0.268)), 77 | 'class_order': [ 78 | 472, 46, 536, 806, 547, 976, 662, 12, 955, 651, 492, 80, 999, 996, 788, 471, 911, 907, 680, 126, 42, 882, 79 | 327, 719, 716, 224, 918, 647, 808, 261, 140, 908, 833, 925, 57, 388, 407, 215, 45, 479, 525, 641, 915, 923, 80 | 108, 461, 186, 843, 115, 250, 829, 625, 769, 323, 974, 291, 438, 50, 825, 441, 446, 200, 162, 373, 872, 112, 81 | 212, 501, 91, 672, 791, 370, 942, 172, 315, 959, 636, 635, 66, 86, 197, 182, 59, 736, 175, 445, 947, 268, 82 | 238, 298, 926, 851, 494, 760, 61, 293, 696, 659, 69, 819, 912, 486, 706, 343, 390, 484, 282, 729, 575, 731, 83 | 530, 32, 534, 838, 466, 734, 425, 400, 290, 660, 254, 266, 551, 775, 721, 134, 886, 338, 465, 236, 522, 655, 84 | 209, 861, 88, 491, 985, 304, 981, 560, 405, 902, 521, 909, 763, 455, 341, 905, 280, 776, 113, 434, 274, 581, 85 | 158, 738, 671, 702, 147, 718, 148, 35, 13, 585, 591, 371, 745, 281, 956, 935, 346, 352, 284, 604, 447, 415, 86 | 98, 921, 118, 978, 880, 509, 381, 71, 552, 169, 600, 334, 171, 835, 798, 77, 249, 318, 419, 990, 335, 374, 87 | 949, 316, 755, 878, 946, 142, 299, 863, 558, 306, 183, 417, 64, 765, 565, 432, 440, 939, 297, 805, 364, 735, 88 | 251, 270, 493, 94, 773, 610, 278, 16, 363, 92, 15, 593, 96, 468, 252, 699, 377, 95, 799, 868, 820, 328, 756, 89 | 81, 991, 464, 774, 584, 809, 844, 940, 720, 498, 310, 384, 619, 56, 406, 639, 285, 67, 634, 792, 232, 54, 90 | 664, 818, 513, 349, 330, 207, 361, 345, 279, 549, 944, 817, 353, 228, 312, 796, 193, 179, 520, 451, 871, 91 | 692, 60, 481, 480, 929, 499, 673, 331, 506, 70, 645, 759, 744, 459] 92 | } 93 | } 94 | 95 | # Add missing keys: 96 | for dset in dataset_config.keys(): 97 | for k in ['resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel']: 98 | if k not in dataset_config[dset].keys(): 99 | dataset_config[dset][k] = None 100 | if 'flip' not in dataset_config[dset].keys(): 101 | dataset_config[dset]['flip'] = False 102 | -------------------------------------------------------------------------------- /src/datasets/exemplars_dataset.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from argparse import ArgumentParser 3 | 4 | from datasets.memory_dataset import MemoryDataset 5 | 6 | 7 | class ExemplarsDataset(MemoryDataset): 8 | """Exemplar storage for approaches with an interface of Dataset""" 9 | 10 | def __init__(self, transform, class_indices, 11 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'): 12 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices) 13 | self.max_num_exemplars_per_class = num_exemplars_per_class 14 | self.max_num_exemplars = num_exemplars 15 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!' 16 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize()) 17 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name) 18 | self.exemplars_selector = selector_cls(self) 19 | 20 | # Returns a parser containing the approach specific parameters 21 | @staticmethod 22 | def extra_parser(args): 23 | parser = ArgumentParser("Exemplars Management Parameters") 24 | _group = parser.add_mutually_exclusive_group() 25 | _group.add_argument('--num-exemplars', default=0, type=int, required=False, 26 | help='Fixed memory, total number of exemplars (default=%(default)s)') 27 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False, 28 | help='Growing memory, number of exemplars per class (default=%(default)s)') 29 | parser.add_argument('--exemplar-selection', default='random', type=str, 30 | choices=['herding', 'random', 'entropy', 'distance'], 31 | required=False, help='Exemplar selection strategy (default=%(default)s)') 32 | return parser.parse_known_args(args) 33 | 34 | def _is_active(self): 35 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0 36 | 37 | def collect_exemplars(self, model, t, trn_loader, selection_transform): 38 | if self._is_active(): 39 | self.images, self.labels = self.exemplars_selector(model, t, trn_loader, selection_transform) 40 | -------------------------------------------------------------------------------- /src/datasets/exemplars_selection.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from contextlib import contextmanager 4 | from typing import Iterable 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, ConcatDataset 9 | from torchvision.transforms import Lambda 10 | 11 | from datasets.exemplars_dataset import ExemplarsDataset 12 | from networks.network import LLL_Net 13 | 14 | 15 | class ExemplarsSelector: 16 | """Exemplar selector for approaches with an interface of Dataset""" 17 | 18 | def __init__(self, exemplars_dataset: ExemplarsDataset): 19 | self.exemplars_dataset = exemplars_dataset 20 | 21 | def __call__(self, model: LLL_Net, t: int, trn_loader: DataLoader, transform): 22 | clock0 = time.time() 23 | exemplars_per_class = self._exemplars_per_class_num(model) 24 | with override_dataset_transform(trn_loader.dataset, transform) as ds_for_selection: 25 | # change loader and fix to go sequentially (shuffle=False), keeps same order for later, eval transforms 26 | sel_loader = DataLoader(ds_for_selection, batch_size=trn_loader.batch_size, shuffle=False, 27 | num_workers=trn_loader.num_workers, pin_memory=trn_loader.pin_memory) 28 | selected_indices = self._select_indices(model, t, sel_loader, exemplars_per_class, transform) 29 | with override_dataset_transform(trn_loader.dataset, Lambda(lambda x: np.array(x))) as ds_for_raw: 30 | x, y = zip(*(ds_for_raw[idx] for idx in selected_indices)) 31 | clock1 = time.time() 32 | print('| Selected {:d} train exemplars, time={:5.1f}s'.format(len(x), clock1 - clock0)) 33 | return x, y 34 | 35 | def _exemplars_per_class_num(self, model: LLL_Net): 36 | if self.exemplars_dataset.max_num_exemplars_per_class: 37 | return self.exemplars_dataset.max_num_exemplars_per_class 38 | 39 | num_cls = model.task_cls.sum().item() 40 | num_exemplars = self.exemplars_dataset.max_num_exemplars 41 | exemplars_per_class = int(np.ceil(num_exemplars / num_cls)) 42 | assert exemplars_per_class > 0, \ 43 | "Not enough exemplars to cover all classes!\n" \ 44 | "Number of classes so far: {}. " \ 45 | "Limit of exemplars: {}".format(num_cls, 46 | num_exemplars) 47 | return exemplars_per_class 48 | 49 | def _select_indices(self, model: LLL_Net, t: int, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 50 | pass 51 | 52 | 53 | class RandomExemplarsSelector(ExemplarsSelector): 54 | """Selection of new samples. This is based on random selection, which produces a random list of samples.""" 55 | 56 | def __init__(self, exemplars_dataset): 57 | super().__init__(exemplars_dataset) 58 | 59 | def _select_indices(self, model: LLL_Net, t: int, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 60 | num_cls = sum(model.task_cls) 61 | result = [] 62 | labels = self._get_labels(sel_loader) 63 | for curr_cls in range(num_cls): 64 | # get all indices from current class -- check if there are exemplars from previous task in the loader 65 | cls_ind = np.where(labels == curr_cls)[0] 66 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 67 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 68 | # select the exemplars randomly 69 | result.extend(random.sample(list(cls_ind), exemplars_per_class)) 70 | return result 71 | 72 | def _get_labels(self, sel_loader): 73 | if hasattr(sel_loader.dataset, 'labels'): # BaseDataset, MemoryDataset 74 | labels = np.asarray(sel_loader.dataset.labels) 75 | elif isinstance(sel_loader.dataset, ConcatDataset): 76 | labels = [] 77 | for ds in sel_loader.dataset.datasets: 78 | labels.extend(ds.labels) 79 | labels = np.array(labels) 80 | else: 81 | raise RuntimeError("Unsupported dataset: {}".format(sel_loader.dataset.__class__.__name__)) 82 | return labels 83 | 84 | 85 | class HerdingExemplarsSelector(ExemplarsSelector): 86 | """Selection of new samples. This is based on herding selection, which produces a sorted list of samples of one 87 | class based on the distance to the mean sample of that class. From iCaRL algorithm 4 and 5: 88 | https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf 89 | """ 90 | def __init__(self, exemplars_dataset): 91 | super().__init__(exemplars_dataset) 92 | 93 | def _select_indices(self, model: LLL_Net, t: int, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 94 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device 95 | 96 | # extract outputs from the model for all train samples 97 | extracted_features = [] 98 | extracted_targets = [] 99 | with torch.no_grad(): 100 | model.eval() 101 | for images, targets in sel_loader: 102 | model.apply(lambda m: setattr(m, 'shift', 1*t)) 103 | model.apply(lambda m: setattr(m, 'width_mult', 1.0)) 104 | feats = model(images.to(model_device), return_features=True)[1] 105 | feats = feats / feats.norm(dim=1).view(-1, 1) # Feature normalization 106 | extracted_features.append(feats) 107 | extracted_targets.extend(targets) 108 | extracted_features = (torch.cat(extracted_features)).cpu() 109 | extracted_targets = np.array(extracted_targets) 110 | result = [] 111 | # iterate through all classes 112 | for curr_cls in np.unique(extracted_targets): 113 | # get all indices from current class 114 | cls_ind = np.where(extracted_targets == curr_cls)[0] 115 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 116 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 117 | # get all extracted features for current class 118 | cls_feats = extracted_features[cls_ind] 119 | # calculate the mean 120 | cls_mu = cls_feats.mean(0) 121 | # select the exemplars closer to the mean of each class 122 | selected = [] 123 | selected_feat = [] 124 | for k in range(exemplars_per_class): 125 | # fix this to the dimension of the model features 126 | sum_others = torch.zeros(cls_feats.shape[1]) 127 | for j in selected_feat: 128 | sum_others += j / (k + 1) 129 | dist_min = np.inf 130 | # choose the closest to the mean of the current class 131 | for item in cls_ind: 132 | if item not in selected: 133 | feat = extracted_features[item] 134 | dist = torch.norm(cls_mu - feat / (k + 1) - sum_others) 135 | if dist < dist_min: 136 | dist_min = dist 137 | newone = item 138 | newonefeat = feat 139 | selected_feat.append(newonefeat) 140 | selected.append(newone) 141 | result.extend(selected) 142 | return result 143 | 144 | 145 | class EntropyExemplarsSelector(ExemplarsSelector): 146 | """Selection of new samples. This is based on entropy selection, which produces a sorted list of samples of one 147 | class based on entropy of each sample. From RWalk http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112 148 | """ 149 | def __init__(self, exemplars_dataset): 150 | super().__init__(exemplars_dataset) 151 | 152 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable: 153 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device 154 | 155 | # extract outputs from the model for all train samples 156 | extracted_logits = [] 157 | extracted_targets = [] 158 | with torch.no_grad(): 159 | model.eval() 160 | for images, targets in sel_loader: 161 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1)) 162 | extracted_targets.extend(targets) 163 | extracted_logits = (torch.cat(extracted_logits)).cpu() 164 | extracted_targets = np.array(extracted_targets) 165 | result = [] 166 | # iterate through all classes 167 | for curr_cls in np.unique(extracted_targets): 168 | # get all indices from current class 169 | cls_ind = np.where(extracted_targets == curr_cls)[0] 170 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 171 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 172 | # get all extracted features for current class 173 | cls_logits = extracted_logits[cls_ind] 174 | # select the exemplars with higher entropy (lower: -entropy) 175 | probs = torch.softmax(cls_logits, dim=1) 176 | log_probs = torch.log(probs) 177 | minus_entropy = (probs * log_probs).sum(1) # change sign of this variable for inverse order 178 | selected = cls_ind[minus_entropy.sort()[1][:exemplars_per_class]] 179 | result.extend(selected) 180 | return result 181 | 182 | 183 | class DistanceExemplarsSelector(ExemplarsSelector): 184 | """Selection of new samples. This is based on distance-based selection, which produces a sorted list of samples of 185 | one class based on closeness to decision boundary of each sample. From RWalk 186 | http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112 187 | """ 188 | def __init__(self, exemplars_dataset): 189 | super().__init__(exemplars_dataset) 190 | 191 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, 192 | transform) -> Iterable: 193 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device 194 | 195 | # extract outputs from the model for all train samples 196 | extracted_logits = [] 197 | extracted_targets = [] 198 | with torch.no_grad(): 199 | model.eval() 200 | for images, targets in sel_loader: 201 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1)) 202 | extracted_targets.extend(targets) 203 | extracted_logits = (torch.cat(extracted_logits)).cpu() 204 | extracted_targets = np.array(extracted_targets) 205 | result = [] 206 | # iterate through all classes 207 | for curr_cls in np.unique(extracted_targets): 208 | # get all indices from current class 209 | cls_ind = np.where(extracted_targets == curr_cls)[0] 210 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls) 211 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store" 212 | # get all extracted features for current class 213 | cls_logits = extracted_logits[cls_ind] 214 | # select the exemplars closer to boundary 215 | distance = cls_logits[:, curr_cls] # change sign of this variable for inverse order 216 | selected = cls_ind[distance.sort()[1][:exemplars_per_class]] 217 | result.extend(selected) 218 | return result 219 | 220 | 221 | def dataset_transforms(dataset, transform_to_change): 222 | if isinstance(dataset, ConcatDataset): 223 | r = [] 224 | for ds in dataset.datasets: 225 | r += dataset_transforms(ds, transform_to_change) 226 | return r 227 | else: 228 | old_transform = dataset.transform 229 | dataset.transform = transform_to_change 230 | return [(dataset, old_transform)] 231 | 232 | 233 | @contextmanager 234 | def override_dataset_transform(dataset, transform): 235 | try: 236 | datasets_with_orig_transform = dataset_transforms(dataset, transform) 237 | yield dataset 238 | finally: 239 | # get bac original transformations 240 | for ds, orig_transform in datasets_with_orig_transform: 241 | ds.transform = orig_transform 242 | -------------------------------------------------------------------------------- /src/datasets/memory_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MemoryDataset(Dataset): 8 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory""" 9 | 10 | def __init__(self, data, transform, class_indices=None): 11 | """Initialization""" 12 | self.labels = data['y'] 13 | self.images = data['x'] 14 | self.transform = transform 15 | self.class_indices = class_indices 16 | 17 | def __len__(self): 18 | """Denotes the total number of samples""" 19 | return len(self.images) 20 | 21 | def __getitem__(self, index): 22 | """Generates one sample of data""" 23 | x = Image.fromarray(self.images[index]) 24 | x = self.transform(x) 25 | y = self.labels[index] 26 | return x, y 27 | 28 | 29 | def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None): 30 | """Prepare data: dataset splits, task partition, class order""" 31 | 32 | data = {} 33 | taskcla = [] 34 | if class_order is None: 35 | num_classes = len(np.unique(trn_data['y'])) 36 | class_order = list(range(num_classes)) 37 | else: 38 | num_classes = len(class_order) 39 | class_order = class_order.copy() 40 | if shuffle_classes: 41 | np.random.shuffle(class_order) 42 | 43 | # compute classes per task and num_tasks 44 | if nc_first_task is None: 45 | cpertask = np.array([num_classes // num_tasks] * num_tasks) 46 | for i in range(num_classes % num_tasks): 47 | cpertask[i] += 1 48 | else: 49 | assert nc_first_task < num_classes, "first task wants more classes than exist" 50 | remaining_classes = num_classes - nc_first_task 51 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2 52 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1)) 53 | for i in range(remaining_classes % (num_tasks - 1)): 54 | cpertask[i + 1] += 1 55 | 56 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes" 57 | cpertask_cumsum = np.cumsum(cpertask) 58 | init_class = np.concatenate(([0], cpertask_cumsum[:-1])) 59 | 60 | # initialize data structure 61 | for tt in range(num_tasks): 62 | data[tt] = {} 63 | data[tt]['name'] = 'task-' + str(tt) 64 | data[tt]['trn'] = {'x': [], 'y': []} 65 | data[tt]['val'] = {'x': [], 'y': []} 66 | data[tt]['tst'] = {'x': [], 'y': []} 67 | 68 | # ALL OR TRAIN 69 | filtering = np.isin(trn_data['y'], class_order) 70 | if filtering.sum() != len(trn_data['y']): 71 | trn_data['x'] = trn_data['x'][filtering] 72 | trn_data['y'] = np.array(trn_data['y'])[filtering] 73 | for this_image, this_label in zip(trn_data['x'], trn_data['y']): 74 | # If shuffling is false, it won't change the class number 75 | this_label = class_order.index(this_label) 76 | # add it to the corresponding split 77 | this_task = (this_label >= cpertask_cumsum).sum() 78 | data[this_task]['trn']['x'].append(this_image) 79 | data[this_task]['trn']['y'].append(this_label - init_class[this_task]) 80 | 81 | # ALL OR TEST 82 | filtering = np.isin(tst_data['y'], class_order) 83 | if filtering.sum() != len(tst_data['y']): 84 | tst_data['x'] = tst_data['x'][filtering] 85 | tst_data['y'] = tst_data['y'][filtering] 86 | for this_image, this_label in zip(tst_data['x'], tst_data['y']): 87 | # If shuffling is false, it won't change the class number 88 | this_label = class_order.index(this_label) 89 | # add it to the corresponding split 90 | this_task = (this_label >= cpertask_cumsum).sum() 91 | data[this_task]['tst']['x'].append(this_image) 92 | data[this_task]['tst']['y'].append(this_label - init_class[this_task]) 93 | 94 | # check classes 95 | for tt in range(num_tasks): 96 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y'])) 97 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes" 98 | 99 | # validation 100 | if validation > 0.0: 101 | for tt in data.keys(): 102 | for cc in range(data[tt]['ncla']): 103 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 104 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation))) 105 | rnd_img.sort(reverse=True) 106 | for ii in range(len(rnd_img)): 107 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]]) 108 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]]) 109 | data[tt]['trn']['x'].pop(rnd_img[ii]) 110 | data[tt]['trn']['y'].pop(rnd_img[ii]) 111 | 112 | # convert them to numpy arrays 113 | for tt in data.keys(): 114 | for split in ['trn', 'val', 'tst']: 115 | data[tt][split]['x'] = np.asarray(data[tt][split]['x']) 116 | 117 | # other 118 | n = 0 119 | for t in data.keys(): 120 | taskcla.append((t, data[t]['ncla'])) 121 | n += data[t]['ncla'] 122 | data['ncla'] = n 123 | 124 | return data, taskcla, class_order 125 | -------------------------------------------------------------------------------- /src/loggers/README.md: -------------------------------------------------------------------------------- 1 | # Loggers 2 | 3 | We include a disk logger, which logs into files and folders in the disk. We also provide a tensorboard logger which 4 | provides a faster way of analysing a training process without need of further development. They can be specified with 5 | `--log` followed by `disk`, `tensorboard` or both. Custom loggers can be defined by inheriting the `ExperimentLogger` 6 | in [exp_logger.py](exp_logger.py). 7 | 8 | When enabled, both loggers will output everything in the path `[RESULTS_PATH]/[DATASETS]_[APPROACH]_[EXP_NAME]` or 9 | `[RESULTS_PATH]/[DATASETS]_[APPROACH]` if `--exp-name` is not set. 10 | 11 | ## Disk logger 12 | The disk logger outputs the following file and folder structure: 13 | - **figures/**: folder where generated figures are logged. 14 | - **models/**: folder where model weight checkpoints are saved. 15 | - **results/**: folder containing the results. 16 | - **acc_tag**: task-agnostic accuracy table. 17 | - **acc_taw**: task-aware accuracy table. 18 | - **avg_acc_tag**: task-agnostic average accuracies. 19 | - **avg_acc_taw**: task-agnostic average accuracies. 20 | - **forg_tag**: task-agnostic forgetting table. 21 | - **forg_taw**: task-aware forgetting table. 22 | - **wavg_acc_tag**: task-agnostic average accuracies weighted according to the number of classes of each task. 23 | - **wavg_acc_taw**: task-aware average accuracies weighted according to the number of classes of each task. 24 | - **raw_log**: json file containing all the logged metrics easily read by many tools (e.g. `pandas`). 25 | - stdout: a copy from the standard output of the terminal. 26 | - stderr: a copy from the error output of the terminal. 27 | 28 | ## TensorBoard logger 29 | The tensorboard logger outputs analogous metrics to the disk logger separated into different tabs according to the task 30 | and different graphs according to the data splits. 31 | 32 | Screenshot for a 10 task experiment, showing the last task plots: 33 |

34 | Tensorboard Screenshot 35 |

36 | -------------------------------------------------------------------------------- /src/loggers/__pycache__/disk_logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/disk_logger.cpython-37.pyc -------------------------------------------------------------------------------- /src/loggers/__pycache__/disk_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/disk_logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/loggers/__pycache__/exp_logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/exp_logger.cpython-37.pyc -------------------------------------------------------------------------------- /src/loggers/__pycache__/exp_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/exp_logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/loggers/disk_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | from loggers.exp_logger import ExperimentLogger 9 | 10 | 11 | class Logger(ExperimentLogger): 12 | """Characterizes a disk logger""" 13 | 14 | def __init__(self, log_path, exp_name, begin_time=None): 15 | super(Logger, self).__init__(log_path, exp_name, begin_time) 16 | 17 | self.begin_time_str = self.begin_time.strftime("%Y-%m-%d-%H-%M") 18 | 19 | # Duplicate standard outputs 20 | sys.stdout = FileOutputDuplicator(sys.stdout, 21 | os.path.join(self.exp_path, 'stdout-{}.txt'.format(self.begin_time_str)), 'w') 22 | sys.stderr = FileOutputDuplicator(sys.stderr, 23 | os.path.join(self.exp_path, 'stderr-{}.txt'.format(self.begin_time_str)), 'w') 24 | 25 | # Raw log file 26 | self.raw_log_file = open(os.path.join(self.exp_path, "raw_log-{}.txt".format(self.begin_time_str)), 'a') 27 | 28 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 29 | if curtime is None: 30 | curtime = datetime.now() 31 | 32 | # Raw dump 33 | entry = {"task": task, "iter": iter, "name": name, "value": value, "group": group, 34 | "time": curtime.strftime("%Y-%m-%d-%H-%M")} 35 | self.raw_log_file.write(json.dumps(entry, sort_keys=True) + "\n") 36 | self.raw_log_file.flush() 37 | 38 | def log_args(self, args): 39 | with open(os.path.join(self.exp_path, 'args-{}.txt'.format(self.begin_time_str)), 'w') as f: 40 | json.dump(args.__dict__, f, separators=(',\n', ' : '), sort_keys=True) 41 | 42 | def log_result(self, array, name, step): 43 | if array.ndim <= 1: 44 | array = array[None] 45 | np.savetxt(os.path.join(self.exp_path, 'results', '{}-{}.txt'.format(name, self.begin_time_str)), 46 | array, '%.6f', delimiter='\t') 47 | 48 | def log_figure(self, name, iter, figure, curtime=None): 49 | curtime = datetime.now() 50 | figure.savefig(os.path.join(self.exp_path, 'figures', 51 | '{}_{}-{}.png'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S")))) 52 | figure.savefig(os.path.join(self.exp_path, 'figures', 53 | '{}_{}-{}.pdf'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S")))) 54 | 55 | def save_model(self, state_dict, task): 56 | torch.save(state_dict, os.path.join(self.exp_path, "models", "task{}.ckpt".format(task))) 57 | 58 | def __del__(self): 59 | self.raw_log_file.close() 60 | 61 | 62 | class FileOutputDuplicator(object): 63 | def __init__(self, duplicate, fname, mode): 64 | self.file = open(fname, mode) 65 | self.duplicate = duplicate 66 | 67 | def __del__(self): 68 | self.file.close() 69 | 70 | def write(self, data): 71 | self.file.write(data) 72 | self.duplicate.write(data) 73 | 74 | def flush(self): 75 | self.file.flush() 76 | -------------------------------------------------------------------------------- /src/loggers/exp_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from datetime import datetime 4 | 5 | 6 | class ExperimentLogger: 7 | """Main class for experiment logging""" 8 | 9 | def __init__(self, log_path, exp_name, begin_time=None): 10 | self.log_path = log_path 11 | self.exp_name = exp_name 12 | self.exp_path = os.path.join(log_path, exp_name) 13 | if begin_time is None: 14 | self.begin_time = datetime.now() 15 | else: 16 | self.begin_time = begin_time 17 | 18 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 19 | pass 20 | 21 | def log_args(self, args): 22 | pass 23 | 24 | def log_result(self, array, name, step): 25 | pass 26 | 27 | def log_figure(self, name, iter, figure, curtime=None): 28 | pass 29 | 30 | def save_model(self, state_dict, task): 31 | pass 32 | 33 | 34 | class MultiLogger(ExperimentLogger): 35 | """This class allows to use multiple loggers""" 36 | 37 | def __init__(self, log_path, exp_name, loggers=None, save_models=True): 38 | super(MultiLogger, self).__init__(log_path, exp_name) 39 | if os.path.exists(self.exp_path): 40 | print("WARNING: {} already exists!".format(self.exp_path)) 41 | else: 42 | os.makedirs(os.path.join(self.exp_path, 'models')) 43 | os.makedirs(os.path.join(self.exp_path, 'results')) 44 | os.makedirs(os.path.join(self.exp_path, 'figures')) 45 | 46 | self.save_models = save_models 47 | self.loggers = [] 48 | for l in loggers: 49 | lclass = getattr(importlib.import_module(name='loggers.' + l + '_logger'), 'Logger') 50 | self.loggers.append(lclass(self.log_path, self.exp_name)) 51 | 52 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 53 | if curtime is None: 54 | curtime = datetime.now() 55 | for l in self.loggers: 56 | l.log_scalar(task, iter, name, value, group, curtime) 57 | 58 | def log_args(self, args): 59 | for l in self.loggers: 60 | l.log_args(args) 61 | 62 | def log_result(self, array, name, step): 63 | for l in self.loggers: 64 | l.log_result(array, name, step) 65 | 66 | def log_figure(self, name, iter, figure, curtime=None): 67 | if curtime is None: 68 | curtime = datetime.now() 69 | for l in self.loggers: 70 | l.log_figure(name, iter, figure, curtime) 71 | 72 | def save_model(self, state_dict, task): 73 | if self.save_models: 74 | for l in self.loggers: 75 | l.save_model(state_dict, task) 76 | -------------------------------------------------------------------------------- /src/loggers/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | from loggers.exp_logger import ExperimentLogger 4 | import json 5 | import numpy as np 6 | 7 | 8 | class Logger(ExperimentLogger): 9 | """Characterizes a Tensorboard logger""" 10 | 11 | def __init__(self, log_path, exp_name, begin_time=None): 12 | super(Logger, self).__init__(log_path, exp_name, begin_time) 13 | self.tbwriter = SummaryWriter(self.exp_path) 14 | 15 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 16 | self.tbwriter.add_scalar(tag="t{}/{}_{}".format(task, group, name), 17 | scalar_value=value, 18 | global_step=iter) 19 | self.tbwriter.file_writer.flush() 20 | 21 | def log_figure(self, name, iter, figure, curtime=None): 22 | self.tbwriter.add_figure(tag=name, figure=figure, global_step=iter) 23 | self.tbwriter.file_writer.flush() 24 | 25 | def log_args(self, args): 26 | self.tbwriter.add_text( 27 | 'args', 28 | json.dumps(args.__dict__, 29 | separators=(',\n', ' : '), 30 | sort_keys=True)) 31 | self.tbwriter.file_writer.flush() 32 | 33 | def log_result(self, array, name, step): 34 | if array.ndim == 1: 35 | # log as scalars 36 | self.tbwriter.add_scalar(f'results/{name}', array[step], step) 37 | 38 | elif array.ndim == 2: 39 | s = "" 40 | i = step 41 | # for i in range(array.shape[0]): 42 | for j in range(array.shape[1]): 43 | s += '{:5.1f}% '.format(100 * array[i, j]) 44 | if np.trace(array) == 0.0: 45 | if i > 0: 46 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i].mean()) 47 | else: 48 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i + 1].mean()) 49 | self.tbwriter.add_text(f'results/{name}', s, step) 50 | 51 | def __del__(self): 52 | self.tbwriter.close() 53 | -------------------------------------------------------------------------------- /src/main_incremental.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | import importlib 6 | import numpy as np 7 | from functools import reduce 8 | 9 | import utils 10 | import approach 11 | from loggers.exp_logger import MultiLogger 12 | from datasets.data_loader import get_loaders 13 | from datasets.dataset_config import dataset_config 14 | from networks import allmodels 15 | 16 | def main(argv=None): 17 | tstart = time.time() 18 | # Arguments 19 | parser = argparse.ArgumentParser(description='ScrollNet: Dynamic Weight Importance for Continual Learning') 20 | 21 | # miscellaneous args 22 | parser.add_argument('--gpu', type=int, default=0, 23 | help='GPU (default=%(default)s)') 24 | parser.add_argument('--results-path', type=str, default='../results', 25 | help='Results path (default=%(default)s)') 26 | parser.add_argument('--exp-name', default=None, type=str, 27 | help='Experiment name (default=%(default)s)') 28 | parser.add_argument('--seed', type=int, default=0, 29 | help='Random seed (default=%(default)s)') 30 | parser.add_argument('--log', default=['disk'], type=str, choices=['disk', 'tensorboard'], 31 | help='Loggers used (disk, tensorboard) (default=%(default)s)', nargs='*', metavar="LOGGER") 32 | parser.add_argument('--save-models', action='store_true', 33 | help='Save trained models (default=%(default)s)') 34 | parser.add_argument('--no-cudnn-deterministic', action='store_true', 35 | help='Disable CUDNN deterministic (default=%(default)s)') 36 | # dataset args 37 | parser.add_argument('--datasets', default=['cifar100'], type=str, choices=list(dataset_config.keys()), 38 | help='Dataset or datasets used (default=%(default)s)', nargs='+', metavar="DATASET") 39 | parser.add_argument('--num-workers', default=4, type=int, required=False, 40 | help='Number of subprocesses to use for dataloader (default=%(default)s)') 41 | parser.add_argument('--pin-memory', default=False, type=bool, required=False, 42 | help='Copy Tensors into CUDA pinned memory before returning them (default=%(default)s)') 43 | parser.add_argument('--batch-size', default=64, type=int, required=False, 44 | help='Number of samples per batch to load (default=%(default)s)') 45 | parser.add_argument('--num-tasks', default=10, type=int, required=False, 46 | help='Number of tasks per dataset (default=%(default)s)') 47 | parser.add_argument('--nc-first-task', default=None, type=int, required=False, 48 | help='Number of classes of the first task (default=%(default)s)') 49 | parser.add_argument('--use-valid-only', action='store_true', 50 | help='Use validation split instead of test (default=%(default)s)') 51 | parser.add_argument('--stop-at-task', default=0, type=int, required=False, 52 | help='Stop training after specified task (default=%(default)s)') 53 | # model args 54 | parser.add_argument('--network', default='scroll_resnet18', type=str, choices=allmodels, 55 | help='Network architecture used (default=%(default)s)', metavar="NETWORK") 56 | parser.add_argument('--keep-existing-head', action='store_true', 57 | help='Disable removing classifier last layer (default=%(default)s)') 58 | parser.add_argument('--pretrained', action='store_true', 59 | help='Use pretrained backbone (default=%(default)s)') 60 | # training args 61 | parser.add_argument('--approach', default='finetuning', type=str, choices=approach.__all__, 62 | help='Learning approach used (default=%(default)s)', metavar="APPROACH") 63 | parser.add_argument('--nepochs', default=200, type=int, required=False, 64 | help='Number of epochs per training session (default=%(default)s)') 65 | parser.add_argument('--lr', default=0.1, type=float, required=False, 66 | help='Starting learning rate (default=%(default)s)') 67 | parser.add_argument('--decay-mile-stone', nargs='+', type=int, 68 | help='mile stone of learning rate decay') 69 | parser.add_argument('--lr-decay', type=float, default=0.1, 70 | help='ratio of learning rate decay') 71 | parser.add_argument('--clipping', default=10000, type=float, required=False, 72 | help='Clip gradient norm (default=%(default)s)') 73 | parser.add_argument('--momentum', default=0.0, type=float, required=False, 74 | help='Momentum factor (default=%(default)s)') 75 | parser.add_argument('--weight-decay', default=0.0, type=float, required=False, 76 | help='Weight decay (L2 penalty) (default=%(default)s)') 77 | parser.add_argument('--multi-softmax', action='store_true', 78 | help='Apply separate softmax for each task (default=%(default)s)') 79 | parser.add_argument('--fix-bn', action='store_true', 80 | help='Fix batch normalization after first task (default=%(default)s)') 81 | parser.add_argument('--eval-on-train', action='store_true', 82 | help='Show train loss and accuracy (default=%(default)s)') 83 | # scrolling args 84 | parser.add_argument('--scroll_step', default=1, type=int, 85 | help='Scrolling step size.') 86 | 87 | # Args -- Incremental Learning Framework 88 | args, extra_args = parser.parse_known_args(argv) 89 | args.results_path = os.path.expanduser(args.results_path) 90 | base_kwargs = dict(nepochs=args.nepochs, lr=args.lr, clipgrad=args.clipping, momentum=args.momentum, 91 | wd=args.weight_decay, multi_softmax=args.multi_softmax, scroll_step=args.scroll_step, 92 | fix_bn=args.fix_bn, eval_on_train=args.eval_on_train) 93 | 94 | if args.no_cudnn_deterministic: 95 | print('WARNING: CUDNN Deterministic will be disabled.') 96 | utils.cudnn_deterministic = False 97 | 98 | utils.seed_everything(seed=args.seed) 99 | print('=' * 108) 100 | print('Arguments =') 101 | for arg in np.sort(list(vars(args).keys())): 102 | print('\t' + arg + ':', getattr(args, arg)) 103 | print('=' * 108) 104 | 105 | # Args -- CUDA 106 | if torch.cuda.is_available(): 107 | torch.cuda.set_device(args.gpu) 108 | device = 'cuda' 109 | else: 110 | print('WARNING: [CUDA unavailable] Using CPU instead!') 111 | device = 'cpu' 112 | 113 | # Args -- Network 114 | from networks.network import LLL_Net 115 | net = getattr(importlib.import_module(name='networks'), args.network) 116 | init_model = net(pretrained=False) 117 | 118 | # Args -- Continual Learning Approach 119 | from approach.incremental_learning import Inc_Learning_Appr 120 | Appr = getattr(importlib.import_module(name='approach.' + args.approach), 'Appr') 121 | assert issubclass(Appr, Inc_Learning_Appr) 122 | appr_args, extra_args = Appr.extra_parser(extra_args) 123 | print('Approach arguments =') 124 | for arg in np.sort(list(vars(appr_args).keys())): 125 | print('\t' + arg + ':', getattr(appr_args, arg)) 126 | print('=' * 108) 127 | 128 | # Args -- Exemplars Management 129 | from datasets.exemplars_dataset import ExemplarsDataset 130 | Appr_ExemplarsDataset = Appr.exemplars_dataset_class() 131 | if Appr_ExemplarsDataset: 132 | assert issubclass(Appr_ExemplarsDataset, ExemplarsDataset) 133 | appr_exemplars_dataset_args, extra_args = Appr_ExemplarsDataset.extra_parser(extra_args) 134 | print('Exemplars dataset arguments =') 135 | for arg in np.sort(list(vars(appr_exemplars_dataset_args).keys())): 136 | print('\t' + arg + ':', getattr(appr_exemplars_dataset_args, arg)) 137 | print('=' * 108) 138 | else: 139 | appr_exemplars_dataset_args = argparse.Namespace() 140 | 141 | # Log all arguments 142 | full_exp_name = reduce((lambda x, y: x[0] + y[0]), args.datasets) if len(args.datasets) > 0 else args.datasets[0] 143 | full_exp_name += '_' + args.approach 144 | if args.exp_name is not None: 145 | full_exp_name += '_' + args.exp_name 146 | logger = MultiLogger(args.results_path, full_exp_name, loggers=args.log, save_models=args.save_models) 147 | logger.log_args(argparse.Namespace(**args.__dict__, **appr_args.__dict__, **appr_exemplars_dataset_args.__dict__)) 148 | 149 | # Loaders 150 | utils.seed_everything(seed=args.seed) 151 | trn_loader, val_loader, tst_loader, taskcla = get_loaders(args.datasets, args.num_tasks, args.nc_first_task, 152 | args.batch_size, num_workers=args.num_workers, 153 | pin_memory=args.pin_memory) 154 | # Apply arguments for loaders 155 | if args.use_valid_only: 156 | tst_loader = val_loader 157 | max_task = len(taskcla) if args.stop_at_task == 0 else args.stop_at_task 158 | 159 | # Network and Approach instances 160 | utils.seed_everything(seed=args.seed) 161 | net = LLL_Net(init_model, remove_existing_head=not args.keep_existing_head) 162 | utils.seed_everything(seed=args.seed) 163 | # taking transformations and class indices from first train dataset 164 | first_train_ds = trn_loader[0].dataset 165 | transform, class_indices = first_train_ds.transform, first_train_ds.class_indices 166 | appr_kwargs = {**base_kwargs, **dict(logger=logger, **appr_args.__dict__)} 167 | if Appr_ExemplarsDataset: 168 | appr_kwargs['exemplars_dataset'] = Appr_ExemplarsDataset(transform, class_indices, 169 | **appr_exemplars_dataset_args.__dict__) 170 | utils.seed_everything(seed=args.seed) 171 | appr = Appr(net, device, **appr_kwargs) 172 | 173 | # Loop tasks 174 | print(taskcla) 175 | acc_taw = np.zeros((max_task, max_task)) 176 | acc_tag = np.zeros((max_task, max_task)) 177 | forg_taw = np.zeros((max_task, max_task)) 178 | forg_tag = np.zeros((max_task, max_task)) 179 | for t, (_, ncla) in enumerate(taskcla): 180 | # Early stop tasks if flag 181 | if t >= max_task: 182 | continue 183 | 184 | print('*' * 108) 185 | print('Task {:2d}'.format(t)) 186 | print('*' * 108) 187 | 188 | # Add head for current task 189 | net.add_head(taskcla[t][1]) 190 | net.to(device) 191 | 192 | # Train 193 | appr.train(t, trn_loader[t], val_loader[t]) 194 | print('-' * 108) 195 | 196 | # Test 197 | for u in range(t + 1): 198 | test_loss, acc_taw[t, u], acc_tag[t, u] = appr.eval(u, tst_loader[u], t) 199 | 200 | if u < t: 201 | forg_taw[t, u] = acc_taw[:t, u].max(0) - acc_taw[t, u] 202 | forg_tag[t, u] = acc_tag[:t, u].max(0) - acc_tag[t, u] 203 | print('>>> Test on task {:2d} : loss={:.3f} | TAw acc={:5.1f}%, forg={:5.1f}%' 204 | '| TAg acc={:5.1f}%, forg={:5.1f}% <<<'.format(u, test_loss, 205 | 100 * acc_taw[t, u], 206 | 100 * forg_taw[t, u], 207 | 100 * acc_tag[t, u], 208 | 100 * forg_tag[t, u])) 209 | logger.log_scalar(task=t, iter=u, name='loss', group='test', value=test_loss) 210 | logger.log_scalar(task=t, iter=u, name='acc_taw', group='test', value=100 * acc_taw[t, u]) 211 | logger.log_scalar(task=t, iter=u, name='acc_tag', group='test', value=100 * acc_tag[t, u]) 212 | logger.log_scalar(task=t, iter=u, name='forg_taw', group='test', value=100 * forg_taw[t, u]) 213 | logger.log_scalar(task=t, iter=u, name='forg_tag', group='test', value=100 * forg_tag[t, u]) 214 | 215 | # Save 216 | print('Save at ' + os.path.join(args.results_path, full_exp_name)) 217 | logger.log_result(acc_taw, name="acc_taw", step=t) 218 | logger.log_result(acc_tag, name="acc_tag", step=t) 219 | logger.log_result(forg_taw, name="forg_taw", step=t) 220 | logger.log_result(forg_tag, name="forg_tag", step=t) 221 | logger.save_model(net.state_dict(), task=t) 222 | logger.log_result(acc_taw.sum(1) / np.tril(np.ones(acc_taw.shape[0])).sum(1), name="avg_accs_taw", step=t) 223 | logger.log_result(acc_tag.sum(1) / np.tril(np.ones(acc_tag.shape[0])).sum(1), name="avg_accs_tag", step=t) 224 | aux = np.tril(np.repeat([[tdata[1] for tdata in taskcla[:max_task]]], max_task, axis=0)) 225 | logger.log_result((acc_taw * aux).sum(1) / aux.sum(1), name="wavg_accs_taw", step=t) 226 | logger.log_result((acc_tag * aux).sum(1) / aux.sum(1), name="wavg_accs_tag", step=t) 227 | 228 | # Print Summary 229 | utils.print_summary(acc_taw, acc_tag, forg_taw, forg_tag) 230 | print('[Elapsed time = {:.1f} h]'.format((time.time() - tstart) / (60 * 60))) 231 | print('Done!') 232 | 233 | return acc_taw, acc_tag, forg_taw, forg_tag, logger.exp_path 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /src/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .scroll_resnet18 import scroll_resnet18 2 | allmodels = ['scroll_resnet18'] 3 | -------------------------------------------------------------------------------- /src/networks/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from copy import deepcopy 4 | from .slimmable_ops import SlimmableConv2d, SlimmableLinear 5 | from widths.config import FLAGS 6 | 7 | class LLL_Net(nn.Module): 8 | """Basic class for implementing networks""" 9 | 10 | def __init__(self, model, remove_existing_head=False): 11 | head_var = model.head_var 12 | assert type(head_var) == str 13 | assert not remove_existing_head or hasattr(model, head_var), \ 14 | "Given model does not have a variable called {}".format(head_var) 15 | assert not remove_existing_head or type(getattr(model, head_var)) in [nn.Sequential, nn.Linear, SlimmableLinear], \ 16 | "Given model's head {} does is not an instance of nn.Sequential or nn.Linear".format(head_var) 17 | super(LLL_Net, self).__init__() 18 | 19 | self.model = model 20 | last_layer = getattr(self.model, head_var) 21 | 22 | if remove_existing_head: 23 | if type(last_layer) == nn.Sequential: 24 | self.out_size = last_layer[-1].in_features 25 | # strips off last linear layer of classifier 26 | del last_layer[-1] 27 | elif type(last_layer) == nn.Linear: 28 | self.out_size = last_layer.in_features 29 | # converts last layer into identity 30 | # setattr(self.model, head_var, nn.Identity()) 31 | # WARNING: this is for when pytorch version is <1.2 32 | setattr(self.model, head_var, nn.Sequential()) 33 | elif type(last_layer) == SlimmableLinear: 34 | self.out_size = last_layer.in_features 35 | setattr(self.model, head_var, nn.Sequential()) 36 | else: 37 | self.out_size = last_layer.out_features 38 | 39 | self.heads = nn.ModuleList() 40 | self.task_cls = [] 41 | self.task_offset = [] 42 | self._initialize_weights() 43 | 44 | def add_head(self, num_outputs): 45 | """Add a new head with the corresponding number of outputs. Also update the number of classes per task and the 46 | corresponding offsets 47 | """ 48 | Ch_in = [int(self.out_size * width_mult) for width_mult in FLAGS.width_mult_list] 49 | Ch_out = [num_outputs for width_mult in FLAGS.width_mult_list] 50 | 51 | self.heads.append(SlimmableLinear(Ch_in, Ch_out)) 52 | # we re-compute instead of append in case an approach makes changes to the heads 53 | self.task_cls = torch.tensor([head.out_features for head in self.heads]) 54 | self.task_offset = torch.cat([torch.LongTensor(1).zero_(), self.task_cls.cumsum(0)[:-1]]) 55 | 56 | def forward(self, x, return_features=False): 57 | """Applies the forward pass 58 | 59 | Simplification to work on multi-head only -- returns all head outputs in a list 60 | Args: 61 | x (tensor): input images 62 | return_features (bool): return the representations before the heads 63 | """ 64 | x = self.model(x) 65 | assert (len(self.heads) > 0), "Cannot access any head" 66 | y = [] 67 | for head in self.heads: 68 | y.append(head(x)) 69 | if return_features: 70 | return y, x 71 | else: 72 | return y 73 | 74 | def get_copy(self): 75 | """Get weights from the model""" 76 | return deepcopy(self.state_dict()) 77 | 78 | def set_state_dict(self, state_dict): 79 | """Load weights into the model""" 80 | self.load_state_dict(deepcopy(state_dict)) 81 | return 82 | 83 | def freeze_all(self): 84 | """Freeze all parameters from the model, including the heads""" 85 | for param in self.parameters(): 86 | param.requires_grad = False 87 | 88 | def freeze_backbone(self): 89 | """Freeze all parameters from the main model, but not the heads""" 90 | for param in self.model.parameters(): 91 | param.requires_grad = False 92 | 93 | def freeze_bn(self): 94 | """Freeze all Batch Normalization layers from the model and use them in eval() mode""" 95 | for m in self.model.modules(): 96 | if isinstance(m, nn.BatchNorm2d): 97 | m.eval() 98 | 99 | def _initialize_weights(self): 100 | """Initialize weights using different strategies""" 101 | # TODO: add different initialization strategies 102 | pass 103 | -------------------------------------------------------------------------------- /src/networks/scroll_resnet18.py: -------------------------------------------------------------------------------- 1 | from distutils.util import change_root 2 | import torch.nn as nn 3 | import math 4 | 5 | from .slimmable_ops import SwitchableBatchNorm2d 6 | from .slimmable_ops import SlimmableConv2d, SlimmableLinear 7 | from widths.config import FLAGS 8 | 9 | def slimconv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return SlimmableConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | channels_in = [int(inplanes * width_mult) for width_mult in FLAGS.width_mult_list] 20 | channels_out = [int(planes * width_mult) for width_mult in FLAGS.width_mult_list] 21 | 22 | self.conv1 = slimconv3x3(channels_in, channels_out, stride) 23 | self.bn1 = SwitchableBatchNorm2d(channels_out) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = slimconv3x3(channels_out, channels_out) 26 | self.bn2 = SwitchableBatchNorm2d(channels_out) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | class Scroll_ResNet(nn.Module): 45 | 46 | def __init__(self, block, layers, num_classes=10): 47 | self.inplanes = 64 48 | super(Scroll_ResNet, self).__init__() 49 | chann_head_in = [3 for width_mult in FLAGS.width_mult_list] 50 | chann_head_out = [int(64 * width_mult) for width_mult in FLAGS.width_mult_list] 51 | self.conv1 = SlimmableConv2d(chann_head_in, chann_head_out, kernel_size=3, stride=1, padding=1, 52 | bias=False) 53 | 54 | self.bn1 = SwitchableBatchNorm2d(chann_head_out) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.layer1 = self._make_layer(block, 64, layers[0]) 57 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 58 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 59 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 60 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 61 | 62 | # last classifier layer (head) with as many outputs as classes 63 | chann_tail_in = [int(512 * block.expansion * width_mult) for width_mult in FLAGS.width_mult_list] 64 | chann_tail_out = [num_classes for width_mult in FLAGS.width_mult_list] 65 | self.fc = SlimmableLinear(chann_tail_in, chann_tail_out) 66 | # self.last_dim = self.fc.in_features 67 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 68 | self.head_var = 'fc' 69 | 70 | for m in self.modules(): 71 | if isinstance(m, nn.Conv2d): 72 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 73 | elif isinstance(m, nn.BatchNorm2d): 74 | nn.init.constant_(m.weight, 1) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1): 78 | downsample = None 79 | chann_d_in = [int(self.inplanes * width_mult) for width_mult in FLAGS.width_mult_list] 80 | chann_d_out = [int(planes * block.expansion * width_mult) for width_mult in FLAGS.width_mult_list] 81 | 82 | if stride != 1 or self.inplanes != planes * block.expansion: 83 | downsample = nn.Sequential( 84 | SlimmableConv2d(chann_d_in, chann_d_out, 85 | kernel_size=1, stride=stride, bias=False), 86 | SwitchableBatchNorm2d(chann_d_out), 87 | ) 88 | 89 | layers = [] 90 | layers.append(block(self.inplanes, planes, stride, downsample)) 91 | self.inplanes = planes * block.expansion 92 | 93 | for i in range(1, blocks): 94 | layers.append(block(self.inplanes, planes)) 95 | 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | x = self.conv1(x) 100 | x = self.bn1(x) 101 | x = self.relu(x) 102 | 103 | x = self.layer1(x) 104 | x = self.layer2(x) 105 | x = self.layer3(x) 106 | x = self.layer4(x) 107 | 108 | x = self.avgpool(x) 109 | x = x.view(x.size(0), -1) 110 | x = self.fc(x) 111 | 112 | return x 113 | 114 | def scroll_resnet18(pretrained=False, **kwargs): 115 | """Constructs a ResNet-18 model. 116 | 117 | Args: 118 | pretrained (bool): If True, returns a model pre-trained on ImageNet 119 | """ 120 | model = Scroll_ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 121 | return model 122 | -------------------------------------------------------------------------------- /src/networks/slimmable_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from widths.config import FLAGS 5 | 6 | class SwitchableBatchNorm2d(nn.Module): 7 | def __init__(self, num_features_list): 8 | super(SwitchableBatchNorm2d, self).__init__() 9 | self.num_features_list = num_features_list 10 | self.num_features = max(num_features_list) 11 | bns = [] 12 | for i in num_features_list: 13 | bns.append(nn.BatchNorm2d(i)) 14 | self.bn = nn.ModuleList(bns) 15 | self.width_mult = max(FLAGS.width_mult_list) 16 | self.ignore_model_profiling = True 17 | self.scroll = 0.0 18 | 19 | def forward(self, input): 20 | idx = FLAGS.width_mult_list.index(self.width_mult) 21 | y = self.bn[idx](input) 22 | return y 23 | 24 | 25 | class SlimmableConv2d(nn.Conv2d): 26 | def __init__(self, in_channels_list, out_channels_list, 27 | kernel_size, stride=1, padding=0, dilation=1, 28 | groups_list=[1], bias=True): 29 | super(SlimmableConv2d, self).__init__( 30 | max(in_channels_list), max(out_channels_list), 31 | kernel_size, stride=stride, padding=padding, dilation=dilation, 32 | groups=max(groups_list), bias=bias) 33 | self.in_channels_list = in_channels_list 34 | self.out_channels_list = out_channels_list 35 | self.groups_list = groups_list 36 | if self.groups_list == [1]: 37 | self.groups_list = [1 for _ in range(len(in_channels_list))] 38 | self.width_mult = max(FLAGS.width_mult_list) 39 | self.scroll = 0.0 40 | self.inverse = 1 # 0 is up, 1 is down 41 | 42 | def cyc_scroll(self, scroll_num1, scroll_num2): 43 | h, w, _, _ = self.weight.shape 44 | matrix = torch.cat((self.weight[(h-scroll_num1):,:], self.weight[:(h-scroll_num1),:]), dim=0) 45 | weight = torch.cat((matrix[:,(h-scroll_num2):], matrix[:,:(h-scroll_num2)]), dim=1) 46 | 47 | return weight 48 | 49 | def cyc_scroll_bias(self, scroll_num): 50 | L = self.bias 51 | bias = torch.cat((self.bias[(L-scroll_num):], self.bias[:(L-scroll_num)]), dim=0) 52 | 53 | return bias 54 | 55 | def cyc_scroll_inverse(self, scroll_num1, scroll_num2): 56 | h, w, _, _ = self.weight.shape 57 | matrix = torch.cat((self.weight[scroll_num1:,:], self.weight[:scroll_num1,:]), dim=0) 58 | weight = torch.cat((matrix[:,scroll_num2:], matrix[:,:scroll_num2]), dim=1) 59 | 60 | return weight 61 | 62 | def cyc_scroll_bias_inverse(self, scroll_num): 63 | bias = torch.cat((self.bias[scroll_num:], self.bias[:scroll_num]), dim=0) 64 | 65 | return bias 66 | 67 | def forward(self, input): 68 | self.scroll = self.scroll % len(FLAGS.width_mult_list) # cycle scrolling 69 | idx = FLAGS.width_mult_list.index(self.width_mult) 70 | self.in_channels = self.in_channels_list[idx] 71 | self.out_channels = self.out_channels_list[idx] 72 | self.groups = self.groups_list[idx] 73 | scroll_num1 = int(self.scroll*(self.out_channels_list[1]-self.out_channels_list[0])) 74 | scroll_num2 = int(self.scroll*(self.in_channels_list[1]-self.in_channels_list[0])) 75 | 76 | if self.inverse==0: 77 | weight = self.cyc_scroll(scroll_num1, scroll_num2) 78 | elif self.inverse==1: 79 | weight = self.cyc_scroll_inverse(scroll_num1, scroll_num2) 80 | weight = weight[:self.out_channels, :self.in_channels, :, :] 81 | 82 | if self.bias is not None: 83 | if self.inverse==0: 84 | bias = self.cyc_scroll_bias(scroll_num1) 85 | elif self.inverse==1: 86 | bias = self.cyc_scroll_bias_inverse(scroll_num1) 87 | bias = bias[:self.out_channels] 88 | else: 89 | bias = self.bias 90 | 91 | y = nn.functional.conv2d( 92 | input, weight, bias, self.stride, self.padding, 93 | self.dilation, self.groups) 94 | return y 95 | 96 | 97 | class SlimmableLinear(nn.Linear): 98 | def __init__(self, in_features_list, out_features_list, bias=True): 99 | super(SlimmableLinear, self).__init__( 100 | max(in_features_list), max(out_features_list), bias=bias) 101 | self.in_features_list = in_features_list 102 | self.out_features_list = out_features_list 103 | self.width_mult = max(FLAGS.width_mult_list) 104 | self.scroll = 0.0 105 | 106 | def cyc_scroll_inverse(self, scroll_num1, scroll_num2): 107 | matrix = torch.cat((self.weight[scroll_num1:,:], self.weight[:scroll_num1,:]), dim=0) 108 | weight = torch.cat((matrix[:,scroll_num2:], matrix[:,:scroll_num2]), dim=1) 109 | return weight 110 | 111 | def cyc_scroll_bias_inverse(self, scroll_num): 112 | bias = torch.cat((self.bias[scroll_num:], self.bias[:scroll_num]), dim=0) 113 | return bias 114 | 115 | def forward(self, input): 116 | 117 | self.scroll = self.scroll % len(FLAGS.width_mult_list) # cycle scrolling 118 | 119 | idx = FLAGS.width_mult_list.index(self.width_mult) 120 | self.in_features = self.in_features_list[idx] 121 | self.out_features = self.out_features_list[idx] 122 | scroll_num1 = int(self.scroll*(self.out_features_list[1]-self.out_features_list[0])) 123 | scroll_num2 = int(self.scroll*(self.in_features_list[1]-self.in_features_list[0])) 124 | 125 | weight = self.cyc_scroll_inverse(scroll_num1, scroll_num2) 126 | weight = weight[:self.out_features, :self.in_features] 127 | if self.bias is not None: 128 | bias = self.cyc_scroll_bias_inverse(scroll_num1) 129 | bias = bias[:self.out_features] 130 | else: 131 | bias = self.bias 132 | return nn.functional.linear(input, weight, bias) 133 | 134 | def make_divisible(v, divisor=8, min_value=1): 135 | """ 136 | forked from slim: 137 | https://github.com/tensorflow/models/blob/\ 138 | 0344c5503ee55e24f0de7f37336a6e08f10976fd/\ 139 | research/slim/nets/mobilenet/mobilenet.py#L62-L69 140 | """ 141 | if min_value is None: 142 | min_value = divisor 143 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 144 | # Make sure that round down does not go down by more than 10%. 145 | if new_v < 0.9 * v: 146 | new_v += divisor 147 | return new_v -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | cudnn_deterministic = True 7 | 8 | 9 | def seed_everything(seed=0): 10 | """Fix all random seeds""" 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.backends.cudnn.deterministic = cudnn_deterministic 17 | 18 | 19 | def print_summary(acc_taw, acc_tag, forg_taw, forg_tag): 20 | """Print summary of results""" 21 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]): 22 | print('*' * 108) 23 | print(name) 24 | for i in range(metric.shape[0]): 25 | print('\t', end='') 26 | for j in range(metric.shape[1]): 27 | print('{:5.1f}% '.format(100 * metric[i, j]), end='') 28 | if np.trace(metric) == 0.0: 29 | if i > 0: 30 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='') 31 | else: 32 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i + 1].mean()), end='') 33 | print() 34 | print('*' * 108) 35 | -------------------------------------------------------------------------------- /src/widths/config.py: -------------------------------------------------------------------------------- 1 | """config utilities for yml file.""" 2 | import os 3 | import sys 4 | import yaml 5 | 6 | # singletone 7 | FLAGS = None 8 | 9 | 10 | class LoaderMeta(type): 11 | """Constructor for supporting `!include`. 12 | """ 13 | def __new__(mcs, __name__, __bases__, __dict__): 14 | """Add include constructer to class.""" 15 | # register the include constructor on the class 16 | cls = super().__new__(mcs, __name__, __bases__, __dict__) 17 | cls.add_constructor('!include', cls.construct_include) 18 | return cls 19 | 20 | 21 | class Loader(yaml.Loader, metaclass=LoaderMeta): 22 | """YAML Loader with `!include` constructor. 23 | """ 24 | def __init__(self, stream): 25 | try: 26 | self._root = os.path.split(stream.name)[0] 27 | except AttributeError: 28 | self._root = os.path.curdir 29 | super().__init__(stream) 30 | 31 | def construct_include(self, node): 32 | """Include file referenced at node.""" 33 | filename = os.path.abspath( 34 | os.path.join(self._root, self.construct_scalar(node))) 35 | extension = os.path.splitext(filename)[1].lstrip('.') 36 | with open(filename, 'r') as f: 37 | if extension in ('yaml', 'yml'): 38 | return yaml.load(f, Loader) 39 | else: 40 | return ''.join(f.readlines()) 41 | 42 | class AttrDict(dict): 43 | """Dict as attribute trick. 44 | 45 | """ 46 | def __init__(self, *args, **kwargs): 47 | super(AttrDict, self).__init__(*args, **kwargs) 48 | self.__dict__ = self 49 | for key in self.__dict__: 50 | value = self.__dict__[key] 51 | if isinstance(value, dict): 52 | self.__dict__[key] = AttrDict(value) 53 | elif isinstance(value, list): 54 | if isinstance(value[0], dict): 55 | self.__dict__[key] = [AttrDict(item) for item in value] 56 | else: 57 | self.__dict__[key] = value 58 | 59 | def yaml(self): 60 | """Convert object to yaml dict and return. 61 | 62 | """ 63 | yaml_dict = {} 64 | for key in self.__dict__: 65 | value = self.__dict__[key] 66 | if isinstance(value, AttrDict): 67 | yaml_dict[key] = value.yaml() 68 | elif isinstance(value, list): 69 | if isinstance(value[0], AttrDict): 70 | new_l = [] 71 | for item in value: 72 | new_l.append(item.yaml()) 73 | yaml_dict[key] = new_l 74 | else: 75 | yaml_dict[key] = value 76 | else: 77 | yaml_dict[key] = value 78 | return yaml_dict 79 | 80 | def __repr__(self): 81 | """Print all variables. 82 | 83 | """ 84 | ret_str = [] 85 | for key in self.__dict__: 86 | value = self.__dict__[key] 87 | if isinstance(value, AttrDict): 88 | ret_str.append('{}:'.format(key)) 89 | child_ret_str = value.__repr__().split('\n') 90 | for item in child_ret_str: 91 | ret_str.append(' ' + item) 92 | elif isinstance(value, list): 93 | if isinstance(value[0], AttrDict): 94 | ret_str.append('{}:'.format(key)) 95 | for item in value: 96 | # treat as AttrDict above 97 | child_ret_str = item.__repr__().split('\n') 98 | for item in child_ret_str: 99 | ret_str.append(' ' + item) 100 | else: 101 | ret_str.append('{}: {}'.format(key, value)) 102 | else: 103 | ret_str.append('{}: {}'.format(key, value)) 104 | return '\n'.join(ret_str) 105 | 106 | 107 | class Config(AttrDict): 108 | """Config with yaml file. 109 | 110 | This class is used to config model hyper-parameters, global constants, and 111 | other settings with yaml file. All settings in yaml file will be 112 | automatically logged into file. 113 | 114 | Args: 115 | filename(str): File name. 116 | 117 | Examples: 118 | 119 | yaml file ``model.yml``:: 120 | 121 | NAME: 'neuralgym' 122 | ALPHA: 1.0 123 | DATASET: '/mnt/data/imagenet' 124 | 125 | Usage in .py: 126 | 127 | >>> from neuralgym import Config 128 | >>> config = Config('model.yml') 129 | >>> print(config.NAME) 130 | neuralgym 131 | >>> print(config.ALPHA) 132 | 1.0 133 | >>> print(config.DATASET) 134 | /mnt/data/imagenet 135 | 136 | """ 137 | 138 | def __init__(self, filename=None, verbose=False): 139 | assert os.path.exists(filename), 'File {} not exist.'.format(filename) 140 | try: 141 | with open(filename, 'r') as f: 142 | cfg_dict = yaml.load(f, Loader) 143 | except EnvironmentError: 144 | print('Please check the file with name of "%s"', filename) 145 | super(Config, self).__init__(cfg_dict) 146 | if verbose: 147 | print(' pi.cfg '.center(80, '-')) 148 | print(self.__repr__()) 149 | print(''.center(80, '-')) 150 | 151 | def app(): 152 | """Load app via stdin from subprocess""" 153 | global FLAGS 154 | if FLAGS is None: 155 | job_yaml_file = 'SizeOfSubnetworks.yml' 156 | FLAGS = Config(job_yaml_file) 157 | return FLAGS 158 | else: 159 | return FLAGS 160 | 161 | app() 162 | --------------------------------------------------------------------------------