├── .gitignore
├── LICENSE
├── README.md
├── agents
├── __init__.py
├── customization.py
├── default.py
├── exp_replay.py
└── regularization.py
├── dataloaders
├── __init__.py
├── base.py
├── datasetGen.py
└── wrapper.py
├── fig
├── results_split_cifar100.png
├── results_split_mnist.png
└── task_shifts.png
├── iBatchLearn.py
├── models
├── __init__.py
├── lenet.py
├── mlp.py
├── resnet.py
└── senet.py
├── modules
├── __init__.py
└── criterions.py
├── requirements.txt
├── scripts
├── permuted_MNIST_incremental_class.sh
├── permuted_MNIST_incremental_domain.sh
├── permuted_MNIST_incremental_task.sh
├── split_CIFAR100_incremental_class.sh
├── split_CIFAR100_incremental_domain.sh
├── split_CIFAR100_incremental_task.sh
├── split_MNIST_incremental_class.sh
├── split_MNIST_incremental_domain.sh
└── split_MNIST_incremental_task.sh
└── utils
├── __init__.py
└── metric.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # repo-specific stuff
2 | data/
3 | outputs/
4 | *.pt
5 | \#*#
6 | .idea
7 | *.sublime-*
8 | *.pkl
9 | .DS_Store
10 | *.pth
11 | *.png
12 | .swp
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | .hypothesis/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # dotenv
96 | .env
97 |
98 | # virtualenv
99 | .venv
100 | venv/
101 | ENV/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 GT-RIPL, Yen-Chang Hsu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Continual-Learning-Benchmark
2 | Evaluate three types of task shifting with popular continual learning algorithms.
3 |
4 | This repository implemented and modularized following algorithms with PyTorch:
5 | - EWC: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1612.00796) (Overcoming catastrophic forgetting in neural networks)
6 | - Online EWC: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1805.06370)
7 | - SI: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1703.04200) (Continual Learning Through Synaptic Intelligence)
8 | - MAS: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf) (Memory Aware Synapses: Learning what (not) to forget)
9 | - GEM: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/exp_replay.py), [paper](https://arxiv.org/abs/1706.08840) (Gradient Episodic Memory for Continual Learning)
10 | - (More are coming)
11 |
12 | All the above algorithms are compared to following baselines with **the same static memory overhead**:
13 | - Naive rehearsal: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/exp_replay.py)
14 | - L2: [code](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py), [paper](https://arxiv.org/abs/1612.00796)
15 |
16 | Key tables:
17 |
18 |
19 |
20 |
21 | If this repository helps your work, please cite:
22 | ```
23 | @inproceedings{Hsu18_EvalCL,
24 | title={Re-evaluating Continual Learning Scenarios: A Categorization and Case for Strong Baselines},
25 | author={Yen-Chang Hsu and Yen-Cheng Liu and Anita Ramasamy and Zsolt Kira},
26 | booktitle={NeurIPS Continual learning Workshop },
27 | year={2018},
28 | url={https://arxiv.org/abs/1810.12488}
29 | }
30 | ```
31 |
32 | ## Preparation
33 | This repository was tested with Python 3.6 and PyTorch 1.0.1.post2. Part of the cases is tested with PyTorch 1.5.1 and gives the same results.
34 |
35 | ```bash
36 | pip install -r requirements.txt
37 | ```
38 |
39 | ## Demo
40 | The scripts for reproducing the results of this paper are under the scripts folder.
41 |
42 | - Example: Run all algorithms in the incremental domain scenario with split MNIST.
43 | ```bash
44 | ./scripts/split_MNIST_incremental_domain.sh 0
45 | # The last number is gpuid
46 | # Outputs will be saved in ./outputs
47 | ```
48 |
49 | - Eaxmple outputs: Summary of repeats
50 | ```text
51 | ===Summary of experiment repeats: 3 / 3 ===
52 | The regularization coefficient: 400.0
53 | The last avg acc of all repeats: [90.517 90.648 91.069]
54 | mean: 90.74466666666666 std: 0.23549144829955856
55 | ```
56 |
57 | - Eaxmple outputs: The grid search for regularization coefficient
58 | ```text
59 | reg_coef: 0.1 mean: 76.08566666666667 std: 1.097717733400629
60 | reg_coef: 1.0 mean: 77.59100000000001 std: 2.100847606721314
61 | reg_coef: 10.0 mean: 84.33933333333334 std: 0.3592671553160509
62 | reg_coef: 100.0 mean: 90.83800000000001 std: 0.6913701372395712
63 | reg_coef: 1000.0 mean: 87.48566666666666 std: 0.5440161353816179
64 | reg_coef: 5000.0 mean: 68.99133333333333 std: 1.6824762174313899
65 |
66 | ```
67 |
68 | ## Usage
69 | - Enable the grid search for the regularization coefficient: Use the option with a list of values, ex: -reg_coef 0.1 1 10 100 ...
70 | - Repeat the experiment N times: Use the option -repeat N
71 |
72 | Lookup available options:
73 | ```bash
74 | python iBatchLearn.py -h
75 | ```
76 |
77 | ## Other results
78 |
79 | Below are CIFAR100 results. Please refer to the [scripts](https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/scripts/split_CIFAR100_incremental_class.sh) for details.
80 |
81 |
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from . import default
2 | from . import regularization
3 | from . import customization
4 | from . import exp_replay
--------------------------------------------------------------------------------
/agents/customization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .default import NormalNN
3 | from .regularization import SI, EWC, EWC_online
4 | from .exp_replay import Naive_Rehearsal, GEM
5 | from modules.criterions import BCEauto
6 |
7 | def init_zero_weights(m):
8 | with torch.no_grad():
9 | if type(m) == torch.nn.Linear:
10 | m.weight.zero_()
11 | m.bias.zero_()
12 | elif type(m) == torch.nn.ModuleDict:
13 | for l in m.values():
14 | init_zero_weights(l)
15 | else:
16 | assert False, 'Only support linear layer'
17 |
18 |
19 | def NormalNN_reset_optim(agent_config):
20 | agent = NormalNN(agent_config)
21 | agent.reset_optimizer = True
22 | return agent
23 |
24 |
25 | def NormalNN_BCE(agent_config):
26 | agent = NormalNN(agent_config)
27 | agent.criterion_fn = BCEauto()
28 | return agent
29 |
30 |
31 | def SI_BCE(agent_config):
32 | agent = SI(agent_config)
33 | agent.criterion_fn = BCEauto()
34 | return agent
35 |
36 |
37 | def SI_splitMNIST_zero_init(agent_config):
38 | agent = SI(agent_config)
39 | agent.damping_factor = 1e-3
40 | agent.reset_optimizer = True
41 | agent.model.last.apply(init_zero_weights)
42 | return agent
43 |
44 |
45 | def SI_splitMNIST_rand_init(agent_config):
46 | agent = SI(agent_config)
47 | agent.damping_factor = 1e-3
48 | agent.reset_optimizer = True
49 | return agent
50 |
51 |
52 | def EWC_BCE(agent_config):
53 | agent = EWC(agent_config)
54 | agent.criterion_fn = BCEauto()
55 | return agent
56 |
57 |
58 | def EWC_mnist(agent_config):
59 | agent = EWC(agent_config)
60 | agent.n_fisher_sample = 60000
61 | return agent
62 |
63 |
64 | def EWC_online_mnist(agent_config):
65 | agent = EWC(agent_config)
66 | agent.n_fisher_sample = 60000
67 | agent.online_reg = True
68 | return agent
69 |
70 |
71 | def EWC_online_empFI(agent_config):
72 | agent = EWC(agent_config)
73 | agent.empFI = True
74 | return agent
75 |
76 |
77 | def EWC_zero_init(agent_config):
78 | agent = EWC(agent_config)
79 | agent.reset_optimizer = True
80 | agent.model.last.apply(init_zero_weights)
81 | return agent
82 |
83 |
84 | def EWC_rand_init(agent_config):
85 | agent = EWC(agent_config)
86 | agent.reset_optimizer = True
87 | return agent
88 |
89 |
90 | def EWC_reset_optim(agent_config):
91 | agent = EWC(agent_config)
92 | agent.reset_optimizer = True
93 | return agent
94 |
95 |
96 | def EWC_online_reset_optim(agent_config):
97 | agent = EWC_online(agent_config)
98 | agent.reset_optimizer = True
99 | return agent
100 |
101 |
102 | def Naive_Rehearsal_100(agent_config):
103 | agent = Naive_Rehearsal(agent_config)
104 | agent.memory_size = 100
105 | return agent
106 |
107 |
108 | def Naive_Rehearsal_200(agent_config):
109 | agent = Naive_Rehearsal(agent_config)
110 | agent.memory_size = 200
111 | return agent
112 |
113 |
114 | def Naive_Rehearsal_400(agent_config):
115 | agent = Naive_Rehearsal(agent_config)
116 | agent.memory_size = 400
117 | return agent
118 |
119 |
120 | def Naive_Rehearsal_1100(agent_config):
121 | agent = Naive_Rehearsal(agent_config)
122 | agent.memory_size = 1100
123 | return agent
124 |
125 |
126 | def Naive_Rehearsal_1400(agent_config):
127 | agent = Naive_Rehearsal(agent_config)
128 | agent.memory_size = 1400
129 | return agent
130 |
131 |
132 | def Naive_Rehearsal_4000(agent_config):
133 | agent = Naive_Rehearsal(agent_config)
134 | agent.memory_size = 4000
135 | return agent
136 |
137 |
138 | def Naive_Rehearsal_4400(agent_config):
139 | agent = Naive_Rehearsal(agent_config)
140 | agent.memory_size = 4400
141 | return agent
142 |
143 |
144 | def Naive_Rehearsal_5600(agent_config):
145 | agent = Naive_Rehearsal(agent_config)
146 | agent.memory_size = 5600
147 | return agent
148 |
149 |
150 | def Naive_Rehearsal_16000(agent_config):
151 | agent = Naive_Rehearsal(agent_config)
152 | agent.memory_size = 16000
153 | return agent
154 |
155 |
156 | def GEM_100(agent_config):
157 | agent = GEM(agent_config)
158 | agent.memory_size = 100
159 | return agent
160 |
161 |
162 | def GEM_200(agent_config):
163 | agent = GEM(agent_config)
164 | agent.memory_size = 200
165 | return agent
166 |
167 |
168 | def GEM_400(agent_config):
169 | agent = GEM(agent_config)
170 | agent.memory_size = 400
171 | return agent
172 |
173 |
174 | def GEM_orig_1100(agent_config):
175 | agent = GEM(agent_config)
176 | agent.skip_memory_concatenation = True
177 | agent.memory_size = 1100
178 | return agent
179 |
180 |
181 | def GEM_1100(agent_config):
182 | agent = GEM(agent_config)
183 | agent.memory_size = 1100
184 | return agent
185 |
186 |
187 | def GEM_4000(agent_config):
188 | agent = GEM(agent_config)
189 | agent.memory_size = 4000
190 | return agent
191 |
192 |
193 | def GEM_4400(agent_config):
194 | agent = GEM(agent_config)
195 | agent.memory_size = 4400
196 | return agent
197 |
198 |
199 | def GEM_16000(agent_config):
200 | agent = GEM(agent_config)
201 | agent.memory_size = 16000
202 | return agent
--------------------------------------------------------------------------------
/agents/default.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | from types import MethodType
5 | import models
6 | from utils.metric import accuracy, AverageMeter, Timer
7 |
8 | class NormalNN(nn.Module):
9 | '''
10 | Normal Neural Network with SGD for classification
11 | '''
12 | def __init__(self, agent_config):
13 | '''
14 | :param agent_config (dict): lr=float,momentum=float,weight_decay=float,
15 | schedule=[int], # The last number in the list is the end of epoch
16 | model_type=str,model_name=str,out_dim={task:dim},model_weights=str
17 | force_single_head=bool
18 | print_freq=int
19 | gpuid=[int]
20 | '''
21 | super(NormalNN, self).__init__()
22 | self.log = print if agent_config['print_freq'] > 0 else lambda \
23 | *args: None # Use a void function to replace the print
24 | self.config = agent_config
25 | # If out_dim is a dict, there is a list of tasks. The model will have a head for each task.
26 | self.multihead = True if len(self.config['out_dim'])>1 else False # A convenience flag to indicate multi-head/task
27 | self.model = self.create_model()
28 | self.criterion_fn = nn.CrossEntropyLoss()
29 | if agent_config['gpuid'][0] >= 0:
30 | self.cuda()
31 | self.gpu = True
32 | else:
33 | self.gpu = False
34 | self.init_optimizer()
35 | self.reset_optimizer = False
36 | self.valid_out_dim = 'ALL' # Default: 'ALL' means all output nodes are active
37 | # Set a interger here for the incremental class scenario
38 |
39 | def init_optimizer(self):
40 | optimizer_arg = {'params':self.model.parameters(),
41 | 'lr':self.config['lr'],
42 | 'weight_decay':self.config['weight_decay']}
43 | if self.config['optimizer'] in ['SGD','RMSprop']:
44 | optimizer_arg['momentum'] = self.config['momentum']
45 | elif self.config['optimizer'] in ['Rprop']:
46 | optimizer_arg.pop('weight_decay')
47 | elif self.config['optimizer'] == 'amsgrad':
48 | optimizer_arg['amsgrad'] = True
49 | self.config['optimizer'] = 'Adam'
50 |
51 | self.optimizer = torch.optim.__dict__[self.config['optimizer']](**optimizer_arg)
52 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config['schedule'],
53 | gamma=0.1)
54 |
55 | def create_model(self):
56 | cfg = self.config
57 |
58 | # Define the backbone (MLP, LeNet, VGG, ResNet ... etc) of model
59 | model = models.__dict__[cfg['model_type']].__dict__[cfg['model_name']]()
60 |
61 | # Apply network surgery to the backbone
62 | # Create the heads for tasks (It can be single task or multi-task)
63 | n_feat = model.last.in_features
64 |
65 | # The output of the model will be a dict: {task_name1:output1, task_name2:output2 ...}
66 | # For a single-headed model the output will be {'All':output}
67 | model.last = nn.ModuleDict()
68 | for task,out_dim in cfg['out_dim'].items():
69 | model.last[task] = nn.Linear(n_feat,out_dim)
70 |
71 | # Redefine the task-dependent function
72 | def new_logits(self, x):
73 | outputs = {}
74 | for task, func in self.last.items():
75 | outputs[task] = func(x)
76 | return outputs
77 |
78 | # Replace the task-dependent function
79 | model.logits = MethodType(new_logits, model)
80 | # Load pre-trained weights
81 | if cfg['model_weights'] is not None:
82 | print('=> Load model weights:', cfg['model_weights'])
83 | model_state = torch.load(cfg['model_weights'],
84 | map_location=lambda storage, loc: storage) # Load to CPU.
85 | model.load_state_dict(model_state)
86 | print('=> Load Done')
87 | return model
88 |
89 | def forward(self, x):
90 | return self.model.forward(x)
91 |
92 | def predict(self, inputs):
93 | self.model.eval()
94 | out = self.forward(inputs)
95 | for t in out.keys():
96 | out[t] = out[t].detach()
97 | return out
98 |
99 | def validation(self, dataloader):
100 | # This function doesn't distinguish tasks.
101 | batch_timer = Timer()
102 | acc = AverageMeter()
103 | batch_timer.tic()
104 |
105 | orig_mode = self.training
106 | self.eval()
107 | for i, (input, target, task) in enumerate(dataloader):
108 |
109 | if self.gpu:
110 | with torch.no_grad():
111 | input = input.cuda()
112 | target = target.cuda()
113 | output = self.predict(input)
114 |
115 | # Summarize the performance of all tasks, or 1 task, depends on dataloader.
116 | # Calculated by total number of data.
117 | acc = accumulate_acc(output, target, task, acc)
118 |
119 | self.train(orig_mode)
120 |
121 | self.log(' * Val Acc {acc.avg:.3f}, Total time {time:.2f}'
122 | .format(acc=acc,time=batch_timer.toc()))
123 | return acc.avg
124 |
125 | def criterion(self, preds, targets, tasks, **kwargs):
126 | # The inputs and targets could come from single task or a mix of tasks
127 | # The network always makes the predictions with all its heads
128 | # The criterion will match the head and task to calculate the loss.
129 | if self.multihead:
130 | loss = 0
131 | for t,t_preds in preds.items():
132 | inds = [i for i in range(len(tasks)) if tasks[i]==t] # The index of inputs that matched specific task
133 | if len(inds)>0:
134 | t_preds = t_preds[inds]
135 | t_target = targets[inds]
136 | loss += self.criterion_fn(t_preds, t_target) * len(inds) # restore the loss from average
137 | loss /= len(targets) # Average the total loss by the mini-batch size
138 | else:
139 | pred = preds['All']
140 | if isinstance(self.valid_out_dim, int): # (Not 'ALL') Mask out the outputs of unseen classes for incremental class scenario
141 | pred = preds['All'][:,:self.valid_out_dim]
142 | loss = self.criterion_fn(pred, targets)
143 | return loss
144 |
145 | def update_model(self, inputs, targets, tasks):
146 | out = self.forward(inputs)
147 | loss = self.criterion(out, targets, tasks)
148 | self.optimizer.zero_grad()
149 | loss.backward()
150 | self.optimizer.step()
151 | return loss.detach(), out
152 |
153 | def learn_batch(self, train_loader, val_loader=None):
154 | if self.reset_optimizer: # Reset optimizer before learning each task
155 | self.log('Optimizer is reset!')
156 | self.init_optimizer()
157 |
158 | for epoch in range(self.config['schedule'][-1]):
159 | data_timer = Timer()
160 | batch_timer = Timer()
161 | batch_time = AverageMeter()
162 | data_time = AverageMeter()
163 | losses = AverageMeter()
164 | acc = AverageMeter()
165 |
166 | # Config the model and optimizer
167 | self.log('Epoch:{0}'.format(epoch))
168 | self.model.train()
169 | self.scheduler.step(epoch)
170 | for param_group in self.optimizer.param_groups:
171 | self.log('LR:',param_group['lr'])
172 |
173 | # Learning with mini-batch
174 | data_timer.tic()
175 | batch_timer.tic()
176 | self.log('Itr\t\tTime\t\t Data\t\t Loss\t\tAcc')
177 | for i, (input, target, task) in enumerate(train_loader):
178 |
179 | data_time.update(data_timer.toc()) # measure data loading time
180 |
181 | if self.gpu:
182 | input = input.cuda()
183 | target = target.cuda()
184 |
185 | loss, output = self.update_model(input, target, task)
186 | input = input.detach()
187 | target = target.detach()
188 |
189 | # measure accuracy and record loss
190 | acc = accumulate_acc(output, target, task, acc)
191 | losses.update(loss, input.size(0))
192 |
193 | batch_time.update(batch_timer.toc()) # measure elapsed time
194 | data_timer.toc()
195 |
196 | if ((self.config['print_freq']>0) and (i % self.config['print_freq'] == 0)) or (i+1)==len(train_loader):
197 | self.log('[{0}/{1}]\t'
198 | '{batch_time.val:.4f} ({batch_time.avg:.4f})\t'
199 | '{data_time.val:.4f} ({data_time.avg:.4f})\t'
200 | '{loss.val:.3f} ({loss.avg:.3f})\t'
201 | '{acc.val:.2f} ({acc.avg:.2f})'.format(
202 | i, len(train_loader), batch_time=batch_time,
203 | data_time=data_time, loss=losses, acc=acc))
204 |
205 | self.log(' * Train Acc {acc.avg:.3f}'.format(acc=acc))
206 |
207 | # Evaluate the performance of current task
208 | if val_loader != None:
209 | self.validation(val_loader)
210 |
211 | def learn_stream(self, data, label):
212 | assert False,'No implementation yet'
213 |
214 | def add_valid_output_dim(self, dim=0):
215 | # This function is kind of ad-hoc, but it is the simplest way to support incremental class learning
216 | self.log('Incremental class: Old valid output dimension:', self.valid_out_dim)
217 | if self.valid_out_dim == 'ALL':
218 | self.valid_out_dim = 0 # Initialize it with zero
219 | self.valid_out_dim += dim
220 | self.log('Incremental class: New Valid output dimension:', self.valid_out_dim)
221 | return self.valid_out_dim
222 |
223 | def count_parameter(self):
224 | return sum(p.numel() for p in self.model.parameters())
225 |
226 | def save_model(self, filename):
227 | model_state = self.model.state_dict()
228 | if isinstance(self.model,torch.nn.DataParallel):
229 | # Get rid of 'module' before the name of states
230 | model_state = self.model.module.state_dict()
231 | for key in model_state.keys(): # Always save it to cpu
232 | model_state[key] = model_state[key].cpu()
233 | print('=> Saving model to:', filename)
234 | torch.save(model_state, filename + '.pth')
235 | print('=> Save Done')
236 |
237 | def cuda(self):
238 | torch.cuda.set_device(self.config['gpuid'][0])
239 | self.model = self.model.cuda()
240 | self.criterion_fn = self.criterion_fn.cuda()
241 | # Multi-GPU
242 | if len(self.config['gpuid']) > 1:
243 | self.model = torch.nn.DataParallel(self.model, device_ids=self.config['gpuid'], output_device=self.config['gpuid'][0])
244 | return self
245 |
246 | def accumulate_acc(output, target, task, meter):
247 | if 'All' in output.keys(): # Single-headed model
248 | meter.update(accuracy(output['All'], target), len(target))
249 | else: # outputs from multi-headed (multi-task) model
250 | for t, t_out in output.items():
251 | inds = [i for i in range(len(task)) if task[i] == t] # The index of inputs that matched specific task
252 | if len(inds) > 0:
253 | t_out = t_out[inds]
254 | t_target = target[inds]
255 | meter.update(accuracy(t_out, t_target), len(inds))
256 |
257 | return meter
258 |
--------------------------------------------------------------------------------
/agents/exp_replay.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from importlib import import_module
4 | from .default import NormalNN
5 | from .regularization import SI, L2, EWC, MAS
6 | from dataloaders.wrapper import Storage
7 |
8 |
9 | class Naive_Rehearsal(NormalNN):
10 |
11 | def __init__(self, agent_config):
12 | super(Naive_Rehearsal, self).__init__(agent_config)
13 | self.task_count = 0
14 | self.memory_size = 1000
15 | self.task_memory = {}
16 | self.skip_memory_concatenation = False
17 |
18 | def learn_batch(self, train_loader, val_loader=None):
19 | # 1.Combine training set
20 | if self.skip_memory_concatenation:
21 | new_train_loader = train_loader
22 | else: # default
23 | dataset_list = []
24 | for storage in self.task_memory.values():
25 | dataset_list.append(storage)
26 | dataset_list *= max(len(train_loader.dataset)//self.memory_size,1) # Let old data: new data = 1:1
27 | dataset_list.append(train_loader.dataset)
28 | dataset = torch.utils.data.ConcatDataset(dataset_list)
29 | new_train_loader = torch.utils.data.DataLoader(dataset,
30 | batch_size=train_loader.batch_size,
31 | shuffle=True,
32 | num_workers=train_loader.num_workers)
33 |
34 | # 2.Update model as normal
35 | super(Naive_Rehearsal, self).learn_batch(new_train_loader, val_loader)
36 |
37 | # 3.Randomly decide the images to stay in the memory
38 | self.task_count += 1
39 | # (a) Decide the number of samples for being saved
40 | num_sample_per_task = self.memory_size // self.task_count
41 | num_sample_per_task = min(len(train_loader.dataset),num_sample_per_task)
42 | # (b) Reduce current exemplar set to reserve the space for the new dataset
43 | for storage in self.task_memory.values():
44 | storage.reduce(num_sample_per_task)
45 | # (c) Randomly choose some samples from new task and save them to the memory
46 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data
47 | self.task_memory[self.task_count] = Storage(train_loader.dataset, randind)
48 |
49 |
50 | class Naive_Rehearsal_SI(Naive_Rehearsal, SI):
51 |
52 | def __init__(self, agent_config):
53 | super(Naive_Rehearsal_SI, self).__init__(agent_config)
54 |
55 |
56 | class Naive_Rehearsal_L2(Naive_Rehearsal, L2):
57 |
58 | def __init__(self, agent_config):
59 | super(Naive_Rehearsal_L2, self).__init__(agent_config)
60 |
61 |
62 | class Naive_Rehearsal_EWC(Naive_Rehearsal, EWC):
63 |
64 | def __init__(self, agent_config):
65 | super(Naive_Rehearsal_EWC, self).__init__(agent_config)
66 | self.online_reg = True # Online EWC
67 |
68 |
69 | class Naive_Rehearsal_MAS(Naive_Rehearsal, MAS):
70 |
71 | def __init__(self, agent_config):
72 | super(Naive_Rehearsal_MAS, self).__init__(agent_config)
73 |
74 |
75 | class GEM(Naive_Rehearsal):
76 | """
77 | @inproceedings{GradientEpisodicMemory,
78 | title={Gradient Episodic Memory for Continual Learning},
79 | author={Lopez-Paz, David and Ranzato, Marc'Aurelio},
80 | booktitle={NIPS},
81 | year={2017},
82 | url={https://arxiv.org/abs/1706.08840}
83 | }
84 | """
85 |
86 | def __init__(self, agent_config):
87 | super(GEM, self).__init__(agent_config)
88 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # For convenience
89 | self.task_grads = {}
90 | self.quadprog = import_module('quadprog')
91 | self.task_mem_cache = {}
92 |
93 | def grad_to_vector(self):
94 | vec = []
95 | for n,p in self.params.items():
96 | if p.grad is not None:
97 | vec.append(p.grad.view(-1))
98 | else:
99 | # Part of the network might has no grad, fill zero for those terms
100 | vec.append(p.data.clone().fill_(0).view(-1))
101 | return torch.cat(vec)
102 |
103 | def vector_to_grad(self, vec):
104 | # Overwrite current param.grad by slicing the values in vec (flatten grad)
105 | pointer = 0
106 | for n, p in self.params.items():
107 | # The length of the parameter
108 | num_param = p.numel()
109 | if p.grad is not None:
110 | # Slice the vector, reshape it, and replace the old data of the grad
111 | p.grad.copy_(vec[pointer:pointer + num_param].view_as(p))
112 | # Part of the network might has no grad, ignore those terms
113 | # Increment the pointer
114 | pointer += num_param
115 |
116 | def project2cone2(self, gradient, memories):
117 | """
118 | Solves the GEM dual QP described in the paper given a proposed
119 | gradient "gradient", and a memory of task gradients "memories".
120 | Overwrites "gradient" with the final projected update.
121 |
122 | input: gradient, p-vector
123 | input: memories, (t * p)-vector
124 | output: x, p-vector
125 |
126 | Modified from: https://github.com/facebookresearch/GradientEpisodicMemory/blob/master/model/gem.py#L70
127 | """
128 | margin = self.config['reg_coef']
129 | memories_np = memories.cpu().contiguous().double().numpy()
130 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
131 | t = memories_np.shape[0]
132 | #print(memories_np.shape, gradient_np.shape)
133 | P = np.dot(memories_np, memories_np.transpose())
134 | P = 0.5 * (P + P.transpose())
135 | q = np.dot(memories_np, gradient_np) * -1
136 | G = np.eye(t)
137 | P = P + G * 0.001
138 | h = np.zeros(t) + margin
139 | v = self.quadprog.solve_qp(P, q, G, h)[0]
140 | x = np.dot(v, memories_np) + gradient_np
141 | new_grad = torch.Tensor(x).view(-1)
142 | if self.gpu:
143 | new_grad = new_grad.cuda()
144 | return new_grad
145 |
146 | def learn_batch(self, train_loader, val_loader=None):
147 |
148 | # Update model as normal
149 | super(GEM, self).learn_batch(train_loader, val_loader)
150 |
151 | # Cache the data for faster processing
152 | for t, mem in self.task_memory.items():
153 | # Concatenate all data in each task
154 | mem_loader = torch.utils.data.DataLoader(mem,
155 | batch_size=len(mem),
156 | shuffle=False,
157 | num_workers=2)
158 | assert len(mem_loader)==1,'The length of mem_loader should be 1'
159 | for i, (mem_input, mem_target, mem_task) in enumerate(mem_loader):
160 | if self.gpu:
161 | mem_input = mem_input.cuda()
162 | mem_target = mem_target.cuda()
163 | self.task_mem_cache[t] = {'data':mem_input,'target':mem_target,'task':mem_task}
164 |
165 | def update_model(self, inputs, targets, tasks):
166 |
167 | # compute gradient on previous tasks
168 | if self.task_count > 0:
169 | for t,mem in self.task_memory.items():
170 | self.zero_grad()
171 | # feed the data from memory and collect the gradients
172 | mem_out = self.forward(self.task_mem_cache[t]['data'])
173 | mem_loss = self.criterion(mem_out, self.task_mem_cache[t]['target'], self.task_mem_cache[t]['task'])
174 | mem_loss.backward()
175 | # Store the grads
176 | self.task_grads[t] = self.grad_to_vector()
177 |
178 | # now compute the grad on the current minibatch
179 | out = self.forward(inputs)
180 | loss = self.criterion(out, targets, tasks)
181 | self.optimizer.zero_grad()
182 | loss.backward()
183 |
184 | # check if gradient violates constraints
185 | if self.task_count > 0:
186 | current_grad_vec = self.grad_to_vector()
187 | mem_grad_vec = torch.stack(list(self.task_grads.values()))
188 | dotp = current_grad_vec * mem_grad_vec
189 | dotp = dotp.sum(dim=1)
190 | if (dotp < 0).sum() != 0:
191 | new_grad = self.project2cone2(current_grad_vec, mem_grad_vec)
192 | # copy gradients back
193 | self.vector_to_grad(new_grad)
194 |
195 | self.optimizer.step()
196 | return loss.detach(), out
197 |
--------------------------------------------------------------------------------
/agents/regularization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from .default import NormalNN
4 |
5 |
6 | class L2(NormalNN):
7 | """
8 | @article{kirkpatrick2017overcoming,
9 | title={Overcoming catastrophic forgetting in neural networks},
10 | author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others},
11 | journal={Proceedings of the national academy of sciences},
12 | year={2017},
13 | url={https://arxiv.org/abs/1612.00796}
14 | }
15 | """
16 | def __init__(self, agent_config):
17 | super(L2, self).__init__(agent_config)
18 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # For convenience
19 | self.regularization_terms = {}
20 | self.task_count = 0
21 | self.online_reg = True # True: There will be only one importance matrix and previous model parameters
22 | # False: Each task has its own importance matrix and model parameters
23 |
24 | def calculate_importance(self, dataloader):
25 | # Use an identity importance so it is an L2 regularization.
26 | importance = {}
27 | for n, p in self.params.items():
28 | importance[n] = p.clone().detach().fill_(1) # Identity
29 | return importance
30 |
31 | def learn_batch(self, train_loader, val_loader=None):
32 |
33 | self.log('#reg_term:', len(self.regularization_terms))
34 |
35 | # 1.Learn the parameters for current task
36 | super(L2, self).learn_batch(train_loader, val_loader)
37 |
38 | # 2.Backup the weight of current task
39 | task_param = {}
40 | for n, p in self.params.items():
41 | task_param[n] = p.clone().detach()
42 |
43 | # 3.Calculate the importance of weights for current task
44 | importance = self.calculate_importance(train_loader)
45 |
46 | # Save the weight and importance of weights of current task
47 | self.task_count += 1
48 | if self.online_reg and len(self.regularization_terms)>0:
49 | # Always use only one slot in self.regularization_terms
50 | self.regularization_terms[1] = {'importance':importance, 'task_param':task_param}
51 | else:
52 | # Use a new slot to store the task-specific information
53 | self.regularization_terms[self.task_count] = {'importance':importance, 'task_param':task_param}
54 |
55 | def criterion(self, inputs, targets, tasks, regularization=True, **kwargs):
56 | loss = super(L2, self).criterion(inputs, targets, tasks, **kwargs)
57 |
58 | if regularization and len(self.regularization_terms)>0:
59 | # Calculate the reg_loss only when the regularization_terms exists
60 | reg_loss = 0
61 | for i,reg_term in self.regularization_terms.items():
62 | task_reg_loss = 0
63 | importance = reg_term['importance']
64 | task_param = reg_term['task_param']
65 | for n, p in self.params.items():
66 | task_reg_loss += (importance[n] * (p - task_param[n]) ** 2).sum()
67 | reg_loss += task_reg_loss
68 | loss += self.config['reg_coef'] * reg_loss
69 | return loss
70 |
71 |
72 | class EWC(L2):
73 | """
74 | @article{kirkpatrick2017overcoming,
75 | title={Overcoming catastrophic forgetting in neural networks},
76 | author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others},
77 | journal={Proceedings of the national academy of sciences},
78 | year={2017},
79 | url={https://arxiv.org/abs/1612.00796}
80 | }
81 | """
82 |
83 | def __init__(self, agent_config):
84 | super(EWC, self).__init__(agent_config)
85 | self.online_reg = False
86 | self.n_fisher_sample = None
87 | self.empFI = False
88 |
89 | def calculate_importance(self, dataloader):
90 | # Update the diag fisher information
91 | # There are several ways to estimate the F matrix.
92 | # We keep the implementation as simple as possible while maintaining a similar performance to the literature.
93 | self.log('Computing EWC')
94 |
95 | # Initialize the importance matrix
96 | if self.online_reg and len(self.regularization_terms)>0:
97 | importance = self.regularization_terms[1]['importance']
98 | else:
99 | importance = {}
100 | for n, p in self.params.items():
101 | importance[n] = p.clone().detach().fill_(0) # zero initialized
102 |
103 | # Sample a subset (n_fisher_sample) of data to estimate the fisher information (batch_size=1)
104 | # Otherwise it uses mini-batches for the estimation. This speeds up the process a lot with similar performance.
105 | if self.n_fisher_sample is not None:
106 | n_sample = min(self.n_fisher_sample, len(dataloader.dataset))
107 | self.log('Sample',self.n_fisher_sample,'for estimating the F matrix.')
108 | rand_ind = random.sample(list(range(len(dataloader.dataset))), n_sample)
109 | subdata = torch.utils.data.Subset(dataloader.dataset, rand_ind)
110 | dataloader = torch.utils.data.DataLoader(subdata, shuffle=True, num_workers=2, batch_size=1)
111 |
112 | mode = self.training
113 | self.eval()
114 |
115 | # Accumulate the square of gradients
116 | for i, (input, target, task) in enumerate(dataloader):
117 | if self.gpu:
118 | input = input.cuda()
119 | target = target.cuda()
120 |
121 | preds = self.forward(input)
122 |
123 | # Sample the labels for estimating the gradients
124 | # For multi-headed model, the batch of data will be from the same task,
125 | # so we just use task[0] as the task name to fetch corresponding predictions
126 | # For single-headed model, just use the max of predictions from preds['All']
127 | task_name = task[0] if self.multihead else 'All'
128 |
129 | # The flag self.valid_out_dim is for handling the case of incremental class learning.
130 | # if self.valid_out_dim is an integer, it means only the first 'self.valid_out_dim' dimensions are used
131 | # in calculating the loss.
132 | pred = preds[task_name] if not isinstance(self.valid_out_dim, int) else preds[task_name][:,:self.valid_out_dim]
133 | ind = pred.max(1)[1].flatten() # Choose the one with max
134 |
135 | # - Alternative ind by multinomial sampling. Its performance is similar. -
136 | # prob = torch.nn.functional.softmax(preds['All'],dim=1)
137 | # ind = torch.multinomial(prob,1).flatten()
138 |
139 | if self.empFI: # Use groundtruth label (default is without this)
140 | ind = target
141 |
142 | loss = self.criterion(preds, ind, task, regularization=False)
143 | self.model.zero_grad()
144 | loss.backward()
145 | for n, p in importance.items():
146 | if self.params[n].grad is not None: # Some heads can have no grad if no loss applied on them.
147 | p += ((self.params[n].grad ** 2) * len(input) / len(dataloader))
148 |
149 | self.train(mode=mode)
150 |
151 | return importance
152 |
153 |
154 | def EWC_online(agent_config):
155 | agent = EWC(agent_config)
156 | agent.online_reg = True
157 | return agent
158 |
159 |
160 | class SI(L2):
161 | """
162 | @inproceedings{zenke2017continual,
163 | title={Continual Learning Through Synaptic Intelligence},
164 | author={Zenke, Friedemann and Poole, Ben and Ganguli, Surya},
165 | booktitle={International Conference on Machine Learning},
166 | year={2017},
167 | url={https://arxiv.org/abs/1703.04200}
168 | }
169 | """
170 |
171 | def __init__(self, agent_config):
172 | super(SI, self).__init__(agent_config)
173 | self.online_reg = True # Original SI works in an online updating fashion
174 | self.damping_factor = 0.1
175 | self.w = {}
176 | for n, p in self.params.items():
177 | self.w[n] = p.clone().detach().zero_()
178 |
179 | # The initial_params will only be used in the first task (when the regularization_terms is empty)
180 | self.initial_params = {}
181 | for n, p in self.params.items():
182 | self.initial_params[n] = p.clone().detach()
183 |
184 | def update_model(self, inputs, targets, tasks):
185 |
186 | unreg_gradients = {}
187 |
188 | # 1.Save current parameters
189 | old_params = {}
190 | for n, p in self.params.items():
191 | old_params[n] = p.clone().detach()
192 |
193 | # 2. Collect the gradients without regularization term
194 | out = self.forward(inputs)
195 | loss = self.criterion(out, targets, tasks, regularization=False)
196 | self.optimizer.zero_grad()
197 | loss.backward(retain_graph=True)
198 | for n, p in self.params.items():
199 | if p.grad is not None:
200 | unreg_gradients[n] = p.grad.clone().detach()
201 |
202 | # 3. Normal update with regularization
203 | loss = self.criterion(out, targets, tasks, regularization=True)
204 | self.optimizer.zero_grad()
205 | loss.backward()
206 | self.optimizer.step()
207 |
208 | # 4. Accumulate the w
209 | for n, p in self.params.items():
210 | delta = p.detach() - old_params[n]
211 | if n in unreg_gradients.keys(): # In multi-head network, some head could have no grad (lazy) since no loss go through it.
212 | self.w[n] -= unreg_gradients[n] * delta # w[n] is >=0
213 |
214 | return loss.detach(), out
215 |
216 | """
217 | # - Alternative simplified implementation with similar performance -
218 | def update_model(self, inputs, targets, tasks):
219 | # A wrapper of original update step to include the estimation of w
220 |
221 | # Backup prev param if not done yet
222 | # The backup only happened at the beginning of a new task
223 | if len(self.prev_params) == 0:
224 | for n, p in self.params.items():
225 | self.prev_params[n] = p.clone().detach()
226 |
227 | # 1.Save current parameters
228 | old_params = {}
229 | for n, p in self.params.items():
230 | old_params[n] = p.clone().detach()
231 |
232 | # 2.Calculate the loss as usual
233 | loss, out = super(SI, self).update_model(inputs, targets, tasks)
234 |
235 | # 3.Accumulate the w
236 | for n, p in self.params.items():
237 | delta = p.detach() - old_params[n]
238 | if p.grad is not None: # In multi-head network, some head could have no grad (lazy) since no loss go through it.
239 | self.w[n] -= p.grad * delta # w[n] is >=0
240 |
241 | return loss.detach(), out
242 | """
243 |
244 | def calculate_importance(self, dataloader):
245 | self.log('Computing SI')
246 | assert self.online_reg,'SI needs online_reg=True'
247 |
248 | # Initialize the importance matrix
249 | if len(self.regularization_terms)>0: # The case of after the first task
250 | importance = self.regularization_terms[1]['importance']
251 | prev_params = self.regularization_terms[1]['task_param']
252 | else: # It is in the first task
253 | importance = {}
254 | for n, p in self.params.items():
255 | importance[n] = p.clone().detach().fill_(0) # zero initialized
256 | prev_params = self.initial_params
257 |
258 | # Calculate or accumulate the Omega (the importance matrix)
259 | for n, p in importance.items():
260 | delta_theta = self.params[n].detach() - prev_params[n]
261 | p += self.w[n]/(delta_theta**2 + self.damping_factor)
262 | self.w[n].zero_()
263 |
264 | return importance
265 |
266 |
267 | class MAS(L2):
268 | """
269 | @article{aljundi2017memory,
270 | title={Memory Aware Synapses: Learning what (not) to forget},
271 | author={Aljundi, Rahaf and Babiloni, Francesca and Elhoseiny, Mohamed and Rohrbach, Marcus and Tuytelaars, Tinne},
272 | booktitle={ECCV},
273 | year={2018},
274 | url={https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf}
275 | }
276 | """
277 |
278 | def __init__(self, agent_config):
279 | super(MAS, self).__init__(agent_config)
280 | self.online_reg = True
281 |
282 | def calculate_importance(self, dataloader):
283 | self.log('Computing MAS')
284 |
285 | # Initialize the importance matrix
286 | if self.online_reg and len(self.regularization_terms)>0:
287 | importance = self.regularization_terms[1]['importance']
288 | else:
289 | importance = {}
290 | for n, p in self.params.items():
291 | importance[n] = p.clone().detach().fill_(0) # zero initialized
292 |
293 | mode = self.training
294 | self.eval()
295 |
296 | # Accumulate the gradients of L2 loss on the outputs
297 | for i, (input, target, task) in enumerate(dataloader):
298 | if self.gpu:
299 | input = input.cuda()
300 | target = target.cuda()
301 |
302 | preds = self.forward(input)
303 |
304 | # Sample the labels for estimating the gradients
305 | # For multi-headed model, the batch of data will be from the same task,
306 | # so we just use task[0] as the task name to fetch corresponding predictions
307 | # For single-headed model, just use the max of predictions from preds['All']
308 | task_name = task[0] if self.multihead else 'All'
309 |
310 | # The flag self.valid_out_dim is for handling the case of incremental class learning.
311 | # if self.valid_out_dim is an integer, it means only the first 'self.valid_out_dim' dimensions are used
312 | # in calculating the loss.
313 | pred = preds[task_name] if not isinstance(self.valid_out_dim, int) else preds[task_name][:,:self.valid_out_dim]
314 |
315 | pred.pow_(2)
316 | loss = pred.mean()
317 |
318 | self.model.zero_grad()
319 | loss.backward()
320 | for n, p in importance.items():
321 | if self.params[n].grad is not None: # Some heads can have no grad if no loss applied on them.
322 | p += (self.params[n].grad.abs() / len(dataloader))
323 |
324 | self.train(mode=mode)
325 |
326 | return importance
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GT-RIPL/Continual-Learning-Benchmark/d78b9973b6ec0059b2d2577872db355ae2489f6b/dataloaders/__init__.py
--------------------------------------------------------------------------------
/dataloaders/base.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torchvision import transforms
3 | from .wrapper import CacheClassLabel
4 |
5 | def MNIST(dataroot, train_aug=False):
6 | # Add padding to make 32x32
7 | #normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # for 28x28
8 | normalize = transforms.Normalize(mean=(0.1000,), std=(0.2752,)) # for 32x32
9 |
10 | val_transform = transforms.Compose([
11 | transforms.Pad(2, fill=0, padding_mode='constant'),
12 | transforms.ToTensor(),
13 | normalize,
14 | ])
15 | train_transform = val_transform
16 | if train_aug:
17 | train_transform = transforms.Compose([
18 | transforms.RandomCrop(32, padding=4),
19 | transforms.ToTensor(),
20 | normalize,
21 | ])
22 |
23 | train_dataset = torchvision.datasets.MNIST(
24 | root=dataroot,
25 | train=True,
26 | download=True,
27 | transform=train_transform
28 | )
29 | train_dataset = CacheClassLabel(train_dataset)
30 |
31 | val_dataset = torchvision.datasets.MNIST(
32 | dataroot,
33 | train=False,
34 | transform=val_transform
35 | )
36 | val_dataset = CacheClassLabel(val_dataset)
37 |
38 | return train_dataset, val_dataset
39 |
40 | def CIFAR10(dataroot, train_aug=False):
41 | normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
42 |
43 | val_transform = transforms.Compose([
44 | transforms.ToTensor(),
45 | normalize,
46 | ])
47 | train_transform = val_transform
48 | if train_aug:
49 | train_transform = transforms.Compose([
50 | transforms.RandomCrop(32, padding=4),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | normalize,
54 | ])
55 |
56 | train_dataset = torchvision.datasets.CIFAR10(
57 | root=dataroot,
58 | train=True,
59 | download=True,
60 | transform=train_transform
61 | )
62 | train_dataset = CacheClassLabel(train_dataset)
63 |
64 | val_dataset = torchvision.datasets.CIFAR10(
65 | root=dataroot,
66 | train=False,
67 | download=True,
68 | transform=val_transform
69 | )
70 | val_dataset = CacheClassLabel(val_dataset)
71 |
72 | return train_dataset, val_dataset
73 |
74 |
75 | def CIFAR100(dataroot, train_aug=False):
76 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
77 |
78 | val_transform = transforms.Compose([
79 | transforms.ToTensor(),
80 | normalize,
81 | ])
82 | train_transform = val_transform
83 | if train_aug:
84 | train_transform = transforms.Compose([
85 | transforms.RandomCrop(32, padding=4),
86 | transforms.RandomHorizontalFlip(),
87 | transforms.ToTensor(),
88 | normalize,
89 | ])
90 |
91 | train_dataset = torchvision.datasets.CIFAR100(
92 | root=dataroot,
93 | train=True,
94 | download=True,
95 | transform=train_transform
96 | )
97 | train_dataset = CacheClassLabel(train_dataset)
98 |
99 | val_dataset = torchvision.datasets.CIFAR100(
100 | root=dataroot,
101 | train=False,
102 | download=True,
103 | transform=val_transform
104 | )
105 | val_dataset = CacheClassLabel(val_dataset)
106 |
107 | return train_dataset, val_dataset
108 |
109 |
--------------------------------------------------------------------------------
/dataloaders/datasetGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from random import shuffle
3 | from .wrapper import Subclass, AppendName, Permutation
4 |
5 |
6 | def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False):
7 | '''
8 | Generate the dataset splits based on the labels.
9 | :param train_dataset: (torch.utils.data.dataset)
10 | :param val_dataset: (torch.utils.data.dataset)
11 | :param first_split_sz: (int)
12 | :param other_split_sz: (int)
13 | :param rand_split: (bool) Randomize the set of label in each split
14 | :param remap_class: (bool) Ex: remap classes in a split from [2,4,6 ...] to [0,1,2 ...]
15 | :return: train_loaders {task_name:loader}, val_loaders {task_name:loader}, out_dim {task_name:num_classes}
16 | '''
17 | assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes'
18 | num_classes = train_dataset.number_classes
19 |
20 | # Calculate the boundary index of classes for splits
21 | # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100]
22 | split_boundaries = [0, first_split_sz]
23 | while split_boundaries[-1]0:
20 | train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset,
21 | args.n_permutation,
22 | remap_class=not args.no_class_remap)
23 | else:
24 | train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
25 | first_split_sz=args.first_split_size,
26 | other_split_sz=args.other_split_size,
27 | rand_split=args.rand_split,
28 | remap_class=not args.no_class_remap)
29 |
30 | # Prepare the Agent (model)
31 | agent_config = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay,'schedule': args.schedule,
32 | 'model_type':args.model_type, 'model_name': args.model_name, 'model_weights':args.model_weights,
33 | 'out_dim':{'All':args.force_out_dim} if args.force_out_dim>0 else task_output_space,
34 | 'optimizer':args.optimizer,
35 | 'print_freq':args.print_freq, 'gpuid': args.gpuid,
36 | 'reg_coef':args.reg_coef}
37 | agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config)
38 | print(agent.model)
39 | print('#parameter of model:',agent.count_parameter())
40 |
41 | # Decide split ordering
42 | task_names = sorted(list(task_output_space.keys()), key=int)
43 | print('Task order:',task_names)
44 | if args.rand_split_order:
45 | shuffle(task_names)
46 | print('Shuffled task order:', task_names)
47 |
48 | acc_table = OrderedDict()
49 | if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance
50 | task_names = ['All']
51 | train_dataset_all = torch.utils.data.ConcatDataset(train_dataset_splits.values())
52 | val_dataset_all = torch.utils.data.ConcatDataset(val_dataset_splits.values())
53 | train_loader = torch.utils.data.DataLoader(train_dataset_all,
54 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
55 | val_loader = torch.utils.data.DataLoader(val_dataset_all,
56 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
57 |
58 | agent.learn_batch(train_loader, val_loader)
59 |
60 | acc_table['All'] = {}
61 | acc_table['All']['All'] = agent.validation(val_loader)
62 |
63 | else: # Incremental learning
64 | # Feed data to agent and evaluate agent's performance
65 | for i in range(len(task_names)):
66 | train_name = task_names[i]
67 | print('======================',train_name,'=======================')
68 | train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name],
69 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
70 | val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name],
71 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
72 |
73 | if args.incremental_class:
74 | agent.add_valid_output_dim(task_output_space[train_name])
75 |
76 | # Learn
77 | agent.learn_batch(train_loader, val_loader)
78 |
79 | # Evaluate
80 | acc_table[train_name] = OrderedDict()
81 | for j in range(i+1):
82 | val_name = task_names[j]
83 | print('validation split name:', val_name)
84 | val_data = val_dataset_splits[val_name] if not args.eval_on_train_set else train_dataset_splits[val_name]
85 | val_loader = torch.utils.data.DataLoader(val_data,
86 | batch_size=args.batch_size, shuffle=False,
87 | num_workers=args.workers)
88 | acc_table[val_name][train_name] = agent.validation(val_loader)
89 |
90 | return acc_table, task_names
91 |
92 | def get_args(argv):
93 | # This function prepares the variables shared across demo.py
94 | parser = argparse.ArgumentParser()
95 | parser.add_argument('--gpuid', nargs="+", type=int, default=[0],
96 | help="The list of gpuid, ex:--gpuid 3 1. Negative value means cpu-only")
97 | parser.add_argument('--model_type', type=str, default='mlp', help="The type (mlp|lenet|vgg|resnet) of backbone network")
98 | parser.add_argument('--model_name', type=str, default='MLP', help="The name of actual model for the backbone")
99 | parser.add_argument('--force_out_dim', type=int, default=2, help="Set 0 to let the task decide the required output dimension")
100 | parser.add_argument('--agent_type', type=str, default='default', help="The type (filename) of agent")
101 | parser.add_argument('--agent_name', type=str, default='NormalNN', help="The class name of agent")
102 | parser.add_argument('--optimizer', type=str, default='SGD', help="SGD|Adam|RMSprop|amsgrad|Adadelta|Adagrad|Adamax ...")
103 | parser.add_argument('--dataroot', type=str, default='data', help="The root folder of dataset or downloaded data")
104 | parser.add_argument('--dataset', type=str, default='MNIST', help="MNIST(default)|CIFAR10|CIFAR100")
105 | parser.add_argument('--n_permutation', type=int, default=0, help="Enable permuted tests when >0")
106 | parser.add_argument('--first_split_size', type=int, default=2)
107 | parser.add_argument('--other_split_size', type=int, default=2)
108 | parser.add_argument('--no_class_remap', dest='no_class_remap', default=False, action='store_true',
109 | help="Avoid the dataset with a subset of classes doing the remapping. Ex: [2,5,6 ...] -> [0,1,2 ...]")
110 | parser.add_argument('--train_aug', dest='train_aug', default=False, action='store_true',
111 | help="Allow data augmentation during training")
112 | parser.add_argument('--rand_split', dest='rand_split', default=False, action='store_true',
113 | help="Randomize the classes in splits")
114 | parser.add_argument('--rand_split_order', dest='rand_split_order', default=False, action='store_true',
115 | help="Randomize the order of splits")
116 | parser.add_argument('--workers', type=int, default=3, help="#Thread for dataloader")
117 | parser.add_argument('--batch_size', type=int, default=100)
118 | parser.add_argument('--lr', type=float, default=0.01, help="Learning rate")
119 | parser.add_argument('--momentum', type=float, default=0)
120 | parser.add_argument('--weight_decay', type=float, default=0)
121 | parser.add_argument('--schedule', nargs="+", type=int, default=[2],
122 | help="The list of epoch numbers to reduce learning rate by factor of 0.1. Last number is the end epoch")
123 | parser.add_argument('--print_freq', type=float, default=100, help="Print the log at every x iteration")
124 | parser.add_argument('--model_weights', type=str, default=None,
125 | help="The path to the file for the model weights (*.pth).")
126 | parser.add_argument('--reg_coef', nargs="+", type=float, default=[0.], help="The coefficient for regularization. Larger means less plasilicity. Give a list for hyperparameter search.")
127 | parser.add_argument('--eval_on_train_set', dest='eval_on_train_set', default=False, action='store_true',
128 | help="Force the evaluation on train set")
129 | parser.add_argument('--offline_training', dest='offline_training', default=False, action='store_true',
130 | help="Non-incremental learning by make all data available in one batch. For measuring the upperbound performance.")
131 | parser.add_argument('--repeat', type=int, default=1, help="Repeat the experiment N times")
132 | parser.add_argument('--incremental_class', dest='incremental_class', default=False, action='store_true',
133 | help="The number of output node in the single-headed model increases along with new categories.")
134 | args = parser.parse_args(argv)
135 | return args
136 |
137 | if __name__ == '__main__':
138 | args = get_args(sys.argv[1:])
139 | reg_coef_list = args.reg_coef
140 | avg_final_acc = {}
141 |
142 | # The for loops over hyper-paramerters or repeats
143 | for reg_coef in reg_coef_list:
144 | args.reg_coef = reg_coef
145 | avg_final_acc[reg_coef] = np.zeros(args.repeat)
146 | for r in range(args.repeat):
147 |
148 | # Run the experiment
149 | acc_table, task_names = run(args)
150 | print(acc_table)
151 |
152 | # Calculate average performance across tasks
153 | # Customize this part for a different performance metric
154 | avg_acc_history = [0] * len(task_names)
155 | for i in range(len(task_names)):
156 | train_name = task_names[i]
157 | cls_acc_sum = 0
158 | for j in range(i + 1):
159 | val_name = task_names[j]
160 | cls_acc_sum += acc_table[val_name][train_name]
161 | avg_acc_history[i] = cls_acc_sum / (i + 1)
162 | print('Task', train_name, 'average acc:', avg_acc_history[i])
163 |
164 | # Gather the final avg accuracy
165 | avg_final_acc[reg_coef][r] = avg_acc_history[-1]
166 |
167 | # Print the summary so far
168 | print('===Summary of experiment repeats:',r+1,'/',args.repeat,'===')
169 | print('The regularization coefficient:', args.reg_coef)
170 | print('The last avg acc of all repeats:', avg_final_acc[reg_coef])
171 | print('mean:', avg_final_acc[reg_coef].mean(), 'std:', avg_final_acc[reg_coef].std())
172 | for reg_coef,v in avg_final_acc.items():
173 | print('reg_coef:', reg_coef,'mean:', avg_final_acc[reg_coef].mean(), 'std:', avg_final_acc[reg_coef].std())
174 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import mlp
2 | from . import lenet
3 | from . import resnet
4 | from . import senet
--------------------------------------------------------------------------------
/models/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class LeNet(nn.Module):
5 |
6 | def __init__(self, out_dim=10, in_channel=1, img_sz=32):
7 | super(LeNet, self).__init__()
8 | feat_map_sz = img_sz//4
9 | self.n_feat = 50 * feat_map_sz * feat_map_sz
10 |
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(in_channel, 20, 5, padding=2),
13 | nn.BatchNorm2d(20),
14 | nn.ReLU(inplace=True),
15 | nn.MaxPool2d(2, 2),
16 | nn.Conv2d(20, 50, 5, padding=2),
17 | nn.BatchNorm2d(50),
18 | nn.ReLU(inplace=True),
19 | nn.MaxPool2d(2, 2)
20 | )
21 | self.linear = nn.Sequential(
22 | nn.Linear(self.n_feat, 500),
23 | nn.BatchNorm1d(500),
24 | nn.ReLU(inplace=True),
25 | )
26 | self.last = nn.Linear(500, out_dim) # Subject to be replaced dependent on task
27 |
28 | def features(self, x):
29 | x = self.conv(x)
30 | x = self.linear(x.view(-1, self.n_feat))
31 | return x
32 |
33 | def logits(self, x):
34 | x = self.last(x)
35 | return x
36 |
37 | def forward(self, x):
38 | x = self.features(x)
39 | x = self.logits(x)
40 | return x
41 |
42 |
43 | def LeNetC(out_dim=10): # LeNet with color input
44 | return LeNet(out_dim=out_dim, in_channel=3, img_sz=32)
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MLP(nn.Module):
6 |
7 | def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
8 | super(MLP, self).__init__()
9 | self.in_dim = in_channel*img_sz*img_sz
10 | self.linear = nn.Sequential(
11 | nn.Linear(self.in_dim, hidden_dim),
12 | #nn.BatchNorm1d(hidden_dim),
13 | nn.ReLU(inplace=True),
14 | nn.Linear(hidden_dim, hidden_dim),
15 | #nn.BatchNorm1d(hidden_dim),
16 | nn.ReLU(inplace=True),
17 | )
18 | self.last = nn.Linear(hidden_dim, out_dim) # Subject to be replaced dependent on task
19 |
20 | def features(self, x):
21 | x = self.linear(x.view(-1,self.in_dim))
22 | return x
23 |
24 | def logits(self, x):
25 | x = self.last(x)
26 | return x
27 |
28 | def forward(self, x):
29 | x = self.features(x)
30 | x = self.logits(x)
31 | return x
32 |
33 |
34 | def MLP100():
35 | return MLP(hidden_dim=100)
36 |
37 |
38 | def MLP400():
39 | return MLP(hidden_dim=400)
40 |
41 |
42 | def MLP1000():
43 | return MLP(hidden_dim=1000)
44 |
45 |
46 | def MLP2000():
47 | return MLP(hidden_dim=2000)
48 |
49 |
50 | def MLP5000():
51 | return MLP(hidden_dim=5000)
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import math
4 | from torch.nn import init
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9 |
10 |
11 | class PreActBlock(nn.Module):
12 | '''Pre-activation version of the BasicBlock.'''
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1, droprate=0):
16 | super(PreActBlock, self).__init__()
17 | self.bn1 = nn.BatchNorm2d(in_planes)
18 | self.conv1 = conv3x3(in_planes, planes, stride)
19 | self.drop = nn.Dropout(p=droprate) if droprate>0 else None
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = conv3x3(planes, planes)
22 |
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
26 | )
27 |
28 | def forward(self, x):
29 | out = F.relu(self.bn1(x))
30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
31 | out = self.conv1(out)
32 | if self.drop is not None:
33 | out = self.drop(out)
34 | out = self.conv2(F.relu(self.bn2(out)))
35 | out += shortcut
36 | return out
37 |
38 |
39 | class PreActBottleneck(nn.Module):
40 | '''Pre-activation version of the original Bottleneck module.'''
41 | expansion = 4
42 |
43 | def __init__(self, in_planes, planes, stride=1, droprate=None):
44 | super(PreActBottleneck, self).__init__()
45 | self.bn1 = nn.BatchNorm2d(in_planes)
46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
51 |
52 | if stride != 1 or in_planes != self.expansion*planes:
53 | self.shortcut = nn.Sequential(
54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
55 | )
56 |
57 | def forward(self, x):
58 | out = F.relu(self.bn1(x))
59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
60 | out = self.conv1(out)
61 | out = self.conv2(F.relu(self.bn2(out)))
62 | out = self.conv3(F.relu(self.bn3(out)))
63 | out += shortcut
64 | return out
65 |
66 |
67 | class PreActResNet(nn.Module):
68 | def __init__(self, block, num_blocks, num_classes=10, in_channels=3):
69 | super(PreActResNet, self).__init__()
70 | self.in_planes = 64
71 | last_planes = 512*block.expansion
72 |
73 | self.conv1 = conv3x3(in_channels, 64)
74 | self.stage1 = self._make_layer(block, 64, num_blocks[0], stride=1)
75 | self.stage2 = self._make_layer(block, 128, num_blocks[1], stride=2)
76 | self.stage3 = self._make_layer(block, 256, num_blocks[2], stride=2)
77 | self.stage4 = self._make_layer(block, 512, num_blocks[3], stride=2)
78 | self.bn_last = nn.BatchNorm2d(last_planes)
79 | self.last = nn.Linear(last_planes, num_classes)
80 |
81 | def _make_layer(self, block, planes, num_blocks, stride):
82 | strides = [stride] + [1]*(num_blocks-1)
83 | layers = []
84 | for stride in strides:
85 | layers.append(block(self.in_planes, planes, stride))
86 | self.in_planes = planes * block.expansion
87 | return nn.Sequential(*layers)
88 |
89 | def features(self, x):
90 | out = self.conv1(x)
91 | out = self.stage1(out)
92 | out = self.stage2(out)
93 | out = self.stage3(out)
94 | out = self.stage4(out)
95 | return out
96 |
97 | def logits(self, x):
98 | x = self.last(x)
99 | return x
100 |
101 | def forward(self, x):
102 | x = self.features(x)
103 | x = F.relu(self.bn_last(x))
104 | x = F.adaptive_avg_pool2d(x, 1)
105 | x = self.logits(x.view(x.size(0), -1))
106 | return x
107 |
108 |
109 | class PreActResNet_cifar(nn.Module):
110 | def __init__(self, block, num_blocks, filters, num_classes=10, droprate=0):
111 | super(PreActResNet_cifar, self).__init__()
112 | self.in_planes = 16
113 | last_planes = filters[2]*block.expansion
114 |
115 | self.conv1 = conv3x3(3, self.in_planes)
116 | self.stage1 = self._make_layer(block, filters[0], num_blocks[0], stride=1, droprate=droprate)
117 | self.stage2 = self._make_layer(block, filters[1], num_blocks[1], stride=2, droprate=droprate)
118 | self.stage3 = self._make_layer(block, filters[2], num_blocks[2], stride=2, droprate=droprate)
119 | self.bn_last = nn.BatchNorm2d(last_planes)
120 | self.last = nn.Linear(last_planes, num_classes)
121 |
122 | """
123 | for m in self.modules():
124 | if isinstance(m, nn.Conv2d):
125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
126 | m.weight.data.normal_(0, math.sqrt(2. / n))
127 | # m.bias.data.zero_()
128 | elif isinstance(m, nn.BatchNorm2d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 | elif isinstance(m, nn.Linear):
132 | init.kaiming_normal(m.weight)
133 | m.bias.data.zero_()
134 | """
135 |
136 | def _make_layer(self, block, planes, num_blocks, stride, droprate):
137 | strides = [stride] + [1]*(num_blocks-1)
138 | layers = []
139 | for stride in strides:
140 | layers.append(block(self.in_planes, planes, stride, droprate))
141 | self.in_planes = planes * block.expansion
142 | return nn.Sequential(*layers)
143 |
144 | def features(self, x):
145 | out = self.conv1(x)
146 | out = self.stage1(out)
147 | out = self.stage2(out)
148 | out = self.stage3(out)
149 | return out
150 |
151 | def logits(self, x):
152 | x = self.last(x)
153 | return x
154 |
155 | def forward(self, x):
156 | out = self.features(x)
157 | out = F.relu(self.bn_last(out))
158 | out = F.avg_pool2d(out, 8)
159 | out = self.logits(out.view(out.size(0), -1))
160 | return out
161 |
162 |
163 | # ResNet for Cifar10/100 or the dataset with image size 32x32
164 |
165 | def ResNet20_cifar(out_dim=10):
166 | return PreActResNet_cifar(PreActBlock, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim)
167 |
168 | def ResNet56_cifar(out_dim=10):
169 | return PreActResNet_cifar(PreActBlock, [9 , 9 , 9 ], [16, 32, 64], num_classes=out_dim)
170 |
171 | def ResNet110_cifar(out_dim=10):
172 | return PreActResNet_cifar(PreActBlock, [18, 18, 18], [16, 32, 64], num_classes=out_dim)
173 |
174 | def ResNet29_cifar(out_dim=10):
175 | return PreActResNet_cifar(PreActBottleneck, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim)
176 |
177 | def ResNet164_cifar(out_dim=10):
178 | return PreActResNet_cifar(PreActBottleneck, [18, 18, 18], [16, 32, 64], num_classes=out_dim)
179 |
180 | def WideResNet_28_2_cifar(out_dim=10):
181 | return PreActResNet_cifar(PreActBlock, [4, 4, 4], [32, 64, 128], num_classes=out_dim)
182 |
183 | def WideResNet_28_2_drop_cifar(out_dim=10):
184 | return PreActResNet_cifar(PreActBlock, [4, 4, 4], [32, 64, 128], num_classes=out_dim, droprate=0.3)
185 |
186 | def WideResNet_28_10_cifar(out_dim=10):
187 | return PreActResNet_cifar(PreActBlock, [4, 4, 4], [160, 320, 640], num_classes=out_dim)
188 |
189 | # ResNet for general purpose. Ex:ImageNet
190 |
191 | def ResNet10(out_dim=10):
192 | return PreActResNet(PreActBlock, [1,1,1,1], num_classes=out_dim)
193 |
194 | def ResNet18S(out_dim=10):
195 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=out_dim, in_channels=1)
196 |
197 | def ResNet18(out_dim=10):
198 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=out_dim)
199 |
200 | def ResNet34(out_dim=10):
201 | return PreActResNet(PreActBlock, [3,4,6,3], num_classes=out_dim)
202 |
203 | def ResNet50(out_dim=10):
204 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes=out_dim)
205 |
206 | def ResNet101(out_dim=10):
207 | return PreActResNet(PreActBottleneck, [3,4,23,3], num_classes=out_dim)
208 |
209 | def ResNet152(out_dim=10):
210 | return PreActResNet(PreActBottleneck, [3,8,36,3], num_classes=out_dim)
--------------------------------------------------------------------------------
/models/senet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .resnet import conv3x3, PreActResNet, PreActResNet_cifar
5 |
6 |
7 | class SE_PreActBlock(nn.Module):
8 | '''Pre-activation version of the BasicBlock.'''
9 | expansion = 1
10 |
11 | def __init__(self, in_planes, planes, stride=1):
12 | super(SE_PreActBlock, self).__init__()
13 | self.bn1 = nn.BatchNorm2d(in_planes)
14 | self.conv1 = conv3x3(in_planes, planes, stride)
15 | self.bn2 = nn.BatchNorm2d(planes)
16 | self.conv2 = conv3x3(planes, planes)
17 |
18 | if stride != 1 or in_planes != self.expansion*planes:
19 | self.shortcut = nn.Sequential(
20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
21 | )
22 |
23 | # SE layers
24 | self.fc1 = nn.Conv2d(planes, planes // 16, kernel_size=1)
25 | self.fc2 = nn.Conv2d(planes // 16, planes, kernel_size=1)
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(x))
29 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
30 | out = self.conv1(out)
31 | out = self.conv2(F.relu(self.bn2(out)))
32 | # Squeeze
33 | w = F.avg_pool2d(out, out.size(2))
34 | w = F.relu(self.fc1(w))
35 | w = torch.sigmoid(self.fc2(w))
36 | # Excitation
37 | out = out * w # New broadcasting feature from v0.2!
38 | out += shortcut
39 | return out
40 |
41 |
42 | class SE_PreActBottleneck(nn.Module):
43 | '''Pre-activation version of the original Bottleneck module.'''
44 | expansion = 4
45 |
46 | def __init__(self, in_planes, planes, stride=1):
47 | super(SE_PreActBottleneck, self).__init__()
48 | self.bn1 = nn.BatchNorm2d(in_planes)
49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
50 | self.bn2 = nn.BatchNorm2d(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52 | self.bn3 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
54 |
55 | if stride != 1 or in_planes != self.expansion*planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
58 | )
59 |
60 | # SE layers
61 | self.fc1 = nn.Conv2d(self.expansion*planes, self.expansion*planes // 16, kernel_size=1)
62 | self.fc2 = nn.Conv2d(self.expansion*planes // 16, self.expansion*planes, kernel_size=1)
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(x))
66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
67 | out = self.conv1(out)
68 | out = self.conv2(F.relu(self.bn2(out)))
69 | out = self.conv3(F.relu(self.bn3(out)))
70 | # Squeeze
71 | w = F.avg_pool2d(out, out.size(2))
72 | w = F.relu(self.fc1(w))
73 | w = torch.sigmoid(self.fc2(w))
74 | # Excitation
75 | out = out * w
76 | out += shortcut
77 | return out
78 |
79 |
80 | # ResNet for Cifar10/100 or the dataset with image size 32x32
81 |
82 | def SE_ResNet20_cifar(out_dim=10):
83 | return PreActResNet_cifar(SE_PreActBlock, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim)
84 |
85 | def SE_ResNet56_cifar(out_dim=10):
86 | return PreActResNet_cifar(SE_PreActBlock, [9 , 9 , 9 ], [16, 32, 64], num_classes=out_dim)
87 |
88 | def ResNet110_cifar(out_dim=10):
89 | return PreActResNet_cifar(SE_PreActBlock, [18, 18, 18], [16, 32, 64], num_classes=out_dim)
90 |
91 | def SE_ResNet29_cifar(out_dim=10):
92 | return PreActResNet_cifar(SE_PreActBottleneck, [3 , 3 , 3 ], [16, 32, 64], num_classes=out_dim)
93 |
94 | def SE_ResNet164_cifar(out_dim=10):
95 | return PreActResNet_cifar(SE_PreActBottleneck, [18, 18, 18], [16, 32, 64], num_classes=out_dim)
96 |
97 | def SE_WideResNet_28_2_cifar(out_dim=10):
98 | return PreActResNet_cifar(SE_PreActBlock, [4, 4, 4], [32, 64, 128], num_classes=out_dim)
99 |
100 | def SE_WideResNet_28_10_cifar(out_dim=10):
101 | return PreActResNet_cifar(SE_PreActBlock, [4, 4, 4], [160, 320, 640], num_classes=out_dim)
102 |
103 | # ResNet for general purpose. Ex:ImageNet
104 |
105 | def SE_ResNet10(out_dim=10):
106 | return PreActResNet(SE_PreActBlock, [1,1,1,1], num_classes=out_dim)
107 |
108 | def SE_ResNet18S(out_dim=10):
109 | return PreActResNet(SE_PreActBlock, [2,2,2,2], num_classes=out_dim, in_channels=1)
110 |
111 | def SE_ResNet18(out_dim=10):
112 | return PreActResNet(SE_PreActBlock, [2,2,2,2], num_classes=out_dim)
113 |
114 | def SE_ResNet34(out_dim=10):
115 | return PreActResNet(SE_PreActBlock, [3,4,6,3], num_classes=out_dim)
116 |
117 | def SE_ResNet50(out_dim=10):
118 | return PreActResNet(SE_PreActBottleneck, [3,4,6,3], num_classes=out_dim)
119 |
120 | def SE_ResNet101(out_dim=10):
121 | return PreActResNet(SE_PreActBottleneck, [3,4,23,3], num_classes=out_dim)
122 |
123 | def SE_ResNet152(out_dim=10):
124 | return PreActResNet(SE_PreActBottleneck, [3,8,36,3], num_classes=out_dim)
125 |
126 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GT-RIPL/Continual-Learning-Benchmark/d78b9973b6ec0059b2d2577872db355ae2489f6b/modules/__init__.py
--------------------------------------------------------------------------------
/modules/criterions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class BCEauto(torch.nn.BCEWithLogitsLoss):
4 | """
5 | BCE with logits loss + automatically convert the target from class label to one-hot vector
6 | """
7 | def forward(self, x, y):
8 | assert x.ndimension() == 2, 'Input size must be 2D'
9 | assert y.numel() == x.size(0), 'The size of input and target doesnt match. Number of input:' + str(x.size(0)) + ' Number of target:' + str(y.numel())
10 | y_onehot = x.clone().zero_()
11 | y_onehot.scatter_(1, y.view(-1, 1), 1)
12 |
13 | return super(BCEauto, self).forward(x, y_onehot)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.1
2 | torchvision>=0.2.1
3 | argparse
4 | quadprog
5 |
6 |
--------------------------------------------------------------------------------
/scripts/permuted_MNIST_incremental_class.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/permuted_MNIST_incremental_class
3 | REPEAT=10
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 --offline_training | tee ${OUTDIR}/Offline.log
6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 | tee ${OUTDIR}/Adam.log
7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/SGD.log
8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adagrad --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/Adagrad.log
9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_online_mnist --lr 0.0001 --reg_coef 50 | tee ${OUTDIR}/EWC_online.log
10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_mnist --lr 0.0001 --reg_coef 10 | tee ${OUTDIR}/EWC.log
11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name SI --lr 0.0001 --reg_coef 0.3 | tee ${OUTDIR}/SI.log
12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name L2 --lr 0.0001 --reg_coef 0 | tee ${OUTDIR}/L2.log
13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_4000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_4000.log
14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_16000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_16000.log
15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name MAS --lr 0.0001 --reg_coef 0.003 | tee ${OUTDIR}/MAS.log
16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_4000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4000.log
17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --n_permutation 10 --force_out_dim 100 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_16000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_16000.log
18 |
--------------------------------------------------------------------------------
/scripts/permuted_MNIST_incremental_domain.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/permuted_MNIST_incremental_domain
3 | REPEAT=10
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 --offline_training | tee ${OUTDIR}/Offline.log
6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 | tee ${OUTDIR}/Adam.log
7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/SGD.log
8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/Adagrad.log
9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_online --lr 0.0001 --reg_coef 250 | tee ${OUTDIR}/EWC_online.log
10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC --lr 0.0001 --reg_coef 150 | tee ${OUTDIR}/EWC.log
11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name SI --lr 0.0001 --reg_coef 10 | tee ${OUTDIR}/SI.log
12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name L2 --lr 0.0001 --reg_coef 0.02 | tee ${OUTDIR}/L2.log
13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_4000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_4000.log
14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_16000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_16000.log
15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name MAS --lr 0.0001 --reg_coef 0.1 | tee ${OUTDIR}/MAS.log
16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_4000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4000.log
17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 10 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_16000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_16000.log
--------------------------------------------------------------------------------
/scripts/permuted_MNIST_incremental_task.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/permuted_MNIST_incremental_task
3 | REPEAT=10
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 --offline_training | tee ${OUTDIR}/Offline.log
6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.0001 | tee ${OUTDIR}/Adam.log
7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.01 | tee ${OUTDIR}/SGD.log
8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --lr 0.001 | tee ${OUTDIR}/Adagrad.log
9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_online_mnist --lr 0.0001 --reg_coef 500 | tee ${OUTDIR}/EWC_online.log
10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name EWC_mnist --lr 0.0001 --reg_coef 500 | tee ${OUTDIR}/EWC.log
11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name SI --lr 0.0001 --reg_coef 1 | tee ${OUTDIR}/SI.log
12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name L2 --lr 0.0001 --reg_coef 0.001 | tee ${OUTDIR}/L2.log
13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_4000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_4000.log
14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name Naive_Rehearsal_16000 --lr 0.0001 | tee ${OUTDIR}/Naive_Rehearsal_16000.log
15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type regularization --agent_name MAS --lr 0.0001 --reg_coef 0.01 | tee ${OUTDIR}/MAS.log
16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_4000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4000.log
17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --n_permutation 10 --no_class_remap --force_out_dim 0 --schedule 10 --batch_size 128 --model_name MLP1000 --agent_type customization --agent_name GEM_16000 --lr 0.1 --reg_coef 0.5 | tee ${OUTDIR}/GEM_16000.log
18 |
--------------------------------------------------------------------------------
/scripts/split_CIFAR100_incremental_class.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/split_CIFAR100_incremental_class
3 | REPEAT=5
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.1 --momentum 0.9 --weight_decay 1e-4 --offline_training | tee ${OUTDIR}/Offline_SGD_WideResNet_28_2_cifar.log
6 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 --offline_training | tee ${OUTDIR}/Offline_Adam_WideResNet_28_2_cifar.log
7 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 | tee ${OUTDIR}/Adam.log
8 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.1 | tee ${OUTDIR}/SGD.log
9 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adagrad --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.1 | tee ${OUTDIR}/Adagrad.log
10 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC --lr 0.001 --reg_coef 2 | tee ${OUTDIR}/EWC.log
11 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 2 | tee ${OUTDIR}/EWC_online.log
12 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 0.001 | tee ${OUTDIR}/SI.log
13 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 500 | tee ${OUTDIR}/L2.log
14 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_1400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1400.log
15 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4600.log
16 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 0.001 |tee ${OUTDIR}/MAS.log
--------------------------------------------------------------------------------
/scripts/split_CIFAR100_incremental_domain.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/split_CIFAR100_incremental_domain
3 | REPEAT=5
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --momentum 0.9 --weight_decay 1e-4 --lr 0.1 --offline_training | tee ${OUTDIR}/Offline_SGD.log
6 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 --offline_training | tee ${OUTDIR}/Offline_adam.log
7 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 | tee ${OUTDIR}/Adam.log
8 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/SGD.log
9 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/Adagrad.log
10 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 20 | tee ${OUTDIR}/EWC_online.log
11 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC --lr 0.001 --reg_coef 10 | tee ${OUTDIR}/EWC.log
12 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 10000 | tee ${OUTDIR}/SI.log
13 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 0.0001 | tee ${OUTDIR}/L2.log
14 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_1400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1400.log
15 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_5600.log
16 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 20 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 1000000 | tee ${OUTDIR}/MAS.log
--------------------------------------------------------------------------------
/scripts/split_CIFAR100_incremental_task.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/split_CIFAR100_incremental_task
3 | REPEAT=5
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --momentum 0.9 --weight_decay 1e-4 --lr 0.1 --offline_training | tee ${OUTDIR}/Offline_SGD.log
6 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 --offline_training | tee ${OUTDIR}/Offline_adam.log
7 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.001 | tee ${OUTDIR}/Adam.log
8 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/SGD.log
9 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --lr 0.01 | tee ${OUTDIR}/Adagrad.log
10 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 3000 | tee ${OUTDIR}/EWC_online.log
11 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC.log
12 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 2 | tee ${OUTDIR}/SI.log
13 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 1 | tee ${OUTDIR}/L2.log
14 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_1400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1400.log
15 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_5600.log
16 | python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 10 | tee ${OUTDIR}/MAS.log
--------------------------------------------------------------------------------
/scripts/split_MNIST_incremental_class.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/split_MNIST_incremental_class
3 | REPEAT=10
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 --offline_training | tee ${OUTDIR}/Offline.log
6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 | tee ${OUTDIR}/Adam.log
7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/SGD.log
8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adagrad --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/Adagrad.log
9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_mnist --lr 0.001 --reg_coef 600 | tee ${OUTDIR}/EWC.log
10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_online_mnist --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC_online.log
11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 600 | tee ${OUTDIR}/SI.log
12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/L2.log
13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_1100 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1100.log
14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_4400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4400.log
15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer Adam --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 1 |tee ${OUTDIR}/MAS.log
16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_1100 --lr 0.01 --reg_coef 0.5 |tee ${OUTDIR}/GEM_1100.log
17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --incremental_class --optimizer SGD --force_out_dim 10 --no_class_remap --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_4400 --lr 0.01 --reg_coef 0.5 |tee ${OUTDIR}/GEM_4400.log
--------------------------------------------------------------------------------
/scripts/split_MNIST_incremental_domain.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/split_MNIST_incremental_domain
3 | REPEAT=10
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 --offline_training | tee ${OUTDIR}/Offline.log
6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 | tee ${OUTDIR}/Adam.log
7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/SGD.log
8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/Adagrad.log
9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_online_mnist --lr 0.001 --reg_coef 700 | tee ${OUTDIR}/EWC_online.log
10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_mnist --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC.log
11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 3000 | tee ${OUTDIR}/SI.log
12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 0.5 | tee ${OUTDIR}/L2.log
13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_1100 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1100.log
14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_4400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4400.log
15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 10000 | tee ${OUTDIR}/MAS.log
16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_1100 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_1100.log
17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 2 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_4400 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4400.log
--------------------------------------------------------------------------------
/scripts/split_MNIST_incremental_task.sh:
--------------------------------------------------------------------------------
1 | GPUID=$1
2 | OUTDIR=outputs/split_MNIST_incremental_task
3 | REPEAT=10
4 | mkdir -p $OUTDIR
5 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 --offline_training | tee ${OUTDIR}/Offline.log
6 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.001 | tee ${OUTDIR}/Adam.log
7 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/SGD.log
8 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adagrad --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --lr 0.01 | tee ${OUTDIR}/Adagrad.log
9 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_online_mnist --lr 0.001 --reg_coef 400 | tee ${OUTDIR}/EWC_online.log
10 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name EWC_mnist --lr 0.001 --reg_coef 100 | tee ${OUTDIR}/EWC.log
11 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name SI --lr 0.001 --reg_coef 300 | tee ${OUTDIR}/SI.log
12 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name L2 --lr 0.001 --reg_coef 0.01 | tee ${OUTDIR}/L2.log
13 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_1100 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_1100.log
14 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name Naive_Rehearsal_4400 --lr 0.001 | tee ${OUTDIR}/Naive_Rehearsal_4400.log
15 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer Adam --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type regularization --agent_name MAS --lr 0.001 --reg_coef 1 | tee ${OUTDIR}/MAS.log
16 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_1100 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_1100.log
17 | python -u iBatchLearn.py --gpuid $GPUID --repeat $REPEAT --optimizer SGD --force_out_dim 0 --first_split_size 2 --other_split_size 2 --schedule 4 --batch_size 128 --model_name MLP400 --agent_type customization --agent_name GEM_4400 --lr 0.01 --reg_coef 0.5 | tee ${OUTDIR}/GEM_4400.log
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GT-RIPL/Continual-Learning-Benchmark/d78b9973b6ec0059b2d2577872db355ae2489f6b/utils/__init__.py
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 |
4 | def accuracy(output, target, topk=(1,)):
5 | """Computes the precision@k for the specified values of k"""
6 | with torch.no_grad():
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum().item()
17 | res.append(correct_k*100.0 / batch_size)
18 |
19 | if len(res)==1:
20 | return res[0]
21 | else:
22 | return res
23 |
24 |
25 | class AverageMeter(object):
26 | """Computes and stores the average and current value"""
27 |
28 | def __init__(self):
29 | self.reset()
30 |
31 | def reset(self):
32 | self.val = 0
33 | self.avg = 0
34 | self.sum = 0
35 | self.count = 0
36 |
37 | def update(self, val, n=1):
38 | self.val = val
39 | self.sum += val * n
40 | self.count += n
41 | self.avg = float(self.sum) / self.count
42 |
43 |
44 | class Timer(object):
45 | """
46 | """
47 |
48 | def __init__(self):
49 | self.reset()
50 |
51 | def reset(self):
52 | self.interval = 0
53 | self.time = time.time()
54 |
55 | def value(self):
56 | return time.time() - self.time
57 |
58 | def tic(self):
59 | self.time = time.time()
60 |
61 | def toc(self):
62 | self.interval = time.time() - self.time
63 | self.time = time.time()
64 | return self.interval
--------------------------------------------------------------------------------