├── .gitignore ├── LICENSE ├── README.md ├── agents ├── __init__.py ├── customization.py ├── default.py ├── exp_replay.py └── regularization.py ├── dataloaders ├── __init__.py ├── base.py ├── datasetGen.py └── wrapper.py ├── fig ├── results_split_cifar100.png ├── results_split_mnist.png └── task_shifts.png ├── iBatchLearn.py ├── models ├── __init__.py ├── lenet.py ├── mlp.py ├── resnet.py └── senet.py ├── modules ├── __init__.py └── criterions.py ├── requirements.txt ├── scripts ├── permuted_MNIST_incremental_class.sh ├── permuted_MNIST_incremental_domain.sh ├── permuted_MNIST_incremental_task.sh ├── split_CIFAR100_incremental_class.sh ├── split_CIFAR100_incremental_domain.sh ├── split_CIFAR100_incremental_task.sh ├── split_MNIST_incremental_class.sh ├── split_MNIST_incremental_domain.sh └── split_MNIST_incremental_task.sh └── utils ├── __init__.py └── metric.py /.gitignore: -------------------------------------------------------------------------------- 1 | # repo-specific stuff 2 | data/ 3 | outputs/ 4 | *.pt 5 | \#*# 6 | .idea 7 | *.sublime-* 8 | *.pkl 9 | .DS_Store 10 | *.pth 11 | *.png 12 | .swp 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # dotenv 96 | .env 97 | 98 | # virtualenv 99 | .venv 100 | venv/ 101 | ENV/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 GT-RIPL, Yen-Chang Hsu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual-Learning-Benchmark 2 | Evaluate three types of task shifting with popular continual learning algorithms. 3 | 4 | This repository implemented and modularized following algorithms with PyTorch: 5 | - EWC: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1612.00796) (Overcoming catastrophic forgetting in neural networks) 6 | - Online EWC: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1805.06370) 7 | - SI: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1703.04200) (Continual Learning Through Synaptic Intelligence) 8 | - MAS: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf) (Memory Aware Synapses: Learning what (not) to forget) 9 | - GEM: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/exp_replay.py), [paper](https://arxiv.org/abs/1706.08840) (Gradient Episodic Memory for Continual Learning) 10 | - (More are coming) 11 | 12 | All the above algorithms are compared to following baselines with **the same static memory overhead**: 13 | - Naive rehearsal: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/exp_replay.py) 14 | - L2: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1612.00796) 15 | 16 | Key tables: 17 | 18 | 19 | 20 | 21 | If this repository helps your work, please cite: 22 | ``` 23 | @inproceedings{Hsu18_EvalCL, 24 | title={Re-evaluating Continual Learning Scenarios: A Categorization and Case for Strong Baselines}, 25 | author={Yen-Chang Hsu and Yen-Cheng Liu and Anita Ramasamy and Zsolt Kira}, 26 | booktitle={NeurIPS Continual learning Workshop }, 27 | year={2018}, 28 | url={https://arxiv.org/abs/1810.12488} 29 | } 30 | ``` 31 | 32 | ## Preparation 33 | This repository was tested with Python 3.6 and PyTorch 1.0.1.post2. Part of the cases is tested with PyTorch 1.5.1 and gives the same results. 34 | 35 | ```bash 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Demo 40 | The scripts for reproducing the results of this paper are under the scripts folder. 41 | 42 | - Example: Run all algorithms in the incremental domain scenario with split MNIST. 43 | ```bash 44 | ./scripts/split_MNIST_incremental_domain.sh 0 45 | # The last number is gpuid 46 | # Outputs will be saved in ./outputs 47 | ``` 48 | 49 | - Eaxmple outputs: Summary of repeats 50 | ```text 51 | ===Summary of experiment repeats: 3 / 3 === 52 | The regularization coefficient: 400.0 53 | The last avg acc of all repeats: [90.517 90.648 91.069] 54 | mean: 90.74466666666666 std: 0.23549144829955856 55 | ``` 56 | 57 | - Eaxmple outputs: The grid search for regularization coefficient 58 | ```text 59 | reg_coef: 0.1 mean: 76.08566666666667 std: 1.097717733400629 60 | reg_coef: 1.0 mean: 77.59100000000001 std: 2.100847606721314 61 | reg_coef: 10.0 mean: 84.33933333333334 std: 0.3592671553160509 62 | reg_coef: 100.0 mean: 90.83800000000001 std: 0.6913701372395712 63 | reg_coef: 1000.0 mean: 87.48566666666666 std: 0.5440161353816179 64 | reg_coef: 5000.0 mean: 68.99133333333333 std: 1.6824762174313899 65 | 66 | ``` 67 | 68 | ## Usage 69 | - Enable the grid search for the regularization coefficient: Use the option with a list of values, ex: -reg_coef 0.1 1 10 100 ... 70 | - Repeat the experiment N times: Use the option -repeat N 71 | 72 | Lookup available options: 73 | ```bash 74 | python iBatchLearn.py -h 75 | ``` 76 | 77 | ## Other results 78 | 79 | Below are CIFAR100 results. Please refer to the [scripts](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/scripts/split_CIFAR100_incremental_class.sh) for details. 80 | 81 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | from . import default 2 | from . import regularization 3 | from . import customization 4 | from . import exp_replay -------------------------------------------------------------------------------- /agents/customization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .default import NormalNN 3 | from .regularization import SI, EWC, EWC_online 4 | from .exp_replay import Naive_Rehearsal, GEM 5 | from modules.criterions import BCEauto 6 | 7 | def init_zero_weights(m): 8 | with torch.no_grad(): 9 | if type(m) == torch.nn.Linear: 10 | m.weight.zero_() 11 | m.bias.zero_() 12 | elif type(m) == torch.nn.ModuleDict: 13 | for l in m.values(): 14 | init_zero_weights(l) 15 | else: 16 | assert False, 'Only support linear layer' 17 | 18 | 19 | def NormalNN_reset_optim(agent_config): 20 | agent = NormalNN(agent_config) 21 | agent.reset_optimizer = True 22 | return agent 23 | 24 | 25 | def NormalNN_BCE(agent_config): 26 | agent = NormalNN(agent_config) 27 | agent.criterion_fn = BCEauto() 28 | return agent 29 | 30 | 31 | def SI_BCE(agent_config): 32 | agent = SI(agent_config) 33 | agent.criterion_fn = BCEauto() 34 | return agent 35 | 36 | 37 | def SI_splitMNIST_zero_init(agent_config): 38 | agent = SI(agent_config) 39 | agent.damping_factor = 1e-3 40 | agent.reset_optimizer = True 41 | agent.model.last.apply(init_zero_weights) 42 | return agent 43 | 44 | 45 | def SI_splitMNIST_rand_init(agent_config): 46 | agent = SI(agent_config) 47 | agent.damping_factor = 1e-3 48 | agent.reset_optimizer = True 49 | return agent 50 | 51 | 52 | def EWC_BCE(agent_config): 53 | agent = EWC(agent_config) 54 | agent.criterion_fn = BCEauto() 55 | return agent 56 | 57 | 58 | def EWC_mnist(agent_config): 59 | agent = EWC(agent_config) 60 | agent.n_fisher_sample = 60000 61 | return agent 62 | 63 | 64 | def EWC_online_mnist(agent_config): 65 | agent = EWC(agent_config) 66 | agent.n_fisher_sample = 60000 67 | agent.online_reg = True 68 | return agent 69 | 70 | 71 | def EWC_online_empFI(agent_config): 72 | agent = EWC(agent_config) 73 | agent.empFI = True 74 | return agent 75 | 76 | 77 | def EWC_zero_init(agent_config): 78 | agent = EWC(agent_config) 79 | agent.reset_optimizer = True 80 | agent.model.last.apply(init_zero_weights) 81 | return agent 82 | 83 | 84 | def EWC_rand_init(agent_config): 85 | agent = EWC(agent_config) 86 | agent.reset_optimizer = True 87 | return agent 88 | 89 | 90 | def EWC_reset_optim(agent_config): 91 | agent = EWC(agent_config) 92 | agent.reset_optimizer = True 93 | return agent 94 | 95 | 96 | def EWC_online_reset_optim(agent_config): 97 | agent = EWC_online(agent_config) 98 | agent.reset_optimizer = True 99 | return agent 100 | 101 | 102 | def Naive_Rehearsal_100(agent_config): 103 | agent = Naive_Rehearsal(agent_config) 104 | agent.memory_size = 100 105 | return agent 106 | 107 | 108 | def Naive_Rehearsal_200(agent_config): 109 | agent = Naive_Rehearsal(agent_config) 110 | agent.memory_size = 200 111 | return agent 112 | 113 | 114 | def Naive_Rehearsal_400(agent_config): 115 | agent = Naive_Rehearsal(agent_config) 116 | agent.memory_size = 400 117 | return agent 118 | 119 | 120 | def Naive_Rehearsal_1100(agent_config): 121 | agent = Naive_Rehearsal(agent_config) 122 | agent.memory_size = 1100 123 | return agent 124 | 125 | 126 | def Naive_Rehearsal_1400(agent_config): 127 | agent = Naive_Rehearsal(agent_config) 128 | agent.memory_size = 1400 129 | return agent 130 | 131 | 132 | def Naive_Rehearsal_4000(agent_config): 133 | agent = Naive_Rehearsal(agent_config) 134 | agent.memory_size = 4000 135 | return agent 136 | 137 | 138 | def Naive_Rehearsal_4400(agent_config): 139 | agent = Naive_Rehearsal(agent_config) 140 | agent.memory_size = 4400 141 | return agent 142 | 143 | 144 | def Naive_Rehearsal_5600(agent_config): 145 | agent = Naive_Rehearsal(agent_config) 146 | agent.memory_size = 5600 147 | return agent 148 | 149 | 150 | def Naive_Rehearsal_16000(agent_config): 151 | agent = Naive_Rehearsal(agent_config) 152 | agent.memory_size = 16000 153 | return agent 154 | 155 | 156 | def GEM_100(agent_config): 157 | agent = GEM(agent_config) 158 | agent.memory_size = 100 159 | return agent 160 | 161 | 162 | def GEM_200(agent_config): 163 | agent = GEM(agent_config) 164 | agent.memory_size = 200 165 | return agent 166 | 167 | 168 | def GEM_400(agent_config): 169 | agent = GEM(agent_config) 170 | agent.memory_size = 400 171 | return agent 172 | 173 | 174 | def GEM_orig_1100(agent_config): 175 | agent = GEM(agent_config) 176 | agent.skip_memory_concatenation = True 177 | agent.memory_size = 1100 178 | return agent 179 | 180 | 181 | def GEM_1100(agent_config): 182 | agent = GEM(agent_config) 183 | agent.memory_size = 1100 184 | return agent 185 | 186 | 187 | def GEM_4000(agent_config): 188 | agent = GEM(agent_config) 189 | agent.memory_size = 4000 190 | return agent 191 | 192 | 193 | def GEM_4400(agent_config): 194 | agent = GEM(agent_config) 195 | agent.memory_size = 4400 196 | return agent 197 | 198 | 199 | def GEM_16000(agent_config): 200 | agent = GEM(agent_config) 201 | agent.memory_size = 16000 202 | return agent -------------------------------------------------------------------------------- /agents/default.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | from types import MethodType 5 | import models 6 | from utils.metric import accuracy, AverageMeter, Timer 7 | 8 | class NormalNN(nn.Module): 9 | ''' 10 | Normal Neural Network with SGD for classification 11 | ''' 12 | def __init__(self, agent_config): 13 | ''' 14 | :param agent_config (dict): lr=float,momentum=float,weight_decay=float, 15 | schedule=[int], # The last number in the list is the end of epoch 16 | model_type=str,model_name=str,out_dim={task:dim},model_weights=str 17 | force_single_head=bool 18 | print_freq=int 19 | gpuid=[int] 20 | ''' 21 | super(NormalNN, self).__init__() 22 | self.log = print if agent_config['print_freq'] > 0 else lambda \ 23 | *args: None # Use a void function to replace the print 24 | self.config = agent_config 25 | # If out_dim is a dict, there is a list of tasks. The model will have a head for each task. 26 | self.multihead = True if len(self.config['out_dim'])>1 else False # A convenience flag to indicate multi-head/task 27 | self.model = self.create_model() 28 | self.criterion_fn = nn.CrossEntropyLoss() 29 | if agent_config['gpuid'][0] >= 0: 30 | self.cuda() 31 | self.gpu = True 32 | else: 33 | self.gpu = False 34 | self.init_optimizer() 35 | self.reset_optimizer = False 36 | self.valid_out_dim = 'ALL' # Default: 'ALL' means all output nodes are active 37 | # Set a interger here for the incremental class scenario 38 | 39 | def init_optimizer(self): 40 | optimizer_arg = {'params':self.model.parameters(), 41 | 'lr':self.config['lr'], 42 | 'weight_decay':self.config['weight_decay']} 43 | if self.config['optimizer'] in ['SGD','RMSprop']: 44 | optimizer_arg['momentum'] = self.config['momentum'] 45 | elif self.config['optimizer'] in ['Rprop']: 46 | optimizer_arg.pop('weight_decay') 47 | elif self.config['optimizer'] == 'amsgrad': 48 | optimizer_arg['amsgrad'] = True 49 | self.config['optimizer'] = 'Adam' 50 | 51 | self.optimizer = torch.optim.__dict__[self.config['optimizer']](**optimizer_arg) 52 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config['schedule'], 53 | gamma=0.1) 54 | 55 | def create_model(self): 56 | cfg = self.config 57 | 58 | # Define the backbone (MLP, LeNet, VGG, ResNet ... etc) of model 59 | model = models.__dict__[cfg['model_type']].__dict__[cfg['model_name']]() 60 | 61 | # Apply network surgery to the backbone 62 | # Create the heads for tasks (It can be single task or multi-task) 63 | n_feat = model.last.in_features 64 | 65 | # The output of the model will be a dict: {task_name1:output1, task_name2:output2 ...} 66 | # For a single-headed model the output will be {'All':output} 67 | model.last = nn.ModuleDict() 68 | for task,out_dim in cfg['out_dim'].items(): 69 | model.last[task] = nn.Linear(n_feat,out_dim) 70 | 71 | # Redefine the task-dependent function 72 | def new_logits(self, x): 73 | outputs = {} 74 | for task, func in self.last.items(): 75 | outputs[task] = func(x) 76 | return outputs 77 | 78 | # Replace the task-dependent function 79 | model.logits = MethodType(new_logits, model) 80 | # Load pre-trained weights 81 | if cfg['model_weights'] is not None: 82 | print('=> Load model weights:', cfg['model_weights']) 83 | model_state = torch.load(cfg['model_weights'], 84 | map_location=lambda storage, loc: storage) # Load to CPU. 85 | model.load_state_dict(model_state) 86 | print('=> Load Done') 87 | return model 88 | 89 | def forward(self, x): 90 | return self.model.forward(x) 91 | 92 | def predict(self, inputs): 93 | self.model.eval() 94 | out = self.forward(inputs) 95 | for t in out.keys(): 96 | out[t] = out[t].detach() 97 | return out 98 | 99 | def validation(self, dataloader): 100 | # This function doesn't distinguish tasks. 101 | batch_timer = Timer() 102 | acc = AverageMeter() 103 | batch_timer.tic() 104 | 105 | orig_mode = self.training 106 | self.eval() 107 | for i, (input, target, task) in enumerate(dataloader): 108 | 109 | if self.gpu: 110 | with torch.no_grad(): 111 | input = input.cuda() 112 | target = target.cuda() 113 | output = self.predict(input) 114 | 115 | # Summarize the performance of all tasks, or 1 task, depends on dataloader. 116 | # Calculated by total number of data. 117 | acc = accumulate_acc(output, target, task, acc) 118 | 119 | self.train(orig_mode) 120 | 121 | self.log(' * Val Acc {acc.avg:.3f}, Total time {time:.2f}' 122 | .format(acc=acc,time=batch_timer.toc())) 123 | return acc.avg 124 | 125 | def criterion(self, preds, targets, tasks, **kwargs): 126 | # The inputs and targets could come from single task or a mix of tasks 127 | # The network always makes the predictions with all its heads 128 | # The criterion will match the head and task to calculate the loss. 129 | if self.multihead: 130 | loss = 0 131 | for t,t_preds in preds.items(): 132 | inds = [i for i in range(len(tasks)) if tasks[i]==t] # The index of inputs that matched specific task 133 | if len(inds)>0: 134 | t_preds = t_preds[inds] 135 | t_target = targets[inds] 136 | loss += self.criterion_fn(t_preds, t_target) * len(inds) # restore the loss from average 137 | loss /= len(targets) # Average the total loss by the mini-batch size 138 | else: 139 | pred = preds['All'] 140 | if isinstance(self.valid_out_dim, int): # (Not 'ALL') Mask out the outputs of unseen classes for incremental class scenario 141 | pred = preds['All'][:,:self.valid_out_dim] 142 | loss = self.criterion_fn(pred, targets) 143 | return loss 144 | 145 | def update_model(self, inputs, targets, tasks): 146 | out = self.forward(inputs) 147 | loss = self.criterion(out, targets, tasks) 148 | self.optimizer.zero_grad() 149 | loss.backward() 150 | self.optimizer.step() 151 | return loss.detach(), out 152 | 153 | def learn_batch(self, train_loader, val_loader=None): 154 | if self.reset_optimizer: # Reset optimizer before learning each task 155 | self.log('Optimizer is reset!') 156 | self.init_optimizer() 157 | 158 | for epoch in range(self.config['schedule'][-1]): 159 | data_timer = Timer() 160 | batch_timer = Timer() 161 | batch_time = AverageMeter() 162 | data_time = AverageMeter() 163 | losses = AverageMeter() 164 | acc = AverageMeter() 165 | 166 | # Config the model and optimizer 167 | self.log('Epoch:{0}'.format(epoch)) 168 | self.model.train() 169 | self.scheduler.step(epoch) 170 | for param_group in self.optimizer.param_groups: 171 | self.log('LR:',param_group['lr']) 172 | 173 | # Learning with mini-batch 174 | data_timer.tic() 175 | batch_timer.tic() 176 | self.log('Itr\t\tTime\t\t Data\t\t Loss\t\tAcc') 177 | for i, (input, target, task) in enumerate(train_loader): 178 | 179 | data_time.update(data_timer.toc()) # measure data loading time 180 | 181 | if self.gpu: 182 | input = input.cuda() 183 | target = target.cuda() 184 | 185 | loss, output = self.update_model(input, target, task) 186 | input = input.detach() 187 | target = target.detach() 188 | 189 | # measure accuracy and record loss 190 | acc = accumulate_acc(output, target, task, acc) 191 | losses.update(loss, input.size(0)) 192 | 193 | batch_time.update(batch_timer.toc()) # measure elapsed time 194 | data_timer.toc() 195 | 196 | if ((self.config['print_freq']>0) and (i % self.config['print_freq'] == 0)) or (i+1)==len(train_loader): 197 | self.log('[{0}/{1}]\t' 198 | '{batch_time.val:.4f} ({batch_time.avg:.4f})\t' 199 | '{data_time.val:.4f} ({data_time.avg:.4f})\t' 200 | '{loss.val:.3f} ({loss.avg:.3f})\t' 201 | '{acc.val:.2f} ({acc.avg:.2f})'.format( 202 | i, len(train_loader), batch_time=batch_time, 203 | data_time=data_time, loss=losses, acc=acc)) 204 | 205 | self.log(' * Train Acc {acc.avg:.3f}'.format(acc=acc)) 206 | 207 | # Evaluate the performance of current task 208 | if val_loader != None: 209 | self.validation(val_loader) 210 | 211 | def learn_stream(self, data, label): 212 | assert False,'No implementation yet' 213 | 214 | def add_valid_output_dim(self, dim=0): 215 | # This function is kind of ad-hoc, but it is the simplest way to support incremental class learning 216 | self.log('Incremental class: Old valid output dimension:', self.valid_out_dim) 217 | if self.valid_out_dim == 'ALL': 218 | self.valid_out_dim = 0 # Initialize it with zero 219 | self.valid_out_dim += dim 220 | self.log('Incremental class: New Valid output dimension:', self.valid_out_dim) 221 | return self.valid_out_dim 222 | 223 | def count_parameter(self): 224 | return sum(p.numel() for p in self.model.parameters()) 225 | 226 | def save_model(self, filename): 227 | model_state = self.model.state_dict() 228 | if isinstance(self.model,torch.nn.DataParallel): 229 | # Get rid of 'module' before the name of states 230 | model_state = self.model.module.state_dict() 231 | for key in model_state.keys(): # Always save it to cpu 232 | model_state[key] = model_state[key].cpu() 233 | print('=> Saving model to:', filename) 234 | torch.save(model_state, filename + '.pth') 235 | print('=> Save Done') 236 | 237 | def cuda(self): 238 | torch.cuda.set_device(self.config['gpuid'][0]) 239 | self.model = self.model.cuda() 240 | self.criterion_fn = self.criterion_fn.cuda() 241 | # Multi-GPU 242 | if len(self.config['gpuid']) > 1: 243 | self.model = torch.nn.DataParallel(self.model, device_ids=self.config['gpuid'], output_device=self.config['gpuid'][0]) 244 | return self 245 | 246 | def accumulate_acc(output, target, task, meter): 247 | if 'All' in output.keys(): # Single-headed model 248 | meter.update(accuracy(output['All'], target), len(target)) 249 | else: # outputs from multi-headed (multi-task) model 250 | for t, t_out in output.items(): 251 | inds = [i for i in range(len(task)) if task[i] == t] # The index of inputs that matched specific task 252 | if len(inds) > 0: 253 | t_out = t_out[inds] 254 | t_target = target[inds] 255 | meter.update(accuracy(t_out, t_target), len(inds)) 256 | 257 | return meter 258 | -------------------------------------------------------------------------------- /agents/exp_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from importlib import import_module 4 | from .default import NormalNN 5 | from .regularization import SI, L2, EWC, MAS 6 | from dataloaders.wrapper import Storage 7 | 8 | 9 | class Naive_Rehearsal(NormalNN): 10 | 11 | def __init__(self, agent_config): 12 | super(Naive_Rehearsal, self).__init__(agent_config) 13 | self.task_count = 0 14 | self.memory_size = 1000 15 | self.task_memory = {} 16 | self.skip_memory_concatenation = False 17 | 18 | def learn_batch(self, train_loader, val_loader=None): 19 | # 1.Combine training set 20 | if self.skip_memory_concatenation: 21 | new_train_loader = train_loader 22 | else: # default 23 | dataset_list = [] 24 | for storage in self.task_memory.values(): 25 | dataset_list.append(storage) 26 | dataset_list *= max(len(train_loader.dataset)//self.memory_size,1) # Let old data: new data = 1:1 27 | dataset_list.append(train_loader.dataset) 28 | dataset = torch.utils.data.ConcatDataset(dataset_list) 29 | new_train_loader = torch.utils.data.DataLoader(dataset, 30 | batch_size=train_loader.batch_size, 31 | shuffle=True, 32 | num_workers=train_loader.num_workers) 33 | 34 | # 2.Update model as normal 35 | super(Naive_Rehearsal, self).learn_batch(new_train_loader, val_loader) 36 | 37 | # 3.Randomly decide the images to stay in the memory 38 | self.task_count += 1 39 | # (a) Decide the number of samples for being saved 40 | num_sample_per_task = self.memory_size // self.task_count 41 | num_sample_per_task = min(len(train_loader.dataset),num_sample_per_task) 42 | # (b) Reduce current exemplar set to reserve the space for the new dataset 43 | for storage in self.task_memory.values(): 44 | storage.reduce(num_sample_per_task) 45 | # (c) Randomly choose some samples from new task and save them to the memory 46 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data 47 | self.task_memory[self.task_count] = Storage(train_loader.dataset, randind) 48 | 49 | 50 | class Naive_Rehearsal_SI(Naive_Rehearsal, SI): 51 | 52 | def __init__(self, agent_config): 53 | super(Naive_Rehearsal_SI, self).__init__(agent_config) 54 | 55 | 56 | class Naive_Rehearsal_L2(Naive_Rehearsal, L2): 57 | 58 | def __init__(self, agent_config): 59 | super(Naive_Rehearsal_L2, self).__init__(agent_config) 60 | 61 | 62 | class Naive_Rehearsal_EWC(Naive_Rehearsal, EWC): 63 | 64 | def __init__(self, agent_config): 65 | super(Naive_Rehearsal_EWC, self).__init__(agent_config) 66 | self.online_reg = True # Online EWC 67 | 68 | 69 | class Naive_Rehearsal_MAS(Naive_Rehearsal, MAS): 70 | 71 | def __init__(self, agent_config): 72 | super(Naive_Rehearsal_MAS, self).__init__(agent_config) 73 | 74 | 75 | class GEM(Naive_Rehearsal): 76 | """ 77 | @inproceedings{GradientEpisodicMemory, 78 | title={Gradient Episodic Memory for Continual Learning}, 79 | author={Lopez-Paz, David and Ranzato, Marc'Aurelio}, 80 | booktitle={NIPS}, 81 | year={2017}, 82 | url={https://arxiv.org/abs/1706.08840} 83 | } 84 | """ 85 | 86 | def __init__(self, agent_config): 87 | super(GEM, self).__init__(agent_config) 88 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # For convenience 89 | self.task_grads = {} 90 | self.quadprog = import_module('quadprog') 91 | self.task_mem_cache = {} 92 | 93 | def grad_to_vector(self): 94 | vec = [] 95 | for n,p in self.params.items(): 96 | if p.grad is not None: 97 | vec.append(p.grad.view(-1)) 98 | else: 99 | # Part of the network might has no grad, fill zero for those terms 100 | vec.append(p.data.clone().fill_(0).view(-1)) 101 | return torch.cat(vec) 102 | 103 | def vector_to_grad(self, vec): 104 | # Overwrite current param.grad by slicing the values in vec (flatten grad) 105 | pointer = 0 106 | for n, p in self.params.items(): 107 | # The length of the parameter 108 | num_param = p.numel() 109 | if p.grad is not None: 110 | # Slice the vector, reshape it, and replace the old data of the grad 111 | p.grad.copy_(vec[pointer:pointer + num_param].view_as(p)) 112 | # Part of the network might has no grad, ignore those terms 113 | # Increment the pointer 114 | pointer += num_param 115 | 116 | def project2cone2(self, gradient, memories): 117 | """ 118 | Solves the GEM dual QP described in the paper given a proposed 119 | gradient "gradient", and a memory of task gradients "memories". 120 | Overwrites "gradient" with the final projected update. 121 | 122 | input: gradient, p-vector 123 | input: memories, (t * p)-vector 124 | output: x, p-vector 125 | 126 | Modified from: https://github.com/facebookresearch/GradientEpisodicMemory/blob/master/model/gem.py#L70 127 | """ 128 | margin = self.config['reg_coef'] 129 | memories_np = memories.cpu().contiguous().double().numpy() 130 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy() 131 | t = memories_np.shape[0] 132 | #print(memories_np.shape, gradient_np.shape) 133 | P = np.dot(memories_np, memories_np.transpose()) 134 | P = 0.5 * (P + P.transpose()) 135 | q = np.dot(memories_np, gradient_np) * -1 136 | G = np.eye(t) 137 | P = P + G * 0.001 138 | h = np.zeros(t) + margin 139 | v = self.quadprog.solve_qp(P, q, G, h)[0] 140 | x = np.dot(v, memories_np) + gradient_np 141 | new_grad = torch.Tensor(x).view(-1) 142 | if self.gpu: 143 | new_grad = new_grad.cuda() 144 | return new_grad 145 | 146 | def learn_batch(self, train_loader, val_loader=None): 147 | 148 | # Update model as normal 149 | super(GEM, self).learn_batch(train_loader, val_loader) 150 | 151 | # Cache the data for faster processing 152 | for t, mem in self.task_memory.items(): 153 | # Concatenate all data in each task 154 | mem_loader = torch.utils.data.DataLoader(mem, 155 | batch_size=len(mem), 156 | shuffle=False, 157 | num_workers=2) 158 | assert len(mem_loader)==1,'The length of mem_loader should be 1' 159 | for i, (mem_input, mem_target, mem_task) in enumerate(mem_loader): 160 | if self.gpu: 161 | mem_input = mem_input.cuda() 162 | mem_target = mem_target.cuda() 163 | self.task_mem_cache[t] = {'data':mem_input,'target':mem_target,'task':mem_task} 164 | 165 | def update_model(self, inputs, targets, tasks): 166 | 167 | # compute gradient on previous tasks 168 | if self.task_count > 0: 169 | for t,mem in self.task_memory.items(): 170 | self.zero_grad() 171 | # feed the data from memory and collect the gradients 172 | mem_out = self.forward(self.task_mem_cache[t]['data']) 173 | mem_loss = self.criterion(mem_out, self.task_mem_cache[t]['target'], self.task_mem_cache[t]['task']) 174 | mem_loss.backward() 175 | # Store the grads 176 | self.task_grads[t] = self.grad_to_vector() 177 | 178 | # now compute the grad on the current minibatch 179 | out = self.forward(inputs) 180 | loss = self.criterion(out, targets, tasks) 181 | self.optimizer.zero_grad() 182 | loss.backward() 183 | 184 | # check if gradient violates constraints 185 | if self.task_count > 0: 186 | current_grad_vec = self.grad_to_vector() 187 | mem_grad_vec = torch.stack(list(self.task_grads.values())) 188 | dotp = current_grad_vec * mem_grad_vec 189 | dotp = dotp.sum(dim=1) 190 | if (dotp < 0).sum() != 0: 191 | new_grad = self.project2cone2(current_grad_vec, mem_grad_vec) 192 | # copy gradients back 193 | self.vector_to_grad(new_grad) 194 | 195 | self.optimizer.step() 196 | return loss.detach(), out 197 | -------------------------------------------------------------------------------- /agents/regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from .default import NormalNN 4 | 5 | 6 | class L2(NormalNN): 7 | """ 8 | @article{kirkpatrick2017overcoming, 9 | title={Overcoming catastrophic forgetting in neural networks}, 10 | author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others}, 11 | journal={Proceedings of the national academy of sciences}, 12 | year={2017}, 13 | url={https://arxiv.org/abs/1612.00796} 14 | } 15 | """ 16 | def __init__(self, agent_config): 17 | super(L2, self).__init__(agent_config) 18 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # For convenience 19 | self.regularization_terms = {} 20 | self.task_count = 0 21 | self.online_reg = True # True: There will be only one importance matrix and previous model parameters 22 | # False: Each task has its own importance matrix and model parameters 23 | 24 | def calculate_importance(self, dataloader): 25 | # Use an identity importance so it is an L2 regularization. 26 | importance = {} 27 | for n, p in self.params.items(): 28 | importance[n] = p.clone().detach().fill_(1) # Identity 29 | return importance 30 | 31 | def learn_batch(self, train_loader, val_loader=None): 32 | 33 | self.log('#reg_term:', len(self.regularization_terms)) 34 | 35 | # 1.Learn the parameters for current task 36 | super(L2, self).learn_batch(train_loader, val_loader) 37 | 38 | # 2.Backup the weight of current task 39 | task_param = {} 40 | for n, p in self.params.items(): 41 | task_param[n] = p.clone().detach() 42 | 43 | # 3.Calculate the importance of weights for current task 44 | importance = self.calculate_importance(train_loader) 45 | 46 | # Save the weight and importance of weights of current task 47 | self.task_count += 1 48 | if self.online_reg and len(self.regularization_terms)>0: 49 | # Always use only one slot in self.regularization_terms 50 | self.regularization_terms[1] = {'importance':importance, 'task_param':task_param} 51 | else: 52 | # Use a new slot to store the task-specific information 53 | self.regularization_terms[self.task_count] = {'importance':importance, 'task_param':task_param} 54 | 55 | def criterion(self, inputs, targets, tasks, regularization=True, **kwargs): 56 | loss = super(L2, self).criterion(inputs, targets, tasks, **kwargs) 57 | 58 | if regularization and len(self.regularization_terms)>0: 59 | # Calculate the reg_loss only when the regularization_terms exists 60 | reg_loss = 0 61 | for i,reg_term in self.regularization_terms.items(): 62 | task_reg_loss = 0 63 | importance = reg_term['importance'] 64 | task_param = reg_term['task_param'] 65 | for n, p in self.params.items(): 66 | task_reg_loss += (importance[n] * (p - task_param[n]) ** 2).sum() 67 | reg_loss += task_reg_loss 68 | loss += self.config['reg_coef'] * reg_loss 69 | return loss 70 | 71 | 72 | class EWC(L2): 73 | """ 74 | @article{kirkpatrick2017overcoming, 75 | title={Overcoming catastrophic forgetting in neural networks}, 76 | author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others}, 77 | journal={Proceedings of the national academy of sciences}, 78 | year={2017}, 79 | url={https://arxiv.org/abs/1612.00796} 80 | } 81 | """ 82 | 83 | def __init__(self, agent_config): 84 | super(EWC, self).__init__(agent_config) 85 | self.online_reg = False 86 | self.n_fisher_sample = None 87 | self.empFI = False 88 | 89 | def calculate_importance(self, dataloader): 90 | # Update the diag fisher information 91 | # There are several ways to estimate the F matrix. 92 | # We keep the implementation as simple as possible while maintaining a similar performance to the literature. 93 | self.log('Computing EWC') 94 | 95 | # Initialize the importance matrix 96 | if self.online_reg and len(self.regularization_terms)>0: 97 | importance = self.regularization_terms[1]['importance'] 98 | else: 99 | importance = {} 100 | for n, p in self.params.items(): 101 | importance[n] = p.clone().detach().fill_(0) # zero initialized 102 | 103 | # Sample a subset (n_fisher_sample) of data to estimate the fisher information (batch_size=1) 104 | # Otherwise it uses mini-batches for the estimation. This speeds up the process a lot with similar performance. 105 | if self.n_fisher_sample is not None: 106 | n_sample = min(self.n_fisher_sample, len(dataloader.dataset)) 107 | self.log('Sample',self.n_fisher_sample,'for estimating the F matrix.') 108 | rand_ind = random.sample(list(range(len(dataloader.dataset))), n_sample) 109 | subdata = torch.utils.data.Subset(dataloader.dataset, rand_ind) 110 | dataloader = torch.utils.data.DataLoader(subdata, shuffle=True, num_workers=2, batch_size=1) 111 | 112 | mode = self.training 113 | self.eval() 114 | 115 | # Accumulate the square of gradients 116 | for i, (input, target, task) in enumerate(dataloader): 117 | if self.gpu: 118 | input = input.cuda() 119 | target = target.cuda() 120 | 121 | preds = self.forward(input) 122 | 123 | # Sample the labels for estimating the gradients 124 | # For multi-headed model, the batch of data will be from the same task, 125 | # so we just use task[0] as the task name to fetch corresponding predictions 126 | # For single-headed model, just use the max of predictions from preds['All'] 127 | task_name = task[0] if self.multihead else 'All' 128 | 129 | # The flag self.valid_out_dim is for handling the case of incremental class learning. 130 | # if self.valid_out_dim is an integer, it means only the first 'self.valid_out_dim' dimensions are used 131 | # in calculating the loss. 132 | pred = preds[task_name] if not isinstance(self.valid_out_dim, int) else preds[task_name][:,:self.valid_out_dim] 133 | ind = pred.max(1)[1].flatten() # Choose the one with max 134 | 135 | # - Alternative ind by multinomial sampling. Its performance is similar. - 136 | # prob = torch.nn.functional.softmax(preds['All'],dim=1) 137 | # ind = torch.multinomial(prob,1).flatten() 138 | 139 | if self.empFI: # Use groundtruth label (default is without this) 140 | ind = target 141 | 142 | loss = self.criterion(preds, ind, task, regularization=False) 143 | self.model.zero_grad() 144 | loss.backward() 145 | for n, p in importance.items(): 146 | if self.params[n].grad is not None: # Some heads can have no grad if no loss applied on them. 147 | p += ((self.params[n].grad ** 2) * len(input) / len(dataloader)) 148 | 149 | self.train(mode=mode) 150 | 151 | return importance 152 | 153 | 154 | def EWC_online(agent_config): 155 | agent = EWC(agent_config) 156 | agent.online_reg = True 157 | return agent 158 | 159 | 160 | class SI(L2): 161 | """ 162 | @inproceedings{zenke2017continual, 163 | title={Continual Learning Through Synaptic Intelligence}, 164 | author={Zenke, Friedemann and Poole, Ben and Ganguli, Surya}, 165 | booktitle={International Conference on Machine Learning}, 166 | year={2017}, 167 | url={https://arxiv.org/abs/1703.04200} 168 | } 169 | """ 170 | 171 | def __init__(self, agent_config): 172 | super(SI, self).__init__(agent_config) 173 | self.online_reg = True # Original SI works in an online updating fashion 174 | self.damping_factor = 0.1 175 | self.w = {} 176 | for n, p in self.params.items(): 177 | self.w[n] = p.clone().detach().zero_() 178 | 179 | # The initial_params will only be used in the first task (when the regularization_terms is empty) 180 | self.initial_params = {} 181 | for n, p in self.params.items(): 182 | self.initial_params[n] = p.clone().detach() 183 | 184 | def update_model(self, inputs, targets, tasks): 185 | 186 | unreg_gradients = {} 187 | 188 | # 1.Save current parameters 189 | old_params = {} 190 | for n, p in self.params.items(): 191 | old_params[n] = p.clone().detach() 192 | 193 | # 2. Collect the gradients without regularization term 194 | out = self.forward(inputs) 195 | loss = self.criterion(out, targets, tasks, regularization=False) 196 | self.optimizer.zero_grad() 197 | loss.backward(retain_graph=True) 198 | for n, p in self.params.items(): 199 | if p.grad is not None: 200 | unreg_gradients[n] = p.grad.clone().detach() 201 | 202 | # 3. Normal update with regularization 203 | loss = self.criterion(out, targets, tasks, regularization=True) 204 | self.optimizer.zero_grad() 205 | loss.backward() 206 | self.optimizer.step() 207 | 208 | # 4. Accumulate the w 209 | for n, p in self.params.items(): 210 | delta = p.detach() - old_params[n] 211 | if n in unreg_gradients.keys(): # In multi-head network, some head could have no grad (lazy) since no loss go through it. 212 | self.w[n] -= unreg_gradients[n] * delta # w[n] is >=0 213 | 214 | return loss.detach(), out 215 | 216 | """ 217 | # - Alternative simplified implementation with similar performance - 218 | def update_model(self, inputs, targets, tasks): 219 | # A wrapper of original update step to include the estimation of w 220 | 221 | # Backup prev param if not done yet 222 | # The backup only happened at the beginning of a new task 223 | if len(self.prev_params) == 0: 224 | for n, p in self.params.items(): 225 | self.prev_params[n] = p.clone().detach() 226 | 227 | # 1.Save current parameters 228 | old_params = {} 229 | for n, p in self.params.items(): 230 | old_params[n] = p.clone().detach() 231 | 232 | # 2.Calculate the loss as usual 233 | loss, out = super(SI, self).update_model(inputs, targets, tasks) 234 | 235 | # 3.Accumulate the w 236 | for n, p in self.params.items(): 237 | delta = p.detach() - old_params[n] 238 | if p.grad is not None: # In multi-head network, some head could have no grad (lazy) since no loss go through it. 239 | self.w[n] -= p.grad * delta # w[n] is >=0 240 | 241 | return loss.detach(), out 242 | """ 243 | 244 | def calculate_importance(self, dataloader): 245 | self.log('Computing SI') 246 | assert self.online_reg,'SI needs online_reg=True' 247 | 248 | # Initialize the importance matrix 249 | if len(self.regularization_terms)>0: # The case of after the first task 250 | importance = self.regularization_terms[1]['importance'] 251 | prev_params = self.regularization_terms[1]['task_param'] 252 | else: # It is in the first task 253 | importance = {} 254 | for n, p in self.params.items(): 255 | importance[n] = p.clone().detach().fill_(0) # zero initialized 256 | prev_params = self.initial_params 257 | 258 | # Calculate or accumulate the Omega (the importance matrix) 259 | for n, p in importance.items(): 260 | delta_theta = self.params[n].detach() - prev_params[n] 261 | p += self.w[n]/(delta_theta**2 + self.damping_factor) 262 | self.w[n].zero_() 263 | 264 | return importance 265 | 266 | 267 | class MAS(L2): 268 | """ 269 | @article{aljundi2017memory, 270 | title={Memory Aware Synapses: Learning what (not) to forget}, 271 | author={Aljundi, Rahaf and Babiloni, Francesca and Elhoseiny, Mohamed and Rohrbach, Marcus and Tuytelaars, Tinne}, 272 | booktitle={ECCV}, 273 | year={2018}, 274 | url={https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf} 275 | } 276 | """ 277 | 278 | def __init__(self, agent_config): 279 | super(MAS, self).__init__(agent_config) 280 | self.online_reg = True 281 | 282 | def calculate_importance(self, dataloader): 283 | self.log('Computing MAS') 284 | 285 | # Initialize the importance matrix 286 | if self.online_reg and len(self.regularization_terms)>0: 287 | importance = self.regularization_terms[1]['importance'] 288 | else: 289 | importance = {} 290 | for n, p in self.params.items(): 291 | importance[n] = p.clone().detach().fill_(0) # zero initialized 292 | 293 | mode = self.training 294 | self.eval() 295 | 296 | # Accumulate the gradients of L2 loss on the outputs 297 | for i, (input, target, task) in enumerate(dataloader): 298 | if self.gpu: 299 | input = input.cuda() 300 | target = target.cuda() 301 | 302 | preds = self.forward(input) 303 | 304 | # Sample the labels for estimating the gradients 305 | # For multi-headed model, the batch of data will be from the same task, 306 | # so we just use task[0] as the task name to fetch corresponding predictions 307 | # For single-headed model, just use the max of predictions from preds['All'] 308 | task_name = task[0] if self.multihead else 'All' 309 | 310 | # The flag self.valid_out_dim is for handling the case of incremental class learning. 311 | # if self.valid_out_dim is an integer, it means only the first 'self.valid_out_dim' dimensions are used 312 | # in calculating the loss. 313 | pred = preds[task_name] if not isinstance(self.valid_out_dim, int) else preds[task_name][:,:self.valid_out_dim] 314 | 315 | pred.pow_(2) 316 | loss = pred.mean() 317 | 318 | self.model.zero_grad() 319 | loss.backward() 320 | for n, p in importance.items(): 321 | if self.params[n].grad is not None: # Some heads can have no grad if no loss applied on them. 322 | p += (self.params[n].grad.abs() / len(dataloader)) 323 | 324 | self.train(mode=mode) 325 | 326 | return importance -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-RIPL/Continual-Learning-Benchmark/d78b9973b6ec0059b2d2577872db355ae2489f6b/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/base.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision import transforms 3 | from .wrapper import CacheClassLabel 4 | 5 | def MNIST(dataroot, train_aug=False): 6 | # Add padding to make 32x32 7 | #normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # for 28x28 8 | normalize = transforms.Normalize(mean=(0.1000,), std=(0.2752,)) # for 32x32 9 | 10 | val_transform = transforms.Compose([ 11 | transforms.Pad(2, fill=0, padding_mode='constant'), 12 | transforms.ToTensor(), 13 | normalize, 14 | ]) 15 | train_transform = val_transform 16 | if train_aug: 17 | train_transform = transforms.Compose([ 18 | transforms.RandomCrop(32, padding=4), 19 | transforms.ToTensor(), 20 | normalize, 21 | ]) 22 | 23 | train_dataset = torchvision.datasets.MNIST( 24 | root=dataroot, 25 | train=True, 26 | download=True, 27 | transform=train_transform 28 | ) 29 | train_dataset = CacheClassLabel(train_dataset) 30 | 31 | val_dataset = torchvision.datasets.MNIST( 32 | dataroot, 33 | train=False, 34 | transform=val_transform 35 | ) 36 | val_dataset = CacheClassLabel(val_dataset) 37 | 38 | return train_dataset, val_dataset 39 | 40 | def CIFAR10(dataroot, train_aug=False): 41 | normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]) 42 | 43 | val_transform = transforms.Compose([ 44 | transforms.ToTensor(), 45 | normalize, 46 | ]) 47 | train_transform = val_transform 48 | if train_aug: 49 | train_transform = transforms.Compose([ 50 | transforms.RandomCrop(32, padding=4), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | normalize, 54 | ]) 55 | 56 | train_dataset = torchvision.datasets.CIFAR10( 57 | root=dataroot, 58 | train=True, 59 | download=True, 60 | transform=train_transform 61 | ) 62 | train_dataset = CacheClassLabel(train_dataset) 63 | 64 | val_dataset = torchvision.datasets.CIFAR10( 65 | root=dataroot, 66 | train=False, 67 | download=True, 68 | transform=val_transform 69 | ) 70 | val_dataset = CacheClassLabel(val_dataset) 71 | 72 | return train_dataset, val_dataset 73 | 74 | 75 | def CIFAR100(dataroot, train_aug=False): 76 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 77 | 78 | val_transform = transforms.Compose([ 79 | transforms.ToTensor(), 80 | normalize, 81 | ]) 82 | train_transform = val_transform 83 | if train_aug: 84 | train_transform = transforms.Compose([ 85 | transforms.RandomCrop(32, padding=4), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | normalize, 89 | ]) 90 | 91 | train_dataset = torchvision.datasets.CIFAR100( 92 | root=dataroot, 93 | train=True, 94 | download=True, 95 | transform=train_transform 96 | ) 97 | train_dataset = CacheClassLabel(train_dataset) 98 | 99 | val_dataset = torchvision.datasets.CIFAR100( 100 | root=dataroot, 101 | train=False, 102 | download=True, 103 | transform=val_transform 104 | ) 105 | val_dataset = CacheClassLabel(val_dataset) 106 | 107 | return train_dataset, val_dataset 108 | 109 | -------------------------------------------------------------------------------- /dataloaders/datasetGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from random import shuffle 3 | from .wrapper import Subclass, AppendName, Permutation 4 | 5 | 6 | def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False): 7 | ''' 8 | Generate the dataset splits based on the labels. 9 | :param train_dataset: (torch.utils.data.dataset) 10 | :param val_dataset: (torch.utils.data.dataset) 11 | :param first_split_sz: (int) 12 | :param other_split_sz: (int) 13 | :param rand_split: (bool) Randomize the set of label in each split 14 | :param remap_class: (bool) Ex: remap classes in a split from [2,4,6 ...] to [0,1,2 ...] 15 | :return: train_loaders {task_name:loader}, val_loaders {task_name:loader}, out_dim {task_name:num_classes} 16 | ''' 17 | assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes' 18 | num_classes = train_dataset.number_classes 19 | 20 | # Calculate the boundary index of classes for splits 21 | # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100] 22 | split_boundaries = [0, first_split_sz] 23 | while split_boundaries[-1]0: 20 | train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset, 21 | args.n_permutation, 22 | remap_class=not args.no_class_remap) 23 | else: 24 | train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset, 25 | first_split_sz=args.first_split_size, 26 | other_split_sz=args.other_split_size, 27 | rand_split=args.rand_split, 28 | remap_class=not args.no_class_remap) 29 | 30 | # Prepare the Agent (model) 31 | agent_config = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay,'schedule': args.schedule, 32 | 'model_type':args.model_type, 'model_name': args.model_name, 'model_weights':args.model_weights, 33 | 'out_dim':{'All':args.force_out_dim} if args.force_out_dim>0 else task_output_space, 34 | 'optimizer':args.optimizer, 35 | 'print_freq':args.print_freq, 'gpuid': args.gpuid, 36 | 'reg_coef':args.reg_coef} 37 | agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config) 38 | print(agent.model) 39 | print('#parameter of model:',agent.count_parameter()) 40 | 41 | # Decide split ordering 42 | task_names = sorted(list(task_output_space.keys()), key=int) 43 | print('Task order:',task_names) 44 | if args.rand_split_order: 45 | shuffle(task_names) 46 | print('Shuffled task order:', task_names) 47 | 48 | acc_table = OrderedDict() 49 | if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance 50 | task_names = ['All'] 51 | train_dataset_all = torch.utils.data.ConcatDataset(train_dataset_splits.values()) 52 | val_dataset_all = torch.utils.data.ConcatDataset(val_dataset_splits.values()) 53 | train_loader = torch.utils.data.DataLoader(train_dataset_all, 54 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 55 | val_loader = torch.utils.data.DataLoader(val_dataset_all, 56 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 57 | 58 | agent.learn_batch(train_loader, val_loader) 59 | 60 | acc_table['All'] = {} 61 | acc_table['All']['All'] = agent.validation(val_loader) 62 | 63 | else: # Incremental learning 64 | # Feed data to agent and evaluate agent's performance 65 | for i in range(len(task_names)): 66 | train_name = task_names[i] 67 | print('======================',train_name,'=======================') 68 | train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name], 69 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 70 | val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name], 71 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 72 | 73 | if args.incremental_class: 74 | agent.add_valid_output_dim(task_output_space[train_name]) 75 | 76 | # Learn 77 | agent.learn_batch(train_loader, val_loader) 78 | 79 | # Evaluate 80 | acc_table[train_name] = OrderedDict() 81 | for j in range(i+1): 82 | val_name = task_names[j] 83 | print('validation split name:', val_name) 84 | val_data = val_dataset_splits[val_name] if not args.eval_on_train_set else train_dataset_splits[val_name] 85 | val_loader = torch.utils.data.DataLoader(val_data, 86 | batch_size=args.batch_size, shuffle=False, 87 | num_workers=args.workers) 88 | acc_table[val_name][train_name] = agent.validation(val_loader) 89 | 90 | return acc_table, task_names 91 | 92 | def get_args(argv): 93 | # This function prepares the variables shared across demo.py 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--gpuid', nargs="+", type=int, default=[0], 96 | help="The list of gpuid, ex:--gpuid 3 1. Negative value means cpu-only") 97 | parser.add_argument('--model_type', type=str, default='mlp', help="The type (mlp|lenet|vgg|resnet) of backbone network") 98 | parser.add_argument('--model_name', type=str, default='MLP', help="The name of actual model for the backbone") 99 | parser.add_argument('--force_out_dim', type=int, default=2, help="Set 0 to let the task decide the required output dimension") 100 | parser.add_argument('--agent_type', type=str, default='default', help="The type (filename) of agent") 101 | parser.add_argument('--agent_name', type=str, default='NormalNN', help="The class name of agent") 102 | parser.add_argument('--optimizer', type=str, default='SGD', help="SGD|Adam|RMSprop|amsgrad|Adadelta|Adagrad|Adamax ...") 103 | parser.add_argument('--dataroot', type=str, default='data', help="The root folder of dataset or downloaded data") 104 | parser.add_argument('--dataset', type=str, default='MNIST', help="MNIST(default)|CIFAR10|CIFAR100") 105 | parser.add_argument('--n_permutation', type=int, default=0, help="Enable permuted tests when >0") 106 | parser.add_argument('--first_split_size', type=int, default=2) 107 | parser.add_argument('--other_split_size', type=int, default=2) 108 | parser.add_argument('--no_class_remap', dest='no_class_remap', default=False, action='store_true', 109 | help="Avoid the dataset with a subset of classes doing the remapping. Ex: [2,5,6 ...] -> [0,1,2 ...]") 110 | parser.add_argument('--train_aug', dest='train_aug', default=False, action='store_true', 111 | help="Allow data augmentation during training") 112 | parser.add_argument('--rand_split', dest='rand_split', default=False, action='store_true', 113 | help="Randomize the classes in splits") 114 | parser.add_argument('--rand_split_order', dest='rand_split_order', default=False, action='store_true', 115 | help="Randomize the order of splits") 116 | parser.add_argument('--workers', type=int, default=3, help="#Thread for dataloader") 117 | parser.add_argument('--batch_size', type=int, default=100) 118 | parser.add_argument('--lr', type=float, default=0.01, help="Learning rate") 119 | parser.add_argument('--momentum', type=float, default=0) 120 | parser.add_argument('--weight_decay', type=float, default=0) 121 | parser.add_argument('--schedule', nargs="+", type=int, default=[2], 122 | help="The list of epoch numbers to reduce learning rate by factor of 0.1. Last number is the end epoch") 123 | parser.add_argument('--print_freq', type=float, default=100, help="Print the log at every x iteration") 124 | parser.add_argument('--model_weights', type=str, default=None, 125 | help="The path to the file for the model weights (*.pth).") 126 | parser.add_argument('--reg_coef', nargs="+", type=float, default=[0.], help="The coefficient for regularization. Larger means less plasilicity. Give a list for hyperparameter search.") 127 | parser.add_argument('--eval_on_train_set', dest='eval_on_train_set', default=False, action='store_true', 128 | help="Force the evaluation on train set") 129 | parser.add_argument('--offline_training', dest='offline_training', default=False, action='store_true', 130 | help="Non-incremental learning by make all data available in one batch. For measuring the upperbound performance.") 131 | parser.add_argument('--repeat', type=int, default=1, help="Repeat the experiment N times") 132 | parser.add_argument('--incremental_class', dest='incremental_class', default=False, action='store_true', 133 | help="The number of output node in the single-headed model increases along with new categories.") 134 | args = parser.parse_args(argv) 135 | return args 136 | 137 | if __name__ == '__main__': 138 | args = get_args(sys.argv[1:]) 139 | reg_coef_list = args.reg_coef 140 | avg_final_acc = {} 141 | 142 | # The for loops over hyper-paramerters or repeats 143 | for reg_coef in reg_coef_list: 144 | args.reg_coef = reg_coef 145 | avg_final_acc[reg_coef] = np.zeros(args.repeat) 146 | for r in range(args.repeat): 147 | 148 | # Run the experiment 149 | acc_table, task_names = run(args) 150 | print(acc_table) 151 | 152 | # Calculate average performance across tasks 153 | # Customize this part for a different performance metric 154 | avg_acc_history = [0] * len(task_names) 155 | for i in range(len(task_names)): 156 | train_name = task_names[i] 157 | cls_acc_sum = 0 158 | for j in range(i + 1): 159 | val_name = task_names[j] 160 | cls_acc_sum += acc_table[val_name][train_name] 161 | avg_acc_history[i] = cls_acc_sum / (i + 1) 162 | print('Task', train_name, 'average acc:', avg_acc_history[i]) 163 | 164 | # Gather the final avg accuracy 165 | avg_final_acc[reg_coef][r] = avg_acc_history[-1] 166 | 167 | # Print the summary so far 168 | print('===Summary of experiment repeats:',r+1,'/',args.repeat,'===') 169 | print('The regularization coefficient:', args.reg_coef) 170 | print('The last avg acc of all repeats:', avg_final_acc[reg_coef]) 171 | print('mean:', avg_final_acc[reg_coef].mean(), 'std:', avg_final_acc[reg_coef].std()) 172 | for reg_coef,v in avg_final_acc.items(): 173 | print('reg_coef:', reg_coef,'mean:', avg_final_acc[reg_coef].mean(), 'std:', avg_final_acc[reg_coef].std()) 174 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import mlp 2 | from . import lenet 3 | from . import resnet 4 | from . import senet -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LeNet(nn.Module): 5 | 6 | def __init__(self, out_dim=10, in_channel=1, img_sz=32): 7 | super(LeNet, self).__init__() 8 | feat_map_sz = img_sz//4 9 | self.n_feat = 50 * feat_map_sz * feat_map_sz 10 | 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_channel, 20, 5, padding=2), 13 | nn.BatchNorm2d(20), 14 | nn.ReLU(inplace=True), 15 | nn.MaxPool2d(2, 2), 16 | nn.Conv2d(20, 50, 5, padding=2), 17 | nn.BatchNorm2d(50), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(2, 2) 20 | ) 21 | self.linear = nn.Sequential( 22 | nn.Linear(self.n_feat, 500), 23 | nn.BatchNorm1d(500), 24 | nn.ReLU(inplace=True), 25 | ) 26 | self.last = nn.Linear(500, out_dim) # Subject to be replaced dependent on task 27 | 28 | def features(self, x): 29 | x = self.conv(x) 30 | x = self.linear(x.view(-1, self.n_feat)) 31 | return x 32 | 33 | def logits(self, x): 34 | x = self.last(x) 35 | return x 36 | 37 | def forward(self, x): 38 | x = self.features(x) 39 | x = self.logits(x) 40 | return x 41 | 42 | 43 | def LeNetC(out_dim=10): # LeNet with color input 44 | return LeNet(out_dim=out_dim, in_channel=3, img_sz=32) -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MLP(nn.Module): 6 | 7 | def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256): 8 | super(MLP, self).__init__() 9 | self.in_dim = in_channel*img_sz*img_sz 10 | self.linear = nn.Sequential( 11 | nn.Linear(self.in_dim, hidden_dim), 12 | #nn.BatchNorm1d(hidden_dim), 13 | nn.ReLU(inplace=True), 14 | nn.Linear(hidden_dim, hidden_dim), 15 | #nn.BatchNorm1d(hidden_dim), 16 | nn.ReLU(inplace=True), 17 | ) 18 | self.last = nn.Linear(hidden_dim, out_dim) # Subject to be replaced dependent on task 19 | 20 | def features(self, x): 21 | x = self.linear(x.view(-1,self.in_dim)) 22 | return x 23 | 24 | def logits(self, x): 25 | x = self.last(x) 26 | return x 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = self.logits(x) 31 | return x 32 | 33 | 34 | def MLP100(): 35 | return MLP(hidden_dim=100) 36 | 37 | 38 | def MLP400(): 39 | return MLP(hidden_dim=400) 40 | 41 | 42 | def MLP1000(): 43 | return MLP(hidden_dim=1000) 44 | 45 | 46 | def MLP2000(): 47 | return MLP(hidden_dim=2000) 48 | 49 | 50 | def MLP5000(): 51 | return MLP(hidden_dim=5000) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | from torch.nn import init 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | class PreActBlock(nn.Module): 12 | '''Pre-activation version of the BasicBlock.''' 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, droprate=0): 16 | super(PreActBlock, self).__init__() 17 | self.bn1 = nn.BatchNorm2d(in_planes) 18 | self.conv1 = conv3x3(in_planes, planes, stride) 19 | self.drop = nn.Dropout(p=droprate) if droprate>0 else None 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | if self.drop is not None: 33 | out = self.drop(out) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out += shortcut 36 | return out 37 | 38 | 39 | class PreActBottleneck(nn.Module): 40 | '''Pre-activation version of the original Bottleneck module.''' 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1, droprate=None): 44 | super(PreActBottleneck, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(x)) 59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 60 | out = self.conv1(out) 61 | out = self.conv2(F.relu(self.bn2(out))) 62 | out = self.conv3(F.relu(self.bn3(out))) 63 | out += shortcut 64 | return out 65 | 66 | 67 | class PreActResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10, in_channels=3): 69 | super(PreActResNet, self).__init__() 70 | self.in_planes = 64 71 | last_planes = 512*block.expansion 72 | 73 | self.conv1 = conv3x3(in_channels, 64) 74 | self.stage1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.stage2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.stage3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.stage4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.bn_last = nn.BatchNorm2d(last_planes) 79 | self.last = nn.Linear(last_planes, num_classes) 80 | 81 | def _make_layer(self, block, planes, num_blocks, stride): 82 | strides = [stride] + [1]*(num_blocks-1) 83 | layers = [] 84 | for stride in strides: 85 | layers.append(block(self.in_planes, planes, stride)) 86 | self.in_planes = planes * block.expansion 87 | return nn.Sequential(*layers) 88 | 89 | def features(self, x): 90 | out = self.conv1(x) 91 | out = self.stage1(out) 92 | out = self.stage2(out) 93 | out = self.stage3(out) 94 | out = self.stage4(out) 95 | return out 96 | 97 | def logits(self, x): 98 | x = self.last(x) 99 | return x 100 | 101 | def forward(self, x): 102 | x = self.features(x) 103 | x = F.relu(self.bn_last(x)) 104 | x = F.adaptive_avg_pool2d(x, 1) 105 | x = self.logits(x.view(x.size(0), -1)) 106 | return x 107 | 108 | 109 | class PreActResNet_cifar(nn.Module): 110 | def __init__(self, block, num_blocks, filters, num_classes=10, droprate=0): 111 | super(PreActResNet_cifar, self).__init__() 112 | self.in_planes = 16 113 | last_planes = filters[2]*block.expansion 114 | 115 | self.conv1 = conv3x3(3, self.in_planes) 116 | self.stage1 = self._make_layer(block, filters[0], num_blocks[0], stride=1, droprate=droprate) 117 | self.stage2 = self._make_layer(block, filters[1], num_blocks[1], stride=2, droprate=droprate) 118 | self.stage3 = self._make_layer(block, filters[2], num_blocks[2], stride=2, droprate=droprate) 119 | self.bn_last = nn.BatchNorm2d(last_planes) 120 | self.last = nn.Linear(last_planes, num_classes) 121 | 122 | """ 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | # m.bias.data.zero_() 128 | elif isinstance(m, nn.BatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | elif isinstance(m, nn.Linear): 132 | init.kaiming_normal(m.weight) 133 | m.bias.data.zero_() 134 | """ 135 | 136 | def _make_layer(self, block, planes, num_blocks, stride, droprate): 137 | strides = [stride] + [1]*(num_blocks-1) 138 | layers = [] 139 | for stride in strides: 140 | layers.append(block(self.in_planes, planes, stride, droprate)) 141 | self.in_planes = planes * block.expansion 142 | return nn.Sequential(*layers) 143 | 144 | def features(self, x): 145 | out = self.conv1(x) 146 | out = self.stage1(out) 147 | out = self.stage2(out) 148 | out = self.stage3(out) 149 | return out 150 | 151 | def logits(self, x): 152 | x = self.last(x) 153 | return x 154 | 155 | def forward(self, x): 156 | out = self.features(x) 157 | out = F.relu(self.bn_last(out)) 158 | out = F.avg_pool2d(out, 8) 159 | out = self.logits(out.view(out.size(0), -1)) 160 | return out 161 | 162 | 163 | # ResNet for Cifar10/100 or the dataset with image size 32x32 164 | 165 | def ResNet20_cifar(out_dim=10): 166 | return PreActResNet_cifar(PreActBlock, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim) 167 | 168 | def ResNet56_cifar(out_dim=10): 169 | return PreActResNet_cifar(PreActBlock, [9 , 9 , 9 ], [16, 32, 64], num_classes=out_dim) 170 | 171 | def ResNet110_cifar(out_dim=10): 172 | return PreActResNet_cifar(PreActBlock, [18, 18, 18], [16, 32, 64], num_classes=out_dim) 173 | 174 | def ResNet29_cifar(out_dim=10): 175 | return PreActResNet_cifar(PreActBottleneck, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim) 176 | 177 | def ResNet164_cifar(out_dim=10): 178 | return PreActResNet_cifar(PreActBottleneck, [18, 18, 18], [16, 32, 64], num_classes=out_dim) 179 | 180 | def WideResNet_28_2_cifar(out_dim=10): 181 | return PreActResNet_cifar(PreActBlock, [4, 4, 4], [32, 64, 128], num_classes=out_dim) 182 | 183 | def WideResNet_28_2_drop_cifar(out_dim=10): 184 | return PreActResNet_cifar(PreActBlock, [4, 4, 4], [32, 64, 128], num_classes=out_dim, droprate=0.3) 185 | 186 | def WideResNet_28_10_cifar(out_dim=10): 187 | return PreActResNet_cifar(PreActBlock, [4, 4, 4], [160, 320, 640], num_classes=out_dim) 188 | 189 | # ResNet for general purpose. Ex:ImageNet 190 | 191 | def ResNet10(out_dim=10): 192 | return PreActResNet(PreActBlock, [1,1,1,1], num_classes=out_dim) 193 | 194 | def ResNet18S(out_dim=10): 195 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=out_dim, in_channels=1) 196 | 197 | def ResNet18(out_dim=10): 198 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=out_dim) 199 | 200 | def ResNet34(out_dim=10): 201 | return PreActResNet(PreActBlock, [3,4,6,3], num_classes=out_dim) 202 | 203 | def ResNet50(out_dim=10): 204 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes=out_dim) 205 | 206 | def ResNet101(out_dim=10): 207 | return PreActResNet(PreActBottleneck, [3,4,23,3], num_classes=out_dim) 208 | 209 | def ResNet152(out_dim=10): 210 | return PreActResNet(PreActBottleneck, [3,8,36,3], num_classes=out_dim) -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .resnet import conv3x3, PreActResNet, PreActResNet_cifar 5 | 6 | 7 | class SE_PreActBlock(nn.Module): 8 | '''Pre-activation version of the BasicBlock.''' 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(SE_PreActBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv1 = conv3x3(in_planes, planes, stride) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.conv2 = conv3x3(planes, planes) 17 | 18 | if stride != 1 or in_planes != self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 21 | ) 22 | 23 | # SE layers 24 | self.fc1 = nn.Conv2d(planes, planes // 16, kernel_size=1) 25 | self.fc2 = nn.Conv2d(planes // 16, planes, kernel_size=1) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(x)) 29 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 30 | out = self.conv1(out) 31 | out = self.conv2(F.relu(self.bn2(out))) 32 | # Squeeze 33 | w = F.avg_pool2d(out, out.size(2)) 34 | w = F.relu(self.fc1(w)) 35 | w = torch.sigmoid(self.fc2(w)) 36 | # Excitation 37 | out = out * w # New broadcasting feature from v0.2! 38 | out += shortcut 39 | return out 40 | 41 | 42 | class SE_PreActBottleneck(nn.Module): 43 | '''Pre-activation version of the original Bottleneck module.''' 44 | expansion = 4 45 | 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(SE_PreActBottleneck, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 54 | 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | 60 | # SE layers 61 | self.fc1 = nn.Conv2d(self.expansion*planes, self.expansion*planes // 16, kernel_size=1) 62 | self.fc2 = nn.Conv2d(self.expansion*planes // 16, self.expansion*planes, kernel_size=1) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | out = self.conv3(F.relu(self.bn3(out))) 70 | # Squeeze 71 | w = F.avg_pool2d(out, out.size(2)) 72 | w = F.relu(self.fc1(w)) 73 | w = torch.sigmoid(self.fc2(w)) 74 | # Excitation 75 | out = out * w 76 | out += shortcut 77 | return out 78 | 79 | 80 | # ResNet for Cifar10/100 or the dataset with image size 32x32 81 | 82 | def SE_ResNet20_cifar(out_dim=10): 83 | return PreActResNet_cifar(SE_PreActBlock, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim) 84 | 85 | def SE_ResNet56_cifar(out_dim=10): 86 | return PreActResNet_cifar(SE_PreActBlock, [9 , 9 , 9 ], [16, 32, 64], num_classes=out_dim) 87 | 88 | def ResNet110_cifar(out_dim=10): 89 | return PreActResNet_cifar(SE_PreActBlock, [18, 18, 18], [16, 32, 64], num_classes=out_dim) 90 | 91 | def SE_ResNet29_cifar(out_dim=10): 92 | return PreActResNet_cifar(SE_PreActBottleneck, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim) 93 | 94 | def SE_ResNet164_cifar(out_dim=10): 95 | return PreActResNet_cifar(SE_PreActBottleneck, [18, 18, 18], [16, 32, 64], num_classes=out_dim) 96 | 97 | def SE_WideResNet_28_2_cifar(out_dim=10): 98 | return PreActResNet_cifar(SE_PreActBlock, [4, 4, 4], [32, 64, 128], num_classes=out_dim) 99 | 100 | def SE_WideResNet_28_10_cifar(out_dim=10): 101 | return PreActResNet_cifar(SE_PreActBlock, [4, 4, 4], [160, 320, 640], num_classes=out_dim) 102 | 103 | # ResNet for general purpose. Ex:ImageNet 104 | 105 | def SE_ResNet10(out_dim=10): 106 | return PreActResNet(SE_PreActBlock, [1,1,1,1], num_classes=out_dim) 107 | 108 | def SE_ResNet18S(out_dim=10): 109 | return PreActResNet(SE_PreActBlock, [2,2,2,2], num_classes=out_dim, in_channels=1) 110 | 111 | def SE_ResNet18(out_dim=10): 112 | return PreActResNet(SE_PreActBlock, [2,2,2,2], num_classes=out_dim) 113 | 114 | def SE_ResNet34(out_dim=10): 115 | return PreActResNet(SE_PreActBlock, [3,4,6,3], num_classes=out_dim) 116 | 117 | def SE_ResNet50(out_dim=10): 118 | return PreActResNet(SE_PreActBottleneck, [3,4,6,3], num_classes=out_dim) 119 | 120 | def SE_ResNet101(out_dim=10): 121 | return PreActResNet(SE_PreActBottleneck, [3,4,23,3], num_classes=out_dim) 122 | 123 | def SE_ResNet152(out_dim=10): 124 | return PreActResNet(SE_PreActBottleneck, [3,8,36,3], num_classes=out_dim) 125 | 126 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-RIPL/Continual-Learning-Benchmark/d78b9973b6ec0059b2d2577872db355ae2489f6b/modules/__init__.py -------------------------------------------------------------------------------- /modules/criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class BCEauto(torch.nn.BCEWithLogitsLoss): 4 | """ 5 | BCE with logits loss + automatically convert the target from class label to one-hot vector 6 | """ 7 | def forward(self, x, y): 8 | assert x.ndimension() == 2, 'Input size must be 2D' 9 | assert y.numel() == x.size(0), 'The size of input and target doesnt match. Number of input:' + str(x.size(0)) + ' Number of target:' + str(y.numel()) 10 | y_onehot = x.clone().zero_() 11 | y_onehot.scatter_(1, y.view(-1, 1), 1) 12 | 13 | return super(BCEauto, self).forward(x, y_onehot) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | argparse 4 | quadprog 5 | 6 | -------------------------------------------------------------------------------- /scripts/permuted_MNIST_incremental_class.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/permuted_MNIST_incremental_class 3 | REPEAT=10 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 --offline_training | tee ${OUTDIR}/Offline.log 6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 | tee ${OUTDIR}/Adam.log 7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/SGD.log 8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adagrad --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/Adagrad.log 9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_online_mnist --lr 0.0001 --reg_coef 50 | tee ${OUTDIR}/EWC_online.log 10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_mnist --lr 0.0001 --reg_coef 10 | tee ${OUTDIR}/EWC.log 11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name SI --lr 0.0001 --reg_coef 0.3 | tee ${OUTDIR}/SI.log 12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name L2 --lr 0.0001 --reg_coef 0 | tee ${OUTDIR}/L2.log 13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_4000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_4000.log 14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_16000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_16000.log 15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name MAS --lr 0.0001 --reg_coef 0.003 | tee ${OUTDIR}/MAS.log 16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_4000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4000.log 17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_16000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_16000.log 18 | -------------------------------------------------------------------------------- /scripts/permuted_MNIST_incremental_domain.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/permuted_MNIST_incremental_domain 3 | REPEAT=10 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 --offline_training | tee ${OUTDIR}/Offline.log 6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 | tee ${OUTDIR}/Adam.log 7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/SGD.log 8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/Adagrad.log 9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_online --lr 0.0001 --reg_coef 250 | tee ${OUTDIR}/EWC_online.log 10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC --lr 0.0001 --reg_coef 150 | tee ${OUTDIR}/EWC.log 11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name SI --lr 0.0001 --reg_coef 10 | tee ${OUTDIR}/SI.log 12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name L2 --lr 0.0001 --reg_coef 0.02 | tee ${OUTDIR}/L2.log 13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_4000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_4000.log 14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_16000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_16000.log 15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name MAS --lr 0.0001 --reg_coef 0.1 | tee ${OUTDIR}/MAS.log 16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_4000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4000.log 17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_16000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_16000.log -------------------------------------------------------------------------------- /scripts/permuted_MNIST_incremental_task.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/permuted_MNIST_incremental_task 3 | REPEAT=10 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 --offline_training | tee ${OUTDIR}/Offline.log 6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 | tee ${OUTDIR}/Adam.log 7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.01 | tee ${OUTDIR}/SGD.log 8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/Adagrad.log 9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_online_mnist --lr 0.0001 --reg_coef 500 | tee ${OUTDIR}/EWC_online.log 10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_mnist --lr 0.0001 --reg_coef 500 | tee ${OUTDIR}/EWC.log 11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name SI --lr 0.0001 --reg_coef 1 | tee ${OUTDIR}/SI.log 12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name L2 --lr 0.0001 --reg_coef 0.001 | tee ${OUTDIR}/L2.log 13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_4000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_4000.log 14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_16000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_16000.log 15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name MAS --lr 0.0001 --reg_coef 0.01 | tee ${OUTDIR}/MAS.log 16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_4000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4000.log 17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_16000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_16000.log 18 | -------------------------------------------------------------------------------- /scripts/split_CIFAR100_incremental_class.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/split_CIFAR100_incremental_class 3 | REPEAT=5 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.1 --momentum 0.9 --weight_decay 1e-4 --offline_training | tee ${OUTDIR}/Offline_SGD_WideResNet_28_2_cifar.log 6 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 --offline_training | tee ${OUTDIR}/Offline_Adam_WideResNet_28_2_cifar.log 7 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 | tee ${OUTDIR}/Adam.log 8 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.1 | tee ${OUTDIR}/SGD.log 9 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adagrad --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.1 | tee ${OUTDIR}/Adagrad.log 10 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC --lr 0.001 --reg_coef 2 | tee ${OUTDIR}/EWC.log 11 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 2 | tee ${OUTDIR}/EWC_online.log 12 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 0.001 | tee ${OUTDIR}/SI.log 13 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 500 | tee ${OUTDIR}/L2.log 14 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_1400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1400.log 15 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4600.log 16 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 0.001 |tee ${OUTDIR}/MAS.log -------------------------------------------------------------------------------- /scripts/split_CIFAR100_incremental_domain.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/split_CIFAR100_incremental_domain 3 | REPEAT=5 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --momentum 0.9 --weight_decay 1e-4 --lr 0.1 --offline_training | tee ${OUTDIR}/Offline_SGD.log 6 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 --offline_training | tee ${OUTDIR}/Offline_adam.log 7 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 | tee ${OUTDIR}/Adam.log 8 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/SGD.log 9 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/Adagrad.log 10 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 20 | tee ${OUTDIR}/EWC_online.log 11 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC --lr 0.001 --reg_coef 10 | tee ${OUTDIR}/EWC.log 12 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 10000 | tee ${OUTDIR}/SI.log 13 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 0.0001 | tee ${OUTDIR}/L2.log 14 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_1400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1400.log 15 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_5600.log 16 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 1000000 | tee ${OUTDIR}/MAS.log -------------------------------------------------------------------------------- /scripts/split_CIFAR100_incremental_task.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/split_CIFAR100_incremental_task 3 | REPEAT=5 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --momentum 0.9 --weight_decay 1e-4 --lr 0.1 --offline_training | tee ${OUTDIR}/Offline_SGD.log 6 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 --offline_training | tee ${OUTDIR}/Offline_adam.log 7 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 | tee ${OUTDIR}/Adam.log 8 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/SGD.log 9 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/Adagrad.log 10 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 3000 | tee ${OUTDIR}/EWC_online.log 11 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC.log 12 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 2 | tee ${OUTDIR}/SI.log 13 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 1 | tee ${OUTDIR}/L2.log 14 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_1400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1400.log 15 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_5600.log 16 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 10 | tee ${OUTDIR}/MAS.log -------------------------------------------------------------------------------- /scripts/split_MNIST_incremental_class.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/split_MNIST_incremental_class 3 | REPEAT=10 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 --offline_training | tee ${OUTDIR}/Offline.log 6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 | tee ${OUTDIR}/Adam.log 7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/SGD.log 8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adagrad --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/Adagrad.log 9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_mnist --lr 0.001 --reg_coef 600 | tee ${OUTDIR}/EWC.log 10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_online_mnist --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC_online.log 11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 600 | tee ${OUTDIR}/SI.log 12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/L2.log 13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_1100 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1100.log 14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_4400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4400.log 15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 1 |tee ${OUTDIR}/MAS.log 16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_1100 --lr 0.01 --reg_coef 0.5 |tee ${OUTDIR}/GEM_1100.log 17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_4400 --lr 0.01 --reg_coef 0.5 |tee ${OUTDIR}/GEM_4400.log -------------------------------------------------------------------------------- /scripts/split_MNIST_incremental_domain.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/split_MNIST_incremental_domain 3 | REPEAT=10 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 --offline_training | tee ${OUTDIR}/Offline.log 6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 | tee ${OUTDIR}/Adam.log 7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/SGD.log 8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/Adagrad.log 9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_online_mnist --lr 0.001 --reg_coef 700 | tee ${OUTDIR}/EWC_online.log 10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_mnist --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC.log 11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 3000 | tee ${OUTDIR}/SI.log 12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 0.5 | tee ${OUTDIR}/L2.log 13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_1100 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1100.log 14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_4400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4400.log 15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 10000 | tee ${OUTDIR}/MAS.log 16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_1100 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_1100.log 17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_4400 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4400.log -------------------------------------------------------------------------------- /scripts/split_MNIST_incremental_task.sh: -------------------------------------------------------------------------------- 1 | GPUID=$1 2 | OUTDIR=outputs/split_MNIST_incremental_task 3 | REPEAT=10 4 | mkdir -p $OUTDIR 5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 --offline_training | tee ${OUTDIR}/Offline.log 6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 | tee ${OUTDIR}/Adam.log 7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/SGD.log 8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/Adagrad.log 9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_online_mnist --lr 0.001 --reg_coef 400 | tee ${OUTDIR}/EWC_online.log 10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_mnist --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC.log 11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 300 | tee ${OUTDIR}/SI.log 12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 0.01 | tee ${OUTDIR}/L2.log 13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_1100 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1100.log 14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_4400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4400.log 15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 1 | tee ${OUTDIR}/MAS.log 16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_1100 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_1100.log 17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_4400 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4400.log -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-RIPL/Continual-Learning-Benchmark/d78b9973b6ec0059b2d2577872db355ae2489f6b/utils/__init__.py -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | def accuracy(output, target, topk=(1,)): 5 | """Computes the precision@k for the specified values of k""" 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum().item() 17 | res.append(correct_k*100.0 / batch_size) 18 | 19 | if len(res)==1: 20 | return res[0] 21 | else: 22 | return res 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=1): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | self.avg = float(self.sum) / self.count 42 | 43 | 44 | class Timer(object): 45 | """ 46 | """ 47 | 48 | def __init__(self): 49 | self.reset() 50 | 51 | def reset(self): 52 | self.interval = 0 53 | self.time = time.time() 54 | 55 | def value(self): 56 | return time.time() - self.time 57 | 58 | def tic(self): 59 | self.time = time.time() 60 | 61 | def toc(self): 62 | self.interval = time.time() - self.time 63 | self.time = time.time() 64 | return self.interval --------------------------------------------------------------------------------