--base-epochs 500 --batch_mixup --batch_logitnoise --ema_alpha 0.001 --ema_frequency 0.003 --distill_version l2 --distill_weight 0.05 --distill_weight_buffer 0.001 --rep_noise_weight 1.0 --repnoise_prob 0.5 --finetune_weight 2 --representation_replay --replay_from 1 --sep_memory --num_workers 8 --csv_filename results.csv --memory-size 500 --tensorboard --epochs 500
27 | ```
28 |
29 | ### Hyperparameters for other settings:
30 |
31 | | Dataset | Num of Tasks | Buffer Size | ema_alpha | ema_frequency | distill_weight | distill_weight_buffer |
32 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
33 | | CIFAR-100 | 5 | 200 | 0.0005 | 0.001 | 0.05 | 0.01 |
34 | | | | 500 | 0.005 | 0.003 | 0.05 | 0.01 |
35 | | | 10 | 200 | 0.001 | 0.003 | 0.05 | 0.001 |
36 | | | | 500 | 0.001 | 0.003 | 0.05 | 0.001 |
37 | | | | 1000 | 0.0005 | 0.0008 | 0.05 | 0.01 |
38 | | | | 2000 | 0.0002 | 0.0015 | 0.05 | 0.01 |
39 | | | 20 | 200 | 0.005 | 0.001 | 0.05 | 0.08 |
40 | | | | 500 | 0.0005 | 0.003 | 0.05 | 0.1 |
41 | | TINYIMAGENET | 10 | 500 | 0.001 | 0.003 | 0.05 | 0.01 |
42 | | | | 1000 | 0.01 | 0.0008 | 0.01 | 0.001 |
43 | | | | 2000 | 0.0001 | 0.008 | 0.01 | 0.0008 |
44 | | IMAGENET- 100 | 10 | 500 | 0.0001 | 0.003 | 0.05 | 0.001 |
45 | | | | 1000 | 0.0001 | 0.003 | 0.05 | 0.001 |
46 | | | | 2000 | 0.01 | 0.005 | 0.01 | 0.001 |
47 |
48 | ## Cite Our Work:
49 |
50 | If you find the code useful in your research please consider citing our paper:
51 |
52 |
53 | @article{jeeveswaran2023birt,
54 | title={BiRT: Bio-inspired Replay in Vision Transformers for Continual Learning},
55 | author={Jeeveswaran, Kishaan and Bhat, Prashant and Zonooz, Bahram and Arani, Elahe},
56 | journal={arXiv preprint arXiv:2305.04769},
57 | year={2023}
58 | }
59 |
60 |
--------------------------------------------------------------------------------
/continual/__init__.py:
--------------------------------------------------------------------------------
1 | from continual import rehearsal
2 | from continual import classifier
3 | from continual import vit
4 | from continual import convit
5 | from continual import utils
6 | from continual import scaler
7 | from continual import cnn
8 | from continual import factory
9 | from continual import sam
10 | from continual import samplers
11 | from continual import mixup
12 |
--------------------------------------------------------------------------------
/continual/birt.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | from timm.models.layers import trunc_normal_
5 | from torch import nn
6 |
7 | import continual.utils as cutils
8 |
9 |
10 | class ContinualClassifier(nn.Module):
11 | """Your good old classifier to do continual."""
12 | def __init__(self, embed_dim, nb_classes):
13 | super().__init__()
14 |
15 | self.embed_dim = embed_dim
16 | self.nb_classes = nb_classes
17 | self.head = nn.Linear(embed_dim, nb_classes, bias=True)
18 | self.norm = nn.LayerNorm(embed_dim)
19 |
20 | def reset_parameters(self):
21 | self.head.reset_parameters()
22 | self.norm.reset_parameters()
23 |
24 | def forward(self, x):
25 | x = self.norm(x)
26 | return self.head(x)
27 |
28 | def add_new_outputs(self, n):
29 | head = nn.Linear(self.embed_dim, self.nb_classes + n, bias=True)
30 | head.weight.data[:-n] = self.head.weight.data
31 |
32 | head.to(self.head.weight.device)
33 | self.head = head
34 | self.nb_classes += n
35 |
36 |
37 | class BiRT(nn.Module):
38 | """"
39 | :param transformer: The base transformer.
40 | :param nb_classes: Thhe initial number of classes.
41 | :param individual_classifier: Classifier config, DyTox is in `1-1`.
42 | :param head_div: Whether to use the divergence head for improved diversity.
43 | :param head_div_mode: Use the divergence head in TRaining, FineTuning, or both.
44 | :param joint_tokens: Use a single TAB forward with masked attention (faster but a bit worse).
45 | """
46 | def __init__(
47 | self,
48 | transformer,
49 | nb_classes,
50 | individual_classifier='',
51 | head_div=False,
52 | head_div_mode=['tr', 'ft'],
53 | joint_tokens=False,
54 | num_blocks=None,
55 | multi_token_setup=False
56 | ):
57 | super().__init__()
58 |
59 | self.nb_classes = nb_classes # 2
60 | self.embed_dim = transformer.embed_dim
61 | self.individual_classifier = individual_classifier
62 | self.use_head_div = head_div # true
63 | self.head_div_mode = head_div_mode # tr
64 | self.head_div = None
65 | self.joint_tokens = joint_tokens # False
66 | self.in_finetuning = False
67 | self.multi_token_setup = multi_token_setup
68 |
69 | self.num_blocks = num_blocks
70 |
71 |
72 | self.nb_classes_per_task = [nb_classes]
73 |
74 | self.patch_embed = transformer.patch_embed
75 | self.pos_embed = transformer.pos_embed
76 | self.pos_drop = transformer.pos_drop
77 | self.sabs = transformer.blocks[:-1]
78 |
79 | self.tabs = transformer.blocks[-1:]
80 |
81 | self.task_tokens = nn.ParameterList([transformer.cls_token])
82 |
83 | if self.individual_classifier != '':
84 | in_dim, out_dim = self._get_ind_clf_dim() # 384, 10
85 | self.head = nn.ModuleList([
86 | ContinualClassifier(in_dim, out_dim).cuda()
87 | ])
88 | else:
89 | self.head = ContinualClassifier(
90 | self.embed_dim * len(self.task_tokens), sum(self.nb_classes_per_task)
91 | ).cuda()
92 |
93 | def end_finetuning(self):
94 | """Start FT mode, usually with backbone freezed and balanced classes."""
95 | self.in_finetuning = False
96 |
97 | def begin_finetuning(self):
98 | """End FT mode, usually with backbone freezed and balanced classes."""
99 | self.in_finetuning = True
100 |
101 | def add_model(self, nb_new_classes, multi_token_setup=False):
102 | """Expand model as per the DyTox framework given `nb_new_classes`.
103 |
104 | :param nb_new_classes: Number of new classes brought by the new task.
105 | """
106 | self.nb_classes_per_task.append(nb_new_classes)
107 |
108 | # Class tokens ---------------------------------------------------------
109 | new_task_token = copy.deepcopy(self.task_tokens[-1])
110 | trunc_normal_(new_task_token, std=.02)
111 | self.task_tokens.append(new_task_token)
112 | # ----------------------------------------------------------------------
113 |
114 | # Diversity head -------------------------------------------------------
115 | if self.use_head_div:
116 | self.head_div = ContinualClassifier(
117 | self.sabs[-1].dim, self.nb_classes_per_task[-1] + 1
118 | ).cuda()
119 | # ----------------------------------------------------------------------
120 |
121 | # Classifier -----------------------------------------------------------
122 | if self.individual_classifier != '' and not multi_token_setup:
123 | in_dim, out_dim = self._get_ind_clf_dim()
124 | self.head.append(
125 | ContinualClassifier(in_dim, out_dim).cuda()
126 | )
127 | elif not multi_token_setup:
128 | self.head = ContinualClassifier(
129 | self.embed_dim * len(self.task_tokens), sum(self.nb_classes_per_task)
130 | ).cuda()
131 | # ----------------------------------------------------------------------
132 |
133 | def _get_ind_clf_dim(self):
134 | """What are the input and output dim of classifier depending on its config.
135 |
136 | By default, DyTox is in 1-1.
137 | """
138 | if self.individual_classifier == '1-1':
139 | in_dim = self.sabs[-1].dim
140 | out_dim = self.nb_classes_per_task[-1]
141 | elif self.individual_classifier == '1-n':
142 | in_dim = self.embed_dim
143 | out_dim = sum(self.nb_classes_per_task)
144 | elif self.individual_classifier == 'n-n':
145 | in_dim = len(self.task_tokens) * self.embed_dim
146 | out_dim = sum(self.nb_classes_per_task)
147 | elif self.individual_classifier == 'n-1':
148 | in_dim = len(self.task_tokens) * self.embed_dim
149 | out_dim = self.nb_classes_per_task[-1]
150 | else:
151 | raise NotImplementedError(f'Unknown ind classifier {self.individual_classifier}')
152 | return in_dim, out_dim
153 |
154 | def freeze(self, names):
155 | """Choose what to freeze depending on the name of the module."""
156 | requires_grad = False
157 | cutils.freeze_parameters(self, requires_grad=not requires_grad)
158 | self.train()
159 |
160 | for name in names:
161 | if name == 'all':
162 | self.eval()
163 | return cutils.freeze_parameters(self)
164 | elif name == 'multitoken_all':
165 | # self.eval()
166 | return cutils.freeze_parameters(self)
167 | elif name == 'old_task_tokens':
168 | cutils.freeze_parameters(self.task_tokens[:-1], requires_grad=requires_grad)
169 | elif name == 'freeze_token':
170 | cutils.freeze_parameters(self.task_tokens[-1], requires_grad=requires_grad)
171 | elif name == 'task_tokens':
172 | cutils.freeze_parameters(self.task_tokens, requires_grad=requires_grad)
173 | elif name == 'sab':
174 | self.sabs.eval()
175 | cutils.freeze_parameters(self.patch_embed, requires_grad=requires_grad)
176 | cutils.freeze_parameters(self.pos_embed, requires_grad=requires_grad)
177 | cutils.freeze_parameters(self.sabs, requires_grad=requires_grad)
178 | elif name == 'partial_sab':
179 | cutils.freeze_parameters(self.patch_embed, requires_grad=requires_grad)
180 | cutils.freeze_parameters(self.pos_embed, requires_grad=requires_grad)
181 | cutils.freeze_parameters(self.sabs[:self.num_blocks], requires_grad=requires_grad)
182 | elif name == 'tab':
183 | self.tabs.eval()
184 | cutils.freeze_parameters(self.tabs, requires_grad=requires_grad)
185 | elif name == 'old_heads':
186 | self.head[:-1].eval()
187 | cutils.freeze_parameters(self.head[:-1], requires_grad=requires_grad)
188 | elif name == 'heads':
189 | self.head.eval()
190 | cutils.freeze_parameters(self.head, requires_grad=requires_grad)
191 | elif name == 'head_div':
192 | self.head_div.eval()
193 | cutils.freeze_parameters(self.head_div, requires_grad=requires_grad)
194 | else:
195 | raise NotImplementedError(f'Unknown name={name}.')
196 |
197 | def param_groups(self):
198 | return {
199 | 'all': self.parameters(),
200 | 'old_task_tokens': self.task_tokens[:-1],
201 | 'task_tokens': self.task_tokens.parameters(),
202 | 'new_task_tokens': [self.task_tokens[-1]],
203 | 'sa': self.sabs.parameters(),
204 | 'patch': self.patch_embed.parameters(),
205 | 'pos': [self.pos_embed],
206 | 'ca': self.tabs.parameters(),
207 | 'old_heads': self.head[:-self.nb_classes_per_task[-1]].parameters() \
208 | if self.individual_classifier else \
209 | self.head.parameters(),
210 | 'new_head': self.head[-1].parameters() if self.individual_classifier else self.head.parameters(),
211 | 'head': self.head.parameters(),
212 | 'head_div': self.head_div.parameters() if self.head_div is not None else None
213 | }
214 |
215 | def reset_classifier(self):
216 | if isinstance(self.head, nn.ModuleList):
217 | for head in self.head:
218 | head.reset_parameters()
219 | else:
220 | self.head.reset_parameters()
221 |
222 | def hook_before_update(self):
223 | pass
224 |
225 | def hook_after_update(self):
226 | pass
227 |
228 | def hook_after_epoch(self):
229 | pass
230 |
231 | def epoch_log(self):
232 | """Write here whatever you want to log on the internal state of the model."""
233 | log = {}
234 |
235 | # Compute mean distance between class tokens
236 | mean_dist, min_dist, max_dist = [], float('inf'), 0.
237 | with torch.no_grad():
238 | for i in range(len(self.task_tokens)):
239 | for j in range(i + 1, len(self.task_tokens)):
240 | dist = torch.norm(self.task_tokens[i] - self.task_tokens[j], p=2).item()
241 | mean_dist.append(dist)
242 |
243 | min_dist = min(dist, min_dist)
244 | max_dist = max(dist, max_dist)
245 |
246 | if len(mean_dist) > 0:
247 | mean_dist = sum(mean_dist) / len(mean_dist)
248 | else:
249 | mean_dist = 0.
250 | min_dist = 0.
251 |
252 | assert min_dist <= mean_dist <= max_dist, (min_dist, mean_dist, max_dist)
253 | log['token_mean_dist'] = round(mean_dist, 5)
254 | log['token_min_dist'] = round(min_dist, 5)
255 | log['token_max_dist'] = round(max_dist, 5)
256 | return log
257 |
258 | def get_internal_losses(self, clf_loss):
259 | """If you want to compute some internal loss, like a EWC loss for example.
260 |
261 | :param clf_loss: The main classification loss (if you wanted to use its gradient for example).
262 | :return: a dictionnary of losses, all values will be summed in the final loss.
263 | """
264 | int_losses = {}
265 | return int_losses
266 |
267 | def forward_initial(self, x):
268 | # Shared part, this is the ENCODER
269 | B = x.shape[0]
270 |
271 | x = self.patch_embed(x)
272 | if self.pos_embed is not None:
273 | x = x + self.pos_embed
274 | x = self.pos_drop(x)
275 |
276 | for blk in self.sabs[:self.num_blocks]:
277 | x, attn, v = blk(x)
278 |
279 | return x
280 |
281 | def forward_latter(self, x, args=None):
282 | B = x.shape[0]
283 | s_e, s_a, s_v = [], [], []
284 | for blk in self.sabs[self.num_blocks:]:
285 | x, attn, v = blk(x, args=args)
286 | s_e.append(x)
287 | s_a.append(attn)
288 | s_v.append(v)
289 |
290 | # Specific part, this is what we called the "task specific DECODER"
291 | if self.joint_tokens:
292 | return self.forward_features_jointtokens(x)
293 |
294 | tokens = []
295 | attentions = []
296 | mask_heads = None
297 |
298 | for task_token in self.task_tokens:
299 | task_token = task_token.expand(B, -1, -1)
300 |
301 | ca_blocks = self.tabs
302 |
303 | for blk in ca_blocks:
304 | task_token, attn, v = blk(torch.cat((task_token, x), dim=1), mask_heads=mask_heads)
305 |
306 | attentions.append(attn)
307 | tokens.append(task_token[:, 0])
308 |
309 | self._class_tokens = tokens
310 | return self.forward_classifier(tokens, tokens[-1], attentions)
311 |
312 | def forward_features_multitoken(self, x, batch_tasks=None):
313 | # Shared part, this is the ENCODER
314 | B = x.shape[0]
315 |
316 | x = self.patch_embed(x)
317 | x = x + self.pos_embed
318 | x = self.pos_drop(x)
319 |
320 | s_e, s_a, s_v = [], [], []
321 | for blk in self.sabs:
322 | x, attn, v = blk(x)
323 | s_e.append(x)
324 | if attn is not None:
325 | s_a.append(attn)
326 | s_v.append(v)
327 |
328 | # Specific part, this is what we called the "task specific DECODER"
329 | if self.joint_tokens:
330 | return self.forward_features_jointtokens(x)
331 |
332 | tokens = []
333 | attentions = []
334 | mask_heads = None
335 |
336 | if self.training:
337 | x_tokens = torch.cat((torch.cat([self.task_tokens[i] for i in batch_tasks]), x), dim=1)
338 | ca_blocks = self.tabs
339 | for blk in ca_blocks:
340 | task_token, attn, v = blk(x_tokens, mask_heads=mask_heads)
341 |
342 | attentions.append(attn.unsqueeze(dim=0))
343 | tokens.append(task_token[:, 0])
344 |
345 | self._class_tokens = tokens
346 | return tokens, tokens[-1], attentions
347 |
348 | else:
349 | for task_token in self.task_tokens:
350 | task_token = task_token.expand(B, -1, -1)
351 |
352 | ca_blocks = self.tabs
353 |
354 | for blk in ca_blocks:
355 | task_token, attn, v = blk(torch.cat((task_token, x), dim=1), mask_heads=mask_heads)
356 |
357 | attentions.append(attn.unsqueeze(dim=0))
358 | tokens.append(task_token[:, 0])
359 |
360 | self._class_tokens = tokens
361 | return tokens, tokens[-1], attentions
362 |
363 | def forward_features(self, x, batch_tasks=None):
364 | # Shared part, this is the ENCODER
365 | B = x.shape[0]
366 |
367 | x = self.patch_embed(x)
368 | if self.pos_embed is not None:
369 | x = x + self.pos_embed
370 | x = self.pos_drop(x)
371 |
372 | s_e, s_a, s_v = [], [], []
373 | for blk in self.sabs:
374 | x, attn, v = blk(x)
375 | s_e.append(x)
376 | if attn is not None:
377 | s_a.append(attn)
378 | if v is not None:
379 | s_v.append(v)
380 |
381 | # Specific part, this is what we called the "task specific DECODER"
382 | if self.joint_tokens:
383 | return self.forward_features_jointtokens(x)
384 |
385 | tokens = []
386 | attentions = []
387 | mask_heads = None
388 |
389 | for task_token in self.task_tokens:
390 | task_token = task_token.expand(B, -1, -1)
391 |
392 | ca_blocks = self.tabs
393 |
394 | for blk in ca_blocks:
395 | task_token, attn, v = blk(torch.cat((task_token, x), dim=1), mask_heads=mask_heads)
396 |
397 | attentions.append(attn.unsqueeze(dim=0))
398 | tokens.append(task_token[:, 0])
399 |
400 | self._class_tokens = tokens
401 | return tokens, tokens[-1], attentions
402 |
403 | def forward_features_jointtokens(self, x):
404 | """Method to do a single TAB forward with all task tokens.
405 |
406 | A masking is used to avoid interaction between tasks. In theory it should
407 | give the same results as multiple TAB forward, but in practice it's a little
408 | bit worse, not sure why. So if you have an idea, please tell me!
409 | """
410 | B = len(x)
411 |
412 | task_tokens = torch.cat(
413 | [task_token.expand(B, 1, -1) for task_token in self.task_tokens],
414 | dim=1
415 | )
416 |
417 | for blk in self.tabs:
418 | task_tokens, _, _ = blk(
419 | torch.cat((task_tokens, x), dim=1),
420 | task_index=len(self.task_tokens),
421 | attn_mask=True
422 | )
423 |
424 | if self.individual_classifier in ('1-1', '1-n'):
425 | return task_tokens.permute(1, 0, 2), task_tokens[:, -1], None
426 | return task_tokens.view(B, -1), task_tokens[:, -1], None
427 |
428 | def forward_classifier_multitoken(self, tokens, last_token, attentions):
429 | """Once all task embeddings e_1, ..., e_t are extracted, classify.
430 |
431 | Classifier has different mode based on a pattern x-y:
432 | - x means the number of task embeddings in input
433 | - y means the number of task to predict
434 |
435 | So:
436 | - n-n: predicts all task given all embeddings
437 | But:
438 | - 1-1: predict 1 task given 1 embedding, which is the 'independent classifier' used in the paper.
439 |
440 | :param tokens: A list of all task tokens embeddings.
441 | :param last_token: The ultimate task token embedding from the latest task.
442 | """
443 | logits_div = None
444 |
445 | logits = []
446 |
447 | # assuming self.individual_classifier is always '1-1' here
448 | for i in range(len(tokens)):
449 | logits.append(self.head[0](tokens[i]))
450 |
451 | logits = torch.cat(logits, dim=1)
452 |
453 | attentions = torch.cat(attentions, dim=0)
454 |
455 | return {
456 | 'logits': logits,
457 | 'div': logits_div,
458 | 'tokens': tokens, # 128, 384
459 | 'attention': attentions # 128, 12, 1, 65
460 | }
461 |
462 | def forward_classifier(self, tokens, last_token, attentions):
463 | """Once all task embeddings e_1, ..., e_t are extracted, classify.
464 |
465 | Classifier has different mode based on a pattern x-y:
466 | - x means the number of task embeddings in input
467 | - y means the number of task to predict
468 |
469 | So:
470 | - n-n: predicts all task given all embeddings
471 | But:
472 | - 1-1: predict 1 task given 1 embedding, which is the 'independent classifier' used in the paper.
473 |
474 | :param tokens: A list of all task tokens embeddings.
475 | :param last_token: The ultimate task token embedding from the latest task.
476 | """
477 | logits_div = None
478 |
479 | if self.individual_classifier != '':
480 | logits = []
481 |
482 | for i, head in enumerate(self.head):
483 | if self.individual_classifier in ('1-n', '1-1'):
484 | logits.append(head(tokens[i]))
485 | else: # n-1, n-n
486 | logits.append(head(torch.cat(tokens[:i+1], dim=1)))
487 |
488 | if self.individual_classifier in ('1-1', 'n-1'):
489 | logits = torch.cat(logits, dim=1)
490 | else: # 1-n, n-n
491 | final_logits = torch.zeros_like(logits[-1])
492 | for i in range(len(logits)):
493 | final_logits[:, :logits[i].shape[1]] += logits[i]
494 |
495 | for i, c in enumerate(self.nb_classes_per_task):
496 | final_logits[:, :c] /= len(self.nb_classes_per_task) - i
497 |
498 | logits = final_logits
499 | elif isinstance(tokens, torch.Tensor):
500 | logits = self.head(tokens)
501 | else:
502 | logits = self.head(torch.cat(tokens, dim=1))
503 |
504 | if self.head_div is not None and eval_training_finetuning(self.head_div_mode, self.in_finetuning):
505 | logits_div = self.head_div(last_token) # only last token
506 |
507 | # modify attentions list to extract only the first element
508 | attentions = torch.cat(attentions, dim=0)
509 |
510 | return {
511 | 'logits': logits,
512 | 'div': logits_div,
513 | 'tokens': tokens, # 128, 384
514 | 'attention': attentions # 128, 12, 1, 65
515 | }
516 |
517 | def forward(self, x, batch_tasks=None, initial=False, latter=False, args=None):
518 | if initial:
519 | return self.forward_initial(x)
520 | elif latter:
521 | return self.forward_latter(x, args=args)
522 | elif self.multi_token_setup:
523 | tokens, last_token, attentions = self.forward_features_multitoken(x, batch_tasks=batch_tasks)
524 | return self.forward_classifier_multitoken(tokens, last_token, attentions)
525 | else:
526 | tokens, last_token, attentions = self.forward_features(x, batch_tasks=batch_tasks)
527 | return self.forward_classifier(tokens, last_token, attentions)
528 |
529 |
530 | def eval_training_finetuning(mode, in_ft):
531 | if 'tr' in mode and 'ft' in mode:
532 | return True
533 | if 'tr' in mode and not in_ft:
534 | return True
535 | if 'ft' in mode and in_ft:
536 | return True
537 | return False
538 |
--------------------------------------------------------------------------------
/continual/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class Classifier(nn.Module):
7 | def __init__(self, embed_dim, nb_total_classes, nb_base_classes, increment, nb_tasks=None, bias=True, complete=True, cosine=False, norm=True):
8 | super().__init__()
9 |
10 | self.embed_dim = embed_dim
11 | self.nb_classes = nb_base_classes
12 | self.cosine = cosine # false
13 |
14 | if self.cosine not in (False, None, ''):
15 | self.scale = nn.Parameter(torch.tensor(1.))
16 | else:
17 | self.scale = 1
18 | self.head = nn.Linear(embed_dim, nb_base_classes, bias=not cosine)
19 | self.norm = nn.LayerNorm(embed_dim) if norm else nn.Identitty()
20 | self.increment = increment
21 |
22 | def reset_parameters(self):
23 | self.head.reset_parameters()
24 | self.norm.reset_parameters()
25 |
26 | def forward(self, x):
27 | x = self.norm(x)
28 |
29 | if self.cosine not in (False, None, ''):
30 | w = self.head.weight # (c, d)
31 |
32 | if self.cosine == 'pcc':
33 | x = x - x.mean(dim=1, keepdims=True)
34 | w = w - w.mean(dim=1, keepdims=True)
35 | x = F.normalize(x, p=2, dim=1) # (bs, d)
36 | w = F.normalize(w, p=2, dim=1) # (c, d)
37 | return self.scale * torch.mm(x, w.T)
38 |
39 | return self.head(x)
40 |
41 | def init_prev_head(self, head):
42 | w, b = head.weight.data, head.bias.data
43 | self.head.weight.data[:w.shape[0], :w.shape[1]] = w
44 | self.head.bias.data[:b.shape[0]] = b
45 |
46 | def init_prev_norm(self, norm):
47 | w, b = norm.weight.data, norm.bias.data
48 | self.norm.weight.data[:w.shape[0]] = w
49 | self.norm.bias.data[:b.shape[0]] = b
50 |
51 | @torch.no_grad()
52 | def weight_align(self, nb_new_classes):
53 | w = self.head.weight.data
54 | norms = torch.norm(w, dim=1)
55 |
56 | norm_old = norms[:-nb_new_classes]
57 | norm_new = norms[-nb_new_classes:]
58 |
59 | gamma = torch.mean(norm_old) / torch.mean(norm_new)
60 | w[-nb_new_classes:] = gamma * w[-nb_new_classes:]
61 |
62 | def add_classes(self):
63 | self.add_new_outputs(self.increment)
64 |
65 | def add_new_outputs(self, n):
66 | head = nn.Linear(self.embed_dim, self.nb_classes + n, bias=not self.cosine)
67 | head.weight.data[:-n] = self.head.weight.data
68 | if not self.cosine:
69 | head.bias.data[:-n] = self.head.bias.data
70 |
71 | head.to(self.head.weight.device)
72 | self.head = head
73 | self.nb_classes += n
74 |
--------------------------------------------------------------------------------
/continual/cnn/__init__.py:
--------------------------------------------------------------------------------
1 | from continual.cnn.abstract import AbstractCNN
2 | from continual.cnn.inception import InceptionV3
3 | from continual.cnn.senet import legacy_seresnet18 as seresnet18
4 | from continual.cnn.resnet import (
5 | resnet18, resnet34, resnet50, resnext50_32x4d, wide_resnet50_2
6 | )
7 | from continual.cnn.resnet_scs import resnet18_scs, resnet18_scs_avg, resnet18_scs_max
8 | from continual.cnn.vgg import vgg16_bn, vgg16
9 | from continual.cnn.resnet_rebuffi import CifarResNet as rebuffi
10 |
--------------------------------------------------------------------------------
/continual/cnn/abstract.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | import continual.utils as cutils
4 |
5 |
6 | class AbstractCNN(nn.Module):
7 | def reset_classifier(self):
8 | self.head.reset_parameters()
9 |
10 | def get_internal_losses(self, clf_loss):
11 | return {}
12 |
13 | def end_finetuning(self):
14 | pass
15 |
16 | def begin_finetuning(self):
17 | pass
18 |
19 | def epoch_log(self):
20 | return {}
21 |
22 | def get_classifier(self):
23 | return self.head
24 |
25 | def freeze(self, names):
26 | cutils.freeze_parameters(self, requires_grad=True)
27 | self.train()
28 |
29 | for name in names:
30 | if name == 'head':
31 | cutils.freeze_parameters(self.head)
32 | self.head.eval()
33 | elif name == 'backbone':
34 | for k, p in self.named_parameters():
35 | if not k.startswith('head'):
36 | cutils.freeze_parameters(p)
37 | elif name == 'all':
38 | cutils.freeze_parameters(self)
39 | self.eval()
40 | else:
41 | raise NotImplementedError(f'Unknown module name to freeze {name}')
42 |
--------------------------------------------------------------------------------
/continual/cnn/inception.py:
--------------------------------------------------------------------------------
1 | """ inceptionv3 in pytorch
2 |
3 |
4 | [1] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna
5 |
6 | Rethinking the Inception Architecture for Computer Vision
7 | https://arxiv.org/abs/1512.00567v3
8 | """
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | from continual.cnn import AbstractCNN
15 |
16 |
17 | class BasicConv2d(nn.Module):
18 |
19 | def __init__(self, input_channels, output_channels, **kwargs):
20 | super().__init__()
21 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
22 | self.bn = nn.BatchNorm2d(output_channels)
23 | self.relu = nn.ReLU(inplace=True)
24 |
25 | def forward(self, x):
26 | x = self.conv(x)
27 | x = self.bn(x)
28 | x = self.relu(x)
29 |
30 | return x
31 |
32 | #same naive inception module
33 | class InceptionA(nn.Module):
34 |
35 | def __init__(self, input_channels, pool_features):
36 | super().__init__()
37 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)
38 |
39 | self.branch5x5 = nn.Sequential(
40 | BasicConv2d(input_channels, 48, kernel_size=1),
41 | BasicConv2d(48, 64, kernel_size=5, padding=2)
42 | )
43 |
44 | self.branch3x3 = nn.Sequential(
45 | BasicConv2d(input_channels, 64, kernel_size=1),
46 | BasicConv2d(64, 96, kernel_size=3, padding=1),
47 | BasicConv2d(96, 96, kernel_size=3, padding=1)
48 | )
49 |
50 | self.branchpool = nn.Sequential(
51 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
52 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
53 | )
54 |
55 | def forward(self, x):
56 |
57 | #x -> 1x1(same)
58 | branch1x1 = self.branch1x1(x)
59 |
60 | #x -> 1x1 -> 5x5(same)
61 | branch5x5 = self.branch5x5(x)
62 | #branch5x5 = self.branch5x5_2(branch5x5)
63 |
64 | #x -> 1x1 -> 3x3 -> 3x3(same)
65 | branch3x3 = self.branch3x3(x)
66 |
67 | #x -> pool -> 1x1(same)
68 | branchpool = self.branchpool(x)
69 |
70 | outputs = [branch1x1, branch5x5, branch3x3, branchpool]
71 |
72 | return torch.cat(outputs, 1)
73 |
74 | #downsample
75 | #Factorization into smaller convolutions
76 | class InceptionB(nn.Module):
77 |
78 | def __init__(self, input_channels):
79 | super().__init__()
80 |
81 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)
82 |
83 | self.branch3x3stack = nn.Sequential(
84 | BasicConv2d(input_channels, 64, kernel_size=1),
85 | BasicConv2d(64, 96, kernel_size=3, padding=1),
86 | BasicConv2d(96, 96, kernel_size=3, stride=2)
87 | )
88 |
89 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)
90 |
91 | def forward(self, x):
92 |
93 | #x - > 3x3(downsample)
94 | branch3x3 = self.branch3x3(x)
95 |
96 | #x -> 3x3 -> 3x3(downsample)
97 | branch3x3stack = self.branch3x3stack(x)
98 |
99 | #x -> avgpool(downsample)
100 | branchpool = self.branchpool(x)
101 |
102 | #"""We can use two parallel stride 2 blocks: P and C. P is a pooling
103 | #layer (either average or maximum pooling) the activation, both of
104 | #them are stride 2 the filter banks of which are concatenated as in
105 | #figure 10."""
106 | outputs = [branch3x3, branch3x3stack, branchpool]
107 |
108 | return torch.cat(outputs, 1)
109 |
110 | #Factorizing Convolutions with Large Filter Size
111 | class InceptionC(nn.Module):
112 | def __init__(self, input_channels, channels_7x7):
113 | super().__init__()
114 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)
115 |
116 | c7 = channels_7x7
117 |
118 | #In theory, we could go even further and argue that one can replace any n × n
119 | #convolution by a 1 × n convolution followed by a n × 1 convolution and the
120 | #computational cost saving increases dramatically as n grows (see figure 6).
121 | self.branch7x7 = nn.Sequential(
122 | BasicConv2d(input_channels, c7, kernel_size=1),
123 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
124 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
125 | )
126 |
127 | self.branch7x7stack = nn.Sequential(
128 | BasicConv2d(input_channels, c7, kernel_size=1),
129 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
130 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
131 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
132 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
133 | )
134 |
135 | self.branch_pool = nn.Sequential(
136 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
137 | BasicConv2d(input_channels, 192, kernel_size=1),
138 | )
139 |
140 | def forward(self, x):
141 |
142 | #x -> 1x1(same)
143 | branch1x1 = self.branch1x1(x)
144 |
145 | #x -> 1layer 1*7 and 7*1 (same)
146 | branch7x7 = self.branch7x7(x)
147 |
148 | #x-> 2layer 1*7 and 7*1(same)
149 | branch7x7stack = self.branch7x7stack(x)
150 |
151 | #x-> avgpool (same)
152 | branchpool = self.branch_pool(x)
153 |
154 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]
155 |
156 | return torch.cat(outputs, 1)
157 |
158 | class InceptionD(nn.Module):
159 |
160 | def __init__(self, input_channels):
161 | super().__init__()
162 |
163 | self.branch3x3 = nn.Sequential(
164 | BasicConv2d(input_channels, 192, kernel_size=1),
165 | BasicConv2d(192, 320, kernel_size=3, stride=2)
166 | )
167 |
168 | self.branch7x7 = nn.Sequential(
169 | BasicConv2d(input_channels, 192, kernel_size=1),
170 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
171 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
172 | BasicConv2d(192, 192, kernel_size=3, stride=2)
173 | )
174 |
175 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)
176 |
177 | def forward(self, x):
178 |
179 | #x -> 1x1 -> 3x3(downsample)
180 | branch3x3 = self.branch3x3(x)
181 |
182 | #x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
183 | branch7x7 = self.branch7x7(x)
184 |
185 | #x -> avgpool (downsample)
186 | branchpool = self.branchpool(x)
187 |
188 | outputs = [branch3x3, branch7x7, branchpool]
189 |
190 | return torch.cat(outputs, 1)
191 |
192 |
193 | #same
194 | class InceptionE(nn.Module):
195 | def __init__(self, input_channels):
196 | super().__init__()
197 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)
198 |
199 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
200 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
201 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
202 |
203 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
204 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
205 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
206 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
207 |
208 | self.branch_pool = nn.Sequential(
209 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
210 | BasicConv2d(input_channels, 192, kernel_size=1)
211 | )
212 |
213 | def forward(self, x):
214 |
215 | #x -> 1x1 (same)
216 | branch1x1 = self.branch1x1(x)
217 |
218 | # x -> 1x1 -> 3x1
219 | # x -> 1x1 -> 1x3
220 | # concatenate(3x1, 1x3)
221 | #"""7. Inception modules with expanded the filter bank outputs.
222 | #This architecture is used on the coarsest (8 × 8) grids to promote
223 | #high dimensional representations, as suggested by principle
224 | #2 of Section 2."""
225 | branch3x3 = self.branch3x3_1(x)
226 | branch3x3 = [
227 | self.branch3x3_2a(branch3x3),
228 | self.branch3x3_2b(branch3x3)
229 | ]
230 | branch3x3 = torch.cat(branch3x3, 1)
231 |
232 | # x -> 1x1 -> 3x3 -> 1x3
233 | # x -> 1x1 -> 3x3 -> 3x1
234 | #concatenate(1x3, 3x1)
235 | branch3x3stack = self.branch3x3stack_1(x)
236 | branch3x3stack = self.branch3x3stack_2(branch3x3stack)
237 | branch3x3stack = [
238 | self.branch3x3stack_3a(branch3x3stack),
239 | self.branch3x3stack_3b(branch3x3stack)
240 | ]
241 | branch3x3stack = torch.cat(branch3x3stack, 1)
242 |
243 | branchpool = self.branch_pool(x)
244 |
245 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]
246 |
247 | return torch.cat(outputs, 1)
248 |
249 | class InceptionV3(AbstractCNN):
250 |
251 | def __init__(self, num_classes=100):
252 | super().__init__()
253 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, padding=1)
254 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
255 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
256 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
257 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
258 |
259 | #naive inception module
260 | self.Mixed_5b = InceptionA(192, pool_features=32)
261 | self.Mixed_5c = InceptionA(256, pool_features=64)
262 | self.Mixed_5d = InceptionA(288, pool_features=64)
263 |
264 | #downsample
265 | self.Mixed_6a = InceptionB(288)
266 |
267 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
268 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
269 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
270 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
271 |
272 | #downsample
273 | self.Mixed_7a = InceptionD(768)
274 |
275 | self.Mixed_7b = InceptionE(1280)
276 | self.Mixed_7c = InceptionE(2048)
277 |
278 | #6*6 feature size
279 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
280 | self.dropout = nn.Dropout2d()
281 | self.head = None
282 | self.embed_dim = 2048
283 |
284 | def forward(self, x):
285 |
286 | #32 -> 30
287 | x = self.Conv2d_1a_3x3(x)
288 | x = self.Conv2d_2a_3x3(x)
289 | x = self.Conv2d_2b_3x3(x)
290 | x = self.Conv2d_3b_1x1(x)
291 | x = self.Conv2d_4a_3x3(x)
292 |
293 | #30 -> 30
294 | x = self.Mixed_5b(x)
295 | x = self.Mixed_5c(x)
296 | x = self.Mixed_5d(x)
297 |
298 | #30 -> 14
299 | #Efficient Grid Size Reduction to avoid representation
300 | #bottleneck
301 | x = self.Mixed_6a(x)
302 |
303 | #14 -> 14
304 | #"""In practice, we have found that employing this factorization does not
305 | #work well on early layers, but it gives very good results on medium
306 | #grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
307 | #On that level, very good results can be achieved by using 1 × 7 convolutions
308 | #followed by 7 × 1 convolutions."""
309 | x = self.Mixed_6b(x)
310 | x = self.Mixed_6c(x)
311 | x = self.Mixed_6d(x)
312 | x = self.Mixed_6e(x)
313 |
314 | #14 -> 6
315 | #Efficient Grid Size Reduction
316 | x = self.Mixed_7a(x)
317 |
318 | #6 -> 6
319 | #We are using this solution only on the coarsest grid,
320 | #since that is the place where producing high dimensional
321 | #sparse representation is the most critical as the ratio of
322 | #local processing (by 1 × 1 convolutions) is increased compared
323 | #to the spatial aggregation."""
324 | x = self.Mixed_7b(x)
325 | x = self.Mixed_7c(x)
326 |
327 | #6 -> 1
328 | x = self.avgpool(x)
329 | x = self.dropout(x)
330 | x = x.view(x.size(0), -1)
331 | x = self.head(x)
332 | return x
333 |
334 |
335 | def inceptionv3():
336 | return InceptionV3()
337 |
338 |
339 |
340 |
--------------------------------------------------------------------------------
/continual/cnn/resnet.py:
--------------------------------------------------------------------------------
1 | #from .utils import load_state_dict_from_url
2 | from typing import Any, Callable, List, Optional, Type, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from continual.cnn import AbstractCNN
7 | from torch import Tensor
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
11 | 'wide_resnet50_2', 'wide_resnet101_2']
12 |
13 |
14 | model_urls = {
15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
24 | }
25 |
26 |
27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=dilation, groups=groups, bias=False, dilation=dilation)
31 |
32 |
33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
34 | """1x1 convolution"""
35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion: int = 1
40 |
41 | def __init__(
42 | self,
43 | inplanes: int,
44 | planes: int,
45 | stride: int = 1,
46 | downsample: Optional[nn.Module] = None,
47 | groups: int = 1,
48 | base_width: int = 64,
49 | dilation: int = 1,
50 | norm_layer: Optional[Callable[..., nn.Module]] = None
51 | ) -> None:
52 | super(BasicBlock, self).__init__()
53 | if norm_layer is None:
54 | norm_layer = nn.BatchNorm2d
55 | if groups != 1 or base_width != 64:
56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
57 | #if dilation > 1:
58 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
60 | self.conv1 = conv3x3(inplanes, planes, stride)
61 | self.bn1 = norm_layer(planes)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.conv2 = conv3x3(planes, planes)
64 | self.bn2 = norm_layer(planes)
65 | self.downsample = downsample
66 | self.stride = stride
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | identity = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 |
78 | if self.downsample is not None:
79 | identity = self.downsample(x)
80 |
81 | out += identity
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class Bottleneck(nn.Module):
88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
91 | # This variant is also known as ResNet V1.5 and improves accuracy according to
92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
93 |
94 | expansion: int = 4
95 |
96 | def __init__(
97 | self,
98 | inplanes: int,
99 | planes: int,
100 | stride: int = 1,
101 | downsample: Optional[nn.Module] = None,
102 | groups: int = 1,
103 | base_width: int = 64,
104 | dilation: int = 1,
105 | norm_layer: Optional[Callable[..., nn.Module]] = None
106 | ) -> None:
107 | super(Bottleneck, self).__init__()
108 | if norm_layer is None:
109 | norm_layer = nn.BatchNorm2d
110 | width = int(planes * (base_width / 64.)) * groups
111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
112 | self.conv1 = conv1x1(inplanes, width)
113 | self.bn1 = norm_layer(width)
114 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
115 | self.bn2 = norm_layer(width)
116 | self.conv3 = conv1x1(width, planes * self.expansion)
117 | self.bn3 = norm_layer(planes * self.expansion)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.downsample = downsample
120 | self.stride = stride
121 |
122 | def forward(self, x: Tensor) -> Tensor:
123 | identity = x
124 |
125 | out = self.conv1(x)
126 | out = self.bn1(out)
127 | out = self.relu(out)
128 |
129 | out = self.conv2(out)
130 | out = self.bn2(out)
131 | out = self.relu(out)
132 |
133 | out = self.conv3(out)
134 | out = self.bn3(out)
135 |
136 | if self.downsample is not None:
137 | identity = self.downsample(x)
138 |
139 | out += identity
140 | out = self.relu(out)
141 |
142 | return out
143 |
144 |
145 | class ResNet(AbstractCNN):
146 |
147 | def __init__(
148 | self,
149 | block: Type[Union[BasicBlock, Bottleneck]],
150 | layers: List[int],
151 | num_classes: int = 1000,
152 | zero_init_residual: bool = False,
153 | groups: int = 1,
154 | width_per_group: int = 64,
155 | replace_stride_with_dilation: Optional[List[bool]] = None,
156 | norm_layer: Optional[Callable[..., nn.Module]] = None
157 | ) -> None:
158 | super(ResNet, self).__init__()
159 | if norm_layer is None:
160 | norm_layer = nn.BatchNorm2d
161 | self._norm_layer = norm_layer
162 |
163 | self.inplanes = 64
164 | self.dilation = 1
165 | if replace_stride_with_dilation is None:
166 | # each element in the tuple indicates if we should replace
167 | # the 2x2 stride with a dilated convolution instead
168 | replace_stride_with_dilation = [False, False, False]
169 | if len(replace_stride_with_dilation) != 3:
170 | raise ValueError("replace_stride_with_dilation should be None "
171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
172 | self.groups = groups
173 | self.base_width = width_per_group
174 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
175 | bias=False)
176 | self.bn1 = norm_layer(self.inplanes)
177 | self.relu = nn.ReLU(inplace=True)
178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179 | self.layer1 = self._make_layer(block, 64, layers[0])
180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
181 | dilate=replace_stride_with_dilation[0])
182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
183 | dilate=replace_stride_with_dilation[1])
184 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
185 | dilate=replace_stride_with_dilation[2])
186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
187 |
188 | #self.fc = nn.Linear(512 * block.expansion, num_classes)
189 | self.embed_dim = 512 * block.expansion
190 | self.head = None
191 |
192 | for m in self.modules():
193 | if isinstance(m, nn.Conv2d):
194 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
195 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
196 | nn.init.constant_(m.weight, 1)
197 | nn.init.constant_(m.bias, 0)
198 |
199 | # Zero-initialize the last BN in each residual branch,
200 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
201 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
202 | if zero_init_residual:
203 | for m in self.modules():
204 | if isinstance(m, Bottleneck):
205 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
206 | elif isinstance(m, BasicBlock):
207 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
208 |
209 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
210 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
211 | norm_layer = self._norm_layer
212 | downsample = None
213 | previous_dilation = self.dilation
214 | if dilate:
215 | self.dilation *= stride
216 | stride = 1
217 | if stride != 1 or self.inplanes != planes * block.expansion:
218 | downsample = nn.Sequential(
219 | conv1x1(self.inplanes, planes * block.expansion, stride),
220 | norm_layer(planes * block.expansion),
221 | )
222 |
223 | layers = []
224 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
225 | self.base_width, previous_dilation, norm_layer))
226 | self.inplanes = planes * block.expansion
227 | for _ in range(1, blocks):
228 | layers.append(block(self.inplanes, planes, groups=self.groups,
229 | base_width=self.base_width, dilation=self.dilation,
230 | norm_layer=norm_layer))
231 |
232 | return nn.Sequential(*layers)
233 |
234 | def _make_layer_nodown(self, inplanes: int, planes: int, blocks: int,
235 | stride: int = 1, dilation: int = 1) -> nn.Sequential:
236 | norm_layer = self._norm_layer
237 | downsample = nn.Conv2d(256, 512, kernel_size=1)
238 | previous_dilation = self.dilation = dilation
239 |
240 | layers = []
241 | layers.append(BasicBlock(inplanes, planes, stride, downsample, self.groups,
242 | self.base_width, previous_dilation, norm_layer))
243 | self.inplanes = planes * BasicBlock.expansion
244 | for _ in range(1, blocks):
245 | layers.append(BasicBlock(self.inplanes, planes, groups=self.groups,
246 | base_width=self.base_width, dilation=self.dilation,
247 | norm_layer=norm_layer))
248 |
249 | return nn.Sequential(*layers)
250 |
251 | def _forward_impl(self, x: Tensor) -> Tensor:
252 | # See note [TorchScript super()]
253 | x = self.conv1(x)
254 | x = self.bn1(x)
255 | x = self.relu(x)
256 | x = self.maxpool(x)
257 |
258 | x = self.layer1(x)
259 | x = self.layer2(x)
260 | x = self.layer3(x)
261 | x = self.layer4(x)
262 |
263 | x = self.avgpool(x)
264 | x = torch.flatten(x, 1)
265 | x = self.head(x)
266 |
267 | return x
268 |
269 | def forward(self, x: Tensor) -> Tensor:
270 | return self._forward_impl(x)
271 |
272 | def forward_tokens(self, x):
273 | x = self.conv1(x)
274 | x = self.bn1(x)
275 | x = self.relu(x)
276 | x = self.maxpool(x)
277 |
278 | x = self.layer1(x)
279 | x = self.layer2(x)
280 | x = self.layer3(x)
281 | x = self.layer4(x)
282 |
283 | x = self.head(x)
284 | return x.view(x.shape[0], self.embed_dim, -1).permute(0, 2, 1)
285 |
286 | def forward_features(self, x):
287 | x = self.conv1(x)
288 | x = self.bn1(x)
289 | x = self.relu(x)
290 | x = self.maxpool(x)
291 |
292 | x = self.layer1(x)
293 | x = self.layer2(x)
294 | x = self.layer3(x)
295 | x = self.layer4(x)
296 |
297 | x = self.avgpool(x)
298 | x = torch.flatten(x, 1)
299 | return x, None, None
300 |
301 |
302 | def _resnet(
303 | arch: str,
304 | block: Type[Union[BasicBlock, Bottleneck]],
305 | layers: List[int],
306 | pretrained: bool,
307 | progress: bool,
308 | **kwargs: Any
309 | ) -> ResNet:
310 | model = ResNet(block, layers, **kwargs)
311 | if pretrained:
312 | state_dict = load_state_dict_from_url(model_urls[arch],
313 | progress=progress)
314 | model.load_state_dict(state_dict)
315 | return model
316 |
317 |
318 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
319 | r"""ResNet-18 model from
320 | `"Deep Residual Learning for Image Recognition" `_.
321 |
322 | Args:
323 | pretrained (bool): If True, returns a model pre-trained on ImageNet
324 | progress (bool): If True, displays a progress bar of the download to stderr
325 | """
326 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
327 | **kwargs)
328 |
329 |
330 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
331 | r"""ResNet-34 model from
332 | `"Deep Residual Learning for Image Recognition" `_.
333 |
334 | Args:
335 | pretrained (bool): If True, returns a model pre-trained on ImageNet
336 | progress (bool): If True, displays a progress bar of the download to stderr
337 | """
338 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
339 | **kwargs)
340 |
341 |
342 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
343 | r"""ResNet-50 model from
344 | `"Deep Residual Learning for Image Recognition" `_.
345 |
346 | Args:
347 | pretrained (bool): If True, returns a model pre-trained on ImageNet
348 | progress (bool): If True, displays a progress bar of the download to stderr
349 | """
350 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
351 | **kwargs)
352 |
353 |
354 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
355 | r"""ResNet-101 model from
356 | `"Deep Residual Learning for Image Recognition" `_.
357 |
358 | Args:
359 | pretrained (bool): If True, returns a model pre-trained on ImageNet
360 | progress (bool): If True, displays a progress bar of the download to stderr
361 | """
362 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
363 | **kwargs)
364 |
365 |
366 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
367 | r"""ResNet-152 model from
368 | `"Deep Residual Learning for Image Recognition" `_.
369 |
370 | Args:
371 | pretrained (bool): If True, returns a model pre-trained on ImageNet
372 | progress (bool): If True, displays a progress bar of the download to stderr
373 | """
374 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
375 | **kwargs)
376 |
377 |
378 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
379 | r"""ResNeXt-50 32x4d model from
380 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
381 |
382 | Args:
383 | pretrained (bool): If True, returns a model pre-trained on ImageNet
384 | progress (bool): If True, displays a progress bar of the download to stderr
385 | """
386 | kwargs['groups'] = 32
387 | kwargs['width_per_group'] = 4
388 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
389 | pretrained, progress, **kwargs)
390 |
391 |
392 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
393 | r"""ResNeXt-101 32x8d model from
394 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
395 |
396 | Args:
397 | pretrained (bool): If True, returns a model pre-trained on ImageNet
398 | progress (bool): If True, displays a progress bar of the download to stderr
399 | """
400 | kwargs['groups'] = 32
401 | kwargs['width_per_group'] = 8
402 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
403 | pretrained, progress, **kwargs)
404 |
405 |
406 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
407 | r"""Wide ResNet-50-2 model from
408 | `"Wide Residual Networks" `_.
409 |
410 | The model is the same as ResNet except for the bottleneck number of channels
411 | which is twice larger in every block. The number of channels in outer 1x1
412 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
413 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
414 |
415 | Args:
416 | pretrained (bool): If True, returns a model pre-trained on ImageNet
417 | progress (bool): If True, displays a progress bar of the download to stderr
418 | """
419 | kwargs['width_per_group'] = 64 * 2
420 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
421 | pretrained, progress, **kwargs)
422 |
423 |
424 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
425 | r"""Wide ResNet-101-2 model from
426 | `"Wide Residual Networks" `_.
427 |
428 | The model is the same as ResNet except for the bottleneck number of channels
429 | which is twice larger in every block. The number of channels in outer 1x1
430 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
431 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
432 |
433 | Args:
434 | pretrained (bool): If True, returns a model pre-trained on ImageNet
435 | progress (bool): If True, displays a progress bar of the download to stderr
436 | """
437 | kwargs['width_per_group'] = 64 * 2
438 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
439 | pretrained, progress, **kwargs)
440 |
--------------------------------------------------------------------------------
/continual/cnn/resnet_rebuffi.py:
--------------------------------------------------------------------------------
1 | """Pytorch port of the resnet used for CIFAR100 by iCaRL.
2 |
3 | https://github.com/srebuffi/iCaRL/blob/master/iCaRL-TheanoLasagne/utils_cifar100.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import init
9 |
10 | from continual.cnn import AbstractCNN
11 |
12 |
13 | class DownsampleStride(nn.Module):
14 |
15 | def __init__(self, n=2):
16 | super(DownsampleStride, self).__init__()
17 | self._n = n
18 |
19 | def forward(self, x):
20 | return x[..., ::2, ::2]
21 |
22 |
23 | class DownsampleConv(nn.Module):
24 |
25 | def __init__(self, inplanes, planes):
26 | super().__init__()
27 |
28 | self.conv = nn.Sequential(
29 | nn.Conv2d(inplanes, planes, stride=2, kernel_size=1, bias=False),
30 | nn.BatchNorm2d(planes),
31 | )
32 |
33 | def forward(self, x):
34 | return self.conv(x)
35 |
36 |
37 | class ResidualBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, increase_dim=False, last_relu=False, downsampling="stride"):
41 | super(ResidualBlock, self).__init__()
42 |
43 | self.increase_dim = increase_dim
44 |
45 | if increase_dim:
46 | first_stride = 2
47 | planes = inplanes * 2
48 | else:
49 | first_stride = 1
50 | planes = inplanes
51 |
52 | self.conv_a = nn.Conv2d(
53 | inplanes, planes, kernel_size=3, stride=first_stride, padding=1, bias=False
54 | )
55 | self.bn_a = nn.BatchNorm2d(planes)
56 |
57 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
58 | self.bn_b = nn.BatchNorm2d(planes)
59 |
60 | if increase_dim:
61 | if downsampling == "stride":
62 | self.downsampler = DownsampleStride()
63 | self._need_pad = True
64 | else:
65 | self.downsampler = DownsampleConv(inplanes, planes)
66 | self._need_pad = False
67 |
68 | self.last_relu = last_relu
69 |
70 | @staticmethod
71 | def pad(x):
72 | return torch.cat((x, x.mul(0)), 1)
73 |
74 | def forward(self, x):
75 | y = self.conv_a(x)
76 | y = self.bn_a(y)
77 | y = F.relu(y, inplace=True)
78 |
79 | y = self.conv_b(y)
80 | y = self.bn_b(y)
81 |
82 | if self.increase_dim:
83 | x = self.downsampler(x)
84 | if self._need_pad:
85 | x = self.pad(x)
86 |
87 | y = x + y
88 |
89 | if self.last_relu:
90 | y = F.relu(y, inplace=True)
91 |
92 | return y
93 |
94 |
95 | class PreActResidualBlock(nn.Module):
96 | expansion = 1
97 |
98 | def __init__(self, inplanes, increase_dim=False, last_relu=False):
99 | super().__init__()
100 |
101 | self.increase_dim = increase_dim
102 |
103 | if increase_dim:
104 | first_stride = 2
105 | planes = inplanes * 2
106 | else:
107 | first_stride = 1
108 | planes = inplanes
109 |
110 | self.bn_a = nn.BatchNorm2d(inplanes)
111 | self.conv_a = nn.Conv2d(
112 | inplanes, planes, kernel_size=3, stride=first_stride, padding=1, bias=False
113 | )
114 |
115 | self.bn_b = nn.BatchNorm2d(planes)
116 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
117 |
118 | if increase_dim:
119 | self.downsample = DownsampleStride()
120 | self.pad = lambda x: torch.cat((x, x.mul(0)), 1)
121 | self.last_relu = last_relu
122 |
123 | def forward(self, x):
124 | y = self.bn_a(x)
125 | y = F.relu(y, inplace=True)
126 | y = self.conv_a(x)
127 |
128 | y = self.bn_b(y)
129 | y = F.relu(y, inplace=True)
130 | y = self.conv_b(y)
131 |
132 | if self.increase_dim:
133 | x = self.downsample(x)
134 | x = self.pad(x)
135 |
136 | y = x + y
137 |
138 | if self.last_relu:
139 | y = F.relu(y, inplace=True)
140 |
141 | return y
142 |
143 |
144 | class Stage(nn.Module):
145 |
146 | def __init__(self, blocks, block_relu=False):
147 | super().__init__()
148 |
149 | self.blocks = nn.ModuleList(blocks)
150 | self.block_relu = block_relu
151 |
152 | def forward(self, x):
153 | intermediary_features = []
154 |
155 | for b in self.blocks:
156 | x = b(x)
157 | intermediary_features.append(x)
158 |
159 | if self.block_relu:
160 | x = F.relu(x)
161 |
162 | return intermediary_features, x
163 |
164 |
165 | class CifarResNet(AbstractCNN):
166 | """
167 | ResNet optimized for the Cifar Dataset, as specified in
168 | https://arxiv.org/abs/1512.03385.pdf
169 | """
170 |
171 | def __init__(
172 | self,
173 | n=5,
174 | nf=16,
175 | channels=3,
176 | preact=False,
177 | zero_residual=True,
178 | pooling_config={"type": "avg"},
179 | downsampling="stride",
180 | all_attentions=False,
181 | last_relu=True,
182 | **kwargs
183 | ):
184 | """ Constructor
185 | Args:
186 | depth: number of layers.
187 | num_classes: number of classes
188 | base_width: base width
189 | """
190 | if kwargs:
191 | raise ValueError("Unused kwargs: {}.".format(kwargs))
192 |
193 | self.all_attentions = all_attentions
194 | self._downsampling_type = downsampling
195 | self.last_relu = last_relu
196 |
197 | Block = ResidualBlock if not preact else PreActResidualBlock
198 |
199 | super(CifarResNet, self).__init__()
200 |
201 | self.conv_1_3x3 = nn.Conv2d(channels, nf, kernel_size=3, stride=1, padding=1, bias=False)
202 | self.bn_1 = nn.BatchNorm2d(nf)
203 |
204 | self.stage_1 = self._make_layer(Block, nf, increase_dim=False, n=n)
205 | self.stage_2 = self._make_layer(Block, nf, increase_dim=True, n=n - 1)
206 | self.stage_3 = self._make_layer(Block, 2 * nf, increase_dim=True, n=n - 2)
207 | self.stage_4 = Block(
208 | 4 * nf, increase_dim=False, last_relu=False, downsampling=self._downsampling_type
209 | )
210 |
211 | if pooling_config["type"] == "avg":
212 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
213 | else:
214 | raise ValueError("Unknown pooling type {}.".format(pooling_config["type"]))
215 |
216 | self.embed_dim = 4 * nf
217 | self.head = None
218 |
219 | for m in self.modules():
220 | if isinstance(m, nn.Conv2d):
221 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
222 | elif isinstance(m, nn.BatchNorm2d):
223 | nn.init.constant_(m.weight, 1)
224 | nn.init.constant_(m.bias, 0)
225 | elif isinstance(m, nn.Linear):
226 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
227 |
228 | if zero_residual:
229 | for m in self.modules():
230 | if isinstance(m, ResidualBlock):
231 | nn.init.constant_(m.bn_b.weight, 0)
232 |
233 | def _make_layer(self, Block, planes, increase_dim=False, n=None):
234 | layers = []
235 |
236 | if increase_dim:
237 | layers.append(
238 | Block(
239 | planes,
240 | increase_dim=True,
241 | last_relu=True,
242 | downsampling=self._downsampling_type
243 | )
244 | )
245 | planes = 2 * planes
246 |
247 | for i in range(n):
248 | layers.append(Block(planes, last_relu=True, downsampling=self._downsampling_type))
249 |
250 | return Stage(layers, block_relu=self.last_relu)
251 |
252 | @property
253 | def last_conv(self):
254 | return self.stage_4.conv_b
255 |
256 | def forward(self, x):
257 | x = self.conv_1_3x3(x)
258 | x = F.relu(self.bn_1(x), inplace=True)
259 |
260 | feats_s1, x = self.stage_1(x)
261 | feats_s2, x = self.stage_2(x)
262 | feats_s3, x = self.stage_3(x)
263 | x = self.stage_4(x)
264 |
265 | features = self.end_features(F.relu(x, inplace=False))
266 |
267 | return self.head(features)
268 |
269 | def end_features(self, x):
270 | x = self.pool(x)
271 | x = x.view(x.size(0), -1)
272 |
273 | return x
274 |
275 |
276 | def resnet_rebuffi(n=5, **kwargs):
277 | return CifarResNet(n=n, **kwargs)
278 |
--------------------------------------------------------------------------------
/continual/cnn/senet.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from continual.cnn import AbstractCNN
4 |
5 | """
6 | SEResNet implementation from Cadene's pretrained models
7 | https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
8 | Additional credit to https://github.com/creafz
9 |
10 | Original model: https://github.com/hujie-frank/SENet
11 |
12 | ResNet code gently borrowed from
13 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
14 |
15 | FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate
16 | support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here.
17 | """
18 | import math
19 | from collections import OrderedDict
20 |
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 |
24 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
25 | from timm.models.helpers import build_model_with_cfg
26 | from timm.models.layers import create_classifier
27 | from timm.models.registry import register_model
28 |
29 | __all__ = ['SENet']
30 |
31 |
32 | def _cfg(url='', **kwargs):
33 | return {
34 | 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
35 | 'crop_pct': 0.875, 'interpolation': 'bilinear',
36 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
37 | 'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
38 | **kwargs
39 | }
40 |
41 |
42 | default_cfgs = {
43 | 'legacy_senet154':
44 | _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
45 | 'legacy_seresnet18': _cfg(
46 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
47 | interpolation='bicubic'),
48 | 'legacy_seresnet34': _cfg(
49 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
50 | 'legacy_seresnet50': _cfg(
51 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
52 | 'legacy_seresnet101': _cfg(
53 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
54 | 'legacy_seresnet152': _cfg(
55 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
56 | 'legacy_seresnext26_32x4d': _cfg(
57 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
58 | interpolation='bicubic'),
59 | 'legacy_seresnext50_32x4d':
60 | _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
61 | 'legacy_seresnext101_32x4d':
62 | _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'),
63 | }
64 |
65 |
66 | def _weight_init(m):
67 | if isinstance(m, nn.Conv2d):
68 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
69 | elif isinstance(m, nn.BatchNorm2d):
70 | nn.init.constant_(m.weight, 1.)
71 | nn.init.constant_(m.bias, 0.)
72 |
73 |
74 | class SEModule(nn.Module):
75 |
76 | def __init__(self, channels, reduction):
77 | super(SEModule, self).__init__()
78 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
81 | self.sigmoid = nn.Sigmoid()
82 |
83 | def forward(self, x):
84 | module_input = x
85 | x = x.mean((2, 3), keepdim=True)
86 | x = self.fc1(x)
87 | x = self.relu(x)
88 | x = self.fc2(x)
89 | x = self.sigmoid(x)
90 | return module_input * x
91 |
92 |
93 | class Bottleneck(nn.Module):
94 | """
95 | Base class for bottlenecks that implements `forward()` method.
96 | """
97 |
98 | def forward(self, x):
99 | residual = x
100 |
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv2(out)
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | out = self.bn3(out)
111 |
112 | if self.downsample is not None:
113 | residual = self.downsample(x)
114 |
115 | out = self.se_module(out) + residual
116 | out = self.relu(out)
117 |
118 | return out
119 |
120 |
121 | class SEBottleneck(Bottleneck):
122 | """
123 | Bottleneck for SENet154.
124 | """
125 | expansion = 4
126 |
127 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
128 | downsample=None):
129 | super(SEBottleneck, self).__init__()
130 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
131 | self.bn1 = nn.BatchNorm2d(planes * 2)
132 | self.conv2 = nn.Conv2d(
133 | planes * 2, planes * 4, kernel_size=3, stride=stride,
134 | padding=1, groups=groups, bias=False)
135 | self.bn2 = nn.BatchNorm2d(planes * 4)
136 | self.conv3 = nn.Conv2d(
137 | planes * 4, planes * 4, kernel_size=1, bias=False)
138 | self.bn3 = nn.BatchNorm2d(planes * 4)
139 | self.relu = nn.ReLU(inplace=True)
140 | self.se_module = SEModule(planes * 4, reduction=reduction)
141 | self.downsample = downsample
142 | self.stride = stride
143 |
144 |
145 | class SEResNetBottleneck(Bottleneck):
146 | """
147 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
148 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
149 | (the latter is used in the torchvision implementation of ResNet).
150 | """
151 | expansion = 4
152 |
153 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
154 | downsample=None):
155 | super(SEResNetBottleneck, self).__init__()
156 | self.conv1 = nn.Conv2d(
157 | inplanes, planes, kernel_size=1, bias=False, stride=stride)
158 | self.bn1 = nn.BatchNorm2d(planes)
159 | self.conv2 = nn.Conv2d(
160 | planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
161 | self.bn2 = nn.BatchNorm2d(planes)
162 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
163 | self.bn3 = nn.BatchNorm2d(planes * 4)
164 | self.relu = nn.ReLU(inplace=True)
165 | self.se_module = SEModule(planes * 4, reduction=reduction)
166 | self.downsample = downsample
167 | self.stride = stride
168 |
169 |
170 | class SEResNeXtBottleneck(Bottleneck):
171 | """
172 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
173 | """
174 | expansion = 4
175 |
176 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
177 | downsample=None, base_width=4):
178 | super(SEResNeXtBottleneck, self).__init__()
179 | width = math.floor(planes * (base_width / 64)) * groups
180 | self.conv1 = nn.Conv2d(
181 | inplanes, width, kernel_size=1, bias=False, stride=1)
182 | self.bn1 = nn.BatchNorm2d(width)
183 | self.conv2 = nn.Conv2d(
184 | width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
185 | self.bn2 = nn.BatchNorm2d(width)
186 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
187 | self.bn3 = nn.BatchNorm2d(planes * 4)
188 | self.relu = nn.ReLU(inplace=True)
189 | self.se_module = SEModule(planes * 4, reduction=reduction)
190 | self.downsample = downsample
191 | self.stride = stride
192 |
193 |
194 | class SEResNetBlock(nn.Module):
195 | expansion = 1
196 |
197 | def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
198 | super(SEResNetBlock, self).__init__()
199 | self.conv1 = nn.Conv2d(
200 | inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
201 | self.bn1 = nn.BatchNorm2d(planes)
202 | self.conv2 = nn.Conv2d(
203 | planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
204 | self.bn2 = nn.BatchNorm2d(planes)
205 | self.relu = nn.ReLU(inplace=True)
206 | self.se_module = SEModule(planes, reduction=reduction)
207 | self.downsample = downsample
208 | self.stride = stride
209 |
210 | def forward(self, x):
211 | residual = x
212 |
213 | out = self.conv1(x)
214 | out = self.bn1(out)
215 | out = self.relu(out)
216 |
217 | out = self.conv2(out)
218 | out = self.bn2(out)
219 | out = self.relu(out)
220 |
221 | if self.downsample is not None:
222 | residual = self.downsample(x)
223 |
224 | out = self.se_module(out) + residual
225 | out = self.relu(out)
226 |
227 | return out
228 |
229 |
230 | class SENet(AbstractCNN):
231 |
232 | def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
233 | in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1,
234 | downsample_padding=0, num_classes=1000, global_pool='avg'):
235 | """
236 | Parameters
237 | ----------
238 | block (nn.Module): Bottleneck class.
239 | - For SENet154: SEBottleneck
240 | - For SE-ResNet models: SEResNetBottleneck
241 | - For SE-ResNeXt models: SEResNeXtBottleneck
242 | layers (list of ints): Number of residual blocks for 4 layers of the
243 | network (layer1...layer4).
244 | groups (int): Number of groups for the 3x3 convolution in each
245 | bottleneck block.
246 | - For SENet154: 64
247 | - For SE-ResNet models: 1
248 | - For SE-ResNeXt models: 32
249 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
250 | - For all models: 16
251 | dropout_p (float or None): Drop probability for the Dropout layer.
252 | If `None` the Dropout layer is not used.
253 | - For SENet154: 0.2
254 | - For SE-ResNet models: None
255 | - For SE-ResNeXt models: None
256 | inplanes (int): Number of input channels for layer1.
257 | - For SENet154: 128
258 | - For SE-ResNet models: 64
259 | - For SE-ResNeXt models: 64
260 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
261 | a single 7x7 convolution in layer0.
262 | - For SENet154: True
263 | - For SE-ResNet models: False
264 | - For SE-ResNeXt models: False
265 | downsample_kernel_size (int): Kernel size for downsampling convolutions
266 | in layer2, layer3 and layer4.
267 | - For SENet154: 3
268 | - For SE-ResNet models: 1
269 | - For SE-ResNeXt models: 1
270 | downsample_padding (int): Padding for downsampling convolutions in
271 | layer2, layer3 and layer4.
272 | - For SENet154: 1
273 | - For SE-ResNet models: 0
274 | - For SE-ResNeXt models: 0
275 | num_classes (int): Number of outputs in `last_linear` layer.
276 | - For all models: 1000
277 | """
278 | super(SENet, self).__init__()
279 | self.inplanes = inplanes
280 | self.num_classes = num_classes
281 | self.drop_rate = drop_rate
282 | if input_3x3:
283 | layer0_modules = [
284 | ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)),
285 | ('bn1', nn.BatchNorm2d(64)),
286 | ('relu1', nn.ReLU(inplace=True)),
287 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
288 | ('bn2', nn.BatchNorm2d(64)),
289 | ('relu2', nn.ReLU(inplace=True)),
290 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)),
291 | ('bn3', nn.BatchNorm2d(inplanes)),
292 | ('relu3', nn.ReLU(inplace=True)),
293 | ]
294 | else:
295 | layer0_modules = [
296 | ('conv1', nn.Conv2d(
297 | in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
298 | ('bn1', nn.BatchNorm2d(inplanes)),
299 | ('relu1', nn.ReLU(inplace=True)),
300 | ]
301 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
302 | # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`.
303 | self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
304 | self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')]
305 | self.layer1 = self._make_layer(
306 | block,
307 | planes=64,
308 | blocks=layers[0],
309 | groups=groups,
310 | reduction=reduction,
311 | downsample_kernel_size=1,
312 | downsample_padding=0
313 | )
314 | self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')]
315 | self.layer2 = self._make_layer(
316 | block,
317 | planes=128,
318 | blocks=layers[1],
319 | stride=2,
320 | groups=groups,
321 | reduction=reduction,
322 | downsample_kernel_size=downsample_kernel_size,
323 | downsample_padding=downsample_padding
324 | )
325 | self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')]
326 | self.layer3 = self._make_layer(
327 | block,
328 | planes=256,
329 | blocks=layers[2],
330 | stride=2,
331 | groups=groups,
332 | reduction=reduction,
333 | downsample_kernel_size=downsample_kernel_size,
334 | downsample_padding=downsample_padding
335 | )
336 | self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')]
337 | self.layer4 = self._make_layer(
338 | block,
339 | planes=512,
340 | blocks=layers[3],
341 | stride=2,
342 | groups=groups,
343 | reduction=reduction,
344 | downsample_kernel_size=downsample_kernel_size,
345 | downsample_padding=downsample_padding
346 | )
347 | self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')]
348 | self.num_features = 512 * block.expansion
349 | self.embed_dim = 512 * block.expansion
350 | self.global_pool, self.last_linear = create_classifier(
351 | self.num_features, self.num_classes, pool_type=global_pool)
352 |
353 | for m in self.modules():
354 | _weight_init(m)
355 |
356 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
357 | downsample_kernel_size=1, downsample_padding=0):
358 | downsample = None
359 | if stride != 1 or self.inplanes != planes * block.expansion:
360 | downsample = nn.Sequential(
361 | nn.Conv2d(
362 | self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size,
363 | stride=stride, padding=downsample_padding, bias=False),
364 | nn.BatchNorm2d(planes * block.expansion),
365 | )
366 |
367 | layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)]
368 | self.inplanes = planes * block.expansion
369 | for i in range(1, blocks):
370 | layers.append(block(self.inplanes, planes, groups, reduction))
371 |
372 | return nn.Sequential(*layers)
373 |
374 | def get_classifier(self):
375 | return self.last_linear
376 |
377 | def reset_classifier(self, num_classes, global_pool='avg'):
378 | self.num_classes = num_classes
379 | self.global_pool, self.last_linear = create_classifier(
380 | self.num_features, self.num_classes, pool_type=global_pool)
381 |
382 | def forward_features(self, x):
383 | x = self.layer0(x)
384 | x = self.pool0(x)
385 | x = self.layer1(x)
386 | x = self.layer2(x)
387 | x = self.layer3(x)
388 | x = self.layer4(x)
389 | return x
390 |
391 | def logits(self, x):
392 | x = self.global_pool(x)
393 | if self.drop_rate > 0.:
394 | x = F.dropout(x, p=self.drop_rate, training=self.training)
395 | x = self.head(x)
396 | return x
397 |
398 | def forward(self, x):
399 | x = self.forward_features(x)
400 | x = self.logits(x)
401 | return x
402 |
403 |
404 | def _create_senet(variant, pretrained=False, **kwargs):
405 | return build_model_with_cfg(
406 | SENet, variant, pretrained,
407 | default_cfg=default_cfgs[variant],
408 | **kwargs)
409 |
410 |
411 | @register_model
412 | def legacy_seresnet18(pretrained=False, **kwargs):
413 | model_args = dict(
414 | block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs)
415 | return _create_senet('legacy_seresnet18', pretrained, **model_args)
416 |
417 |
418 | @register_model
419 | def legacy_seresnet34(pretrained=False, **kwargs):
420 | model_args = dict(
421 | block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs)
422 | return _create_senet('legacy_seresnet34', pretrained, **model_args)
423 |
424 |
425 | @register_model
426 | def legacy_seresnet50(pretrained=False, **kwargs):
427 | model_args = dict(
428 | block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs)
429 | return _create_senet('legacy_seresnet50', pretrained, **model_args)
430 |
431 |
432 | @register_model
433 | def legacy_seresnet101(pretrained=False, **kwargs):
434 | model_args = dict(
435 | block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs)
436 | return _create_senet('legacy_seresnet101', pretrained, **model_args)
437 |
438 |
439 | @register_model
440 | def legacy_seresnet152(pretrained=False, **kwargs):
441 | model_args = dict(
442 | block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs)
443 | return _create_senet('legacy_seresnet152', pretrained, **model_args)
444 |
445 |
446 | @register_model
447 | def legacy_senet154(pretrained=False, **kwargs):
448 | model_args = dict(
449 | block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16,
450 | downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs)
451 | return _create_senet('legacy_senet154', pretrained, **model_args)
452 |
453 |
454 | @register_model
455 | def legacy_seresnext26_32x4d(pretrained=False, **kwargs):
456 | model_args = dict(
457 | block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs)
458 | return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args)
459 |
460 |
461 | @register_model
462 | def legacy_seresnext50_32x4d(pretrained=False, **kwargs):
463 | model_args = dict(
464 | block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs)
465 | return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args)
466 |
467 |
468 | @register_model
469 | def legacy_seresnext101_32x4d(pretrained=False, **kwargs):
470 | model_args = dict(
471 | block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs)
472 | return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args)
473 |
--------------------------------------------------------------------------------
/continual/cnn/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | #from .utils import load_state_dict_from_url
4 | from typing import Union, List, Dict, Any, cast
5 |
6 | from continual.cnn import AbstractCNN
7 |
8 | __all__ = [
9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
10 | 'vgg19_bn', 'vgg19',
11 | ]
12 |
13 |
14 | model_urls = {
15 | 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth',
16 | 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth',
17 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
18 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
19 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
20 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
21 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
22 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
23 | }
24 |
25 |
26 | class VGG(AbstractCNN):
27 |
28 | def __init__(
29 | self,
30 | features: nn.Module,
31 | num_classes: int = 1000,
32 | init_weights: bool = True
33 | ) -> None:
34 | super(VGG, self).__init__()
35 | self.features = features
36 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
37 | self.classifier = nn.Sequential(
38 | nn.Linear(512 * 7 * 7, 4096),
39 | nn.ReLU(True),
40 | nn.Dropout(),
41 | nn.Linear(4096, 4096),
42 | nn.ReLU(True),
43 | nn.Dropout(),
44 | )
45 |
46 | self.head = None
47 | self.embed_dim = 4096
48 |
49 | if init_weights:
50 | self._initialize_weights()
51 |
52 | def forward(self, x: torch.Tensor) -> torch.Tensor:
53 | x = self.features(x)
54 | x = self.avgpool(x)
55 | x = torch.flatten(x, 1)
56 | x = self.classifier(x)
57 | return self.head(x)
58 |
59 | def _initialize_weights(self) -> None:
60 | for m in self.modules():
61 | if isinstance(m, nn.Conv2d):
62 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
63 | if m.bias is not None:
64 | nn.init.constant_(m.bias, 0)
65 | elif isinstance(m, nn.BatchNorm2d):
66 | nn.init.constant_(m.weight, 1)
67 | nn.init.constant_(m.bias, 0)
68 | elif isinstance(m, nn.Linear):
69 | nn.init.normal_(m.weight, 0, 0.01)
70 | nn.init.constant_(m.bias, 0)
71 |
72 |
73 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
74 | layers: List[nn.Module] = []
75 | in_channels = 3
76 | for v in cfg:
77 | if v == 'M':
78 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
79 | else:
80 | v = cast(int, v)
81 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
82 | if batch_norm:
83 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
84 | else:
85 | layers += [conv2d, nn.ReLU(inplace=True)]
86 | in_channels = v
87 | return nn.Sequential(*layers)
88 |
89 |
90 | cfgs: Dict[str, List[Union[str, int]]] = {
91 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
92 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
93 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
94 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
95 | }
96 |
97 |
98 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
99 | if pretrained:
100 | kwargs['init_weights'] = False
101 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
102 | if pretrained:
103 | state_dict = load_state_dict_from_url(model_urls[arch],
104 | progress=progress)
105 | model.load_state_dict(state_dict)
106 | return model
107 |
108 |
109 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
110 | r"""VGG 11-layer model (configuration "A") from
111 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
112 |
113 | Args:
114 | pretrained (bool): If True, returns a model pre-trained on ImageNet
115 | progress (bool): If True, displays a progress bar of the download to stderr
116 | """
117 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
118 |
119 |
120 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
121 | r"""VGG 11-layer model (configuration "A") with batch normalization
122 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
123 |
124 | Args:
125 | pretrained (bool): If True, returns a model pre-trained on ImageNet
126 | progress (bool): If True, displays a progress bar of the download to stderr
127 | """
128 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
129 |
130 |
131 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
132 | r"""VGG 13-layer model (configuration "B")
133 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
134 |
135 | Args:
136 | pretrained (bool): If True, returns a model pre-trained on ImageNet
137 | progress (bool): If True, displays a progress bar of the download to stderr
138 | """
139 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
140 |
141 |
142 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
143 | r"""VGG 13-layer model (configuration "B") with batch normalization
144 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
145 |
146 | Args:
147 | pretrained (bool): If True, returns a model pre-trained on ImageNet
148 | progress (bool): If True, displays a progress bar of the download to stderr
149 | """
150 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
151 |
152 |
153 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
154 | r"""VGG 16-layer model (configuration "D")
155 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
156 |
157 | Args:
158 | pretrained (bool): If True, returns a model pre-trained on ImageNet
159 | progress (bool): If True, displays a progress bar of the download to stderr
160 | """
161 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
162 |
163 |
164 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
165 | r"""VGG 16-layer model (configuration "D") with batch normalization
166 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
167 |
168 | Args:
169 | pretrained (bool): If True, returns a model pre-trained on ImageNet
170 | progress (bool): If True, displays a progress bar of the download to stderr
171 | """
172 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
173 |
174 |
175 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
176 | r"""VGG 19-layer model (configuration "E")
177 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
178 |
179 | Args:
180 | pretrained (bool): If True, returns a model pre-trained on ImageNet
181 | progress (bool): If True, displays a progress bar of the download to stderr
182 | """
183 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
184 |
185 |
186 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
187 | r"""VGG 19-layer model (configuration 'E') with batch normalization
188 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
189 |
190 | Args:
191 | pretrained (bool): If True, returns a model pre-trained on ImageNet
192 | progress (bool): If True, displays a progress bar of the download to stderr
193 | """
194 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
195 |
--------------------------------------------------------------------------------
/continual/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import json
4 | import os
5 | import warnings
6 |
7 | from continuum import ClassIncremental
8 | # from continuum import Permutations
9 | from continual.mycontinual import Rotations, IncrementalRotation, Permutations
10 | from continuum.datasets import CIFAR100, MNIST, ImageNet100, ImageFolderDataset, CIFAR10, TinyImageNet200, STL10
11 | from timm.data import create_transform
12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13 | from torchvision import transforms
14 | from torchvision.datasets.folder import ImageFolder, default_loader
15 | from torchvision.transforms import functional as Fv
16 |
17 | from typing import Tuple, Union
18 |
19 | import numpy as np
20 |
21 | from continuum.datasets import ImageFolderDataset
22 | from continuum.download import download, unzip
23 |
24 | try:
25 | interpolation = Fv.InterpolationMode.BICUBIC
26 | except:
27 | interpolation = 3
28 |
29 |
30 | from torch.utils.data import DataLoader
31 | import torch.nn.functional as F
32 | from argparse import Namespace
33 | from copy import deepcopy
34 | import torch
35 | from PIL import Image
36 | # from datasets.utils.validation import get_train_val
37 | from typing import Tuple
38 |
39 |
40 | class ImageNet1000(ImageFolderDataset):
41 | """Continuum dataset for datasets with tree-like structure.
42 | :param train_folder: The folder of the train data.
43 | :param test_folder: The folder of the test data.
44 | :param download: Dummy parameter.
45 | """
46 |
47 | def __init__(
48 | self,
49 | data_path: str,
50 | train: bool = True,
51 | download: bool = False,
52 | ):
53 | super().__init__(data_path=data_path, train=train, download=download)
54 |
55 | def get_data(self):
56 | if self.train:
57 | self.data_path = os.path.join(self.data_path, "train")
58 | else:
59 | self.data_path = os.path.join(self.data_path, "val")
60 | return super().get_data()
61 |
62 |
63 | class INatDataset(ImageFolder):
64 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
65 | category='name', loader=default_loader):
66 | self.transform = transform
67 | self.loader = loader
68 | self.target_transform = target_transform
69 | self.year = year
70 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
71 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
72 | with open(path_json) as json_file:
73 | data = json.load(json_file)
74 |
75 | with open(os.path.join(root, 'categories.json')) as json_file:
76 | data_catg = json.load(json_file)
77 |
78 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
79 |
80 | with open(path_json_for_targeter) as json_file:
81 | data_for_targeter = json.load(json_file)
82 |
83 | targeter = {}
84 | indexer = 0
85 | for elem in data_for_targeter['annotations']:
86 | king = []
87 | king.append(data_catg[int(elem['category_id'])][category])
88 | if king[0] not in targeter.keys():
89 | targeter[king[0]] = indexer
90 | indexer += 1
91 | self.nb_classes = len(targeter)
92 |
93 | self.samples = []
94 | for elem in data['images']:
95 | cut = elem['file_name'].split('/')
96 | target_current = int(cut[2])
97 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
98 |
99 | categors = data_catg[target_current]
100 | target_current_true = targeter[categors[category]]
101 | self.samples.append((path_current, target_current_true))
102 |
103 |
104 | def build_dataset(is_train, args):
105 | transform = build_transform(is_train, args)
106 |
107 | if args.data_set.lower() == 'cifar10':
108 | dataset = CIFAR10(args.data_path, train=is_train, download=True)
109 | elif args.data_set.lower() == 'cifar':
110 | dataset = CIFAR100(args.data_path, train=is_train, download=True)
111 | elif args.data_set.lower() == 'tinyimg':
112 | dataset = TinyImageNet200(args.data_path, train=is_train, download=True)
113 | elif args.data_set.lower() == 'imagenet100':
114 | dataset = ImageNet100_local(
115 | args.data_path, train=is_train,
116 | data_subset=os.path.join('./imagenet100_splits', "train_100.txt" if is_train else "val_100.txt")
117 | )
118 | elif args.data_set.lower() == 'imagenet1000':
119 | dataset = ImageNet1000(args.data_path, train=is_train)
120 | else:
121 | raise ValueError(f'Unknown dataset {args.data_set}.')
122 |
123 | scenario = ClassIncremental(
124 | dataset,
125 | initial_increment=args.initial_increment,
126 | increment=args.increment,
127 | transformations=transform.transforms,
128 | class_order=args.class_order
129 | )
130 | nb_classes = scenario.nb_classes #100
131 |
132 | return scenario, nb_classes
133 |
134 |
135 | def build_transform(is_train, args):
136 | if args.aa == 'none':
137 | args.aa = None
138 |
139 | with warnings.catch_warnings():
140 | resize_im = args.input_size > 32
141 | if is_train:
142 | # this should always dispatch to transforms_imagenet_train
143 | transform = create_transform(
144 | input_size=args.input_size,
145 | is_training=True,
146 | color_jitter=args.color_jitter,
147 | auto_augment=args.aa,
148 | interpolation='bicubic',
149 | re_prob=args.reprob,
150 | re_mode=args.remode,
151 | re_count=args.recount,
152 | )
153 | if not resize_im:
154 | transform.transforms[0] = transforms.RandomCrop(
155 | args.input_size, padding=4)
156 |
157 | if args.input_size == 32 and (args.data_set == 'CIFAR' or args.data_set == 'CIFAR10'):
158 | transform.transforms[-1] = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
159 | elif args.data_set == 'STL10':
160 | transform.transforms[-1] = transforms.Normalize((0.4192, 0.4124, 0.3804), (0.2714, 0.2679, 0.2771))
161 | return transform
162 |
163 | t = []
164 | if resize_im and args.data_set != 'TINYIMG':
165 | size = int((256 / 224) * args.input_size)
166 | t.append(
167 | transforms.Resize(size, interpolation=interpolation), # to maintain same ratio w.r.t. 224 images
168 | )
169 | t.append(transforms.CenterCrop(args.input_size))
170 |
171 | t.append(transforms.ToTensor())
172 | if args.input_size == 32 and (args.data_set == 'CIFAR' or args.data_set == 'CIFAR10'):
173 | # Normalization values for CIFAR100
174 | t.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
175 | else:
176 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
177 |
178 | composed_transforms = transforms.Compose(t)
179 | return composed_transforms
180 |
181 |
182 | class ImageNet1000_local(ImageFolderDataset):
183 | """ImageNet1000 dataset.
184 |
185 | Simple wrapper around ImageFolderDataset to provide a link to the download
186 | page.
187 | """
188 | def __init__(self, *args, **kwargs):
189 | super().__init__(*args, **kwargs)
190 | # if self.train:
191 | # self.data_path = os.path.join(self.data_path, "train")
192 | # else:
193 | # self.data_path = os.path.join(self.data_path, "val")
194 |
195 | @property
196 | def transformations(self):
197 | """Default transformations if nothing is provided to the scenario."""
198 | return [transforms.ToTensor(),
199 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
200 |
201 | def _download(self):
202 | if not os.path.exists(self.data_path):
203 | raise IOError(
204 | "You must download yourself the ImageNet dataset."
205 | " Please go to http://www.image-net.org/challenges/LSVRC/2012/downloads and"
206 | " download 'Training images (Task 1 & 2)' and 'Validation images (all tasks)'."
207 | )
208 | print("ImageNet already downloaded.")
209 |
210 |
211 | class ImageNet100_local(ImageNet1000_local):
212 | """Subset of ImageNet1000 made of only 100 classes.
213 |
214 | You must download the ImageNet1000 dataset then provide the images subset.
215 | If in doubt, use the option at initialization `download=True` and it will
216 | auto-download for you the subset ids used in:
217 | * Small Task Incremental Learning
218 | Douillard et al. 2020
219 | """
220 |
221 | train_subset_url = "https://github.com/Continvvm/continuum/releases/download/v0.1/train_100.txt"
222 | test_subset_url = "https://github.com/Continvvm/continuum/releases/download/v0.1/val_100.txt"
223 |
224 | def __init__(
225 | self, *args, data_subset: Union[Tuple[np.array, np.array], str, None] = None, **kwargs
226 | ):
227 | self.data_subset = data_subset
228 | super().__init__(*args, **kwargs)
229 |
230 | def _download(self):
231 | super()._download()
232 |
233 | filename = "val_100.txt"
234 | self.subset_url = self.test_subset_url
235 | if self.train:
236 | filename = "train_100.txt"
237 | self.subset_url = self.train_subset_url
238 |
239 | if self.data_subset is None:
240 | self.data_subset = os.path.join(self.data_path, filename)
241 | download(self.subset_url, self.data_path)
242 |
243 | def get_data(self) -> Tuple[np.ndarray, np.ndarray, Union[np.ndarray, None]]:
244 | data = self._parse_subset(self.data_subset, train=self.train) # type: ignore
245 | return (*data, None)
246 |
247 | def _parse_subset(
248 | self,
249 | subset: Union[Tuple[np.array, np.array], str, None],
250 | train: bool = True
251 | ) -> Tuple[np.array, np.array]:
252 | if isinstance(subset, str):
253 | x, y = [], []
254 |
255 | with open(subset, "r") as f:
256 | for line in f:
257 | split_line = line.split(" ")
258 | path = split_line[0].strip()
259 | x.append(os.path.join(self.data_path, path))
260 | y.append(int(split_line[1].strip()))
261 | x = np.array(x)
262 | y = np.array(y)
263 | return x, y
264 | return subset # type: ignore
265 |
--------------------------------------------------------------------------------
/continual/factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from continual import convit, birt, samplers, vit
4 | from continual.cnn import (InceptionV3, rebuffi, resnet18, resnet34, resnet50,
5 | resnext50_32x4d, seresnet18, vgg16, vgg16_bn,
6 | wide_resnet50_2, resnet18_scs, resnet18_scs_max, resnet18_scs_avg)
7 |
8 |
9 | def get_backbone(args):
10 | print(f"Creating model: {args.model}")
11 | if args.model == 'vit':
12 | model = vit.VisionTransformer(
13 | num_classes=args.nb_classes,
14 | drop_rate=args.drop,
15 | drop_path_rate=args.drop_path,
16 | img_size=args.input_size,
17 | patch_size=args.patch_size,
18 | embed_dim=args.embed_dim,
19 | depth=args.depth,
20 | num_heads=args.num_heads
21 | )
22 | elif args.model == 'convit':
23 | model = convit.ConVit(
24 | num_classes=args.nb_classes,
25 | drop_rate=args.drop,
26 | drop_path_rate=args.drop_path,
27 | img_size=args.input_size,
28 | patch_size=args.patch_size,
29 | embed_dim=args.embed_dim,
30 | depth=args.depth,
31 | num_heads=args.num_heads,
32 | local_up_to_layer=args.local_up_to_layer,
33 | locality_strength=args.locality_strength,
34 | class_attention=args.class_attention,
35 | ca_type='jointca' if args.joint_tokens else 'base',
36 | norm_layer=args.norm,
37 | dynamic_tokens=args.dynamic_tokens,
38 | num_blocks=args.replay_from,
39 | attn_version=args.attn_version
40 | )
41 | elif args.model == 'resnet18_scs': model = resnet18_scs()
42 | elif args.model == 'resnet18_scs_avg': model = resnet18_scs_max()
43 | elif args.model == 'resnet18_scs_max': model = resnet18_scs_avg()
44 | elif args.model == 'resnet18': model = resnet18()
45 | elif args.model == 'resnet34': model = resnet34()
46 | elif args.model == 'resnet50': model = resnet50()
47 | elif args.model == 'wide_resnet50': model = wide_resnet50_2()
48 | elif args.model == 'resnext50': model = resnext50_32x4d()
49 | elif args.model == 'seresnet18': model = seresnet18()
50 | elif args.model == 'inception3': model = InceptionV3()
51 | elif args.model == 'vgg16bn': model = vgg16_bn()
52 | elif args.model == 'vgg16': model = vgg16()
53 | elif args.model == 'rebuffi': model = rebuffi()
54 | else:
55 | raise NotImplementedError(f'Unknown backbone {args.model}')
56 |
57 | return model
58 |
59 |
60 |
61 | def get_loaders(dataset_train, dataset_val, args, drop_last=True):
62 | sampler_train, sampler_val = samplers.get_sampler(dataset_train, dataset_val, args)
63 |
64 | loader_train = torch.utils.data.DataLoader(
65 | dataset_train, sampler=sampler_train,
66 | batch_size=args.batch_size,
67 | num_workers=args.num_workers,
68 | pin_memory=args.pin_mem,
69 | drop_last=drop_last,
70 | )
71 |
72 | loader_val = torch.utils.data.DataLoader(
73 | dataset_val, sampler=sampler_val,
74 | batch_size=int(1.5 * args.batch_size),
75 | num_workers=args.num_workers,
76 | pin_memory=args.pin_mem,
77 | drop_last=False
78 | )
79 |
80 | return loader_train, loader_val
81 |
82 |
83 | def get_train_loaders(dataset_train, args, batch_size=None, drop_last=True):
84 | batch_size = batch_size or args.batch_size
85 |
86 | sampler_train = samplers.get_train_sampler(dataset_train, args)
87 |
88 | loader_train = torch.utils.data.DataLoader(
89 | dataset_train, sampler=sampler_train,
90 | batch_size=batch_size,
91 | num_workers=args.num_workers,
92 | pin_memory=False,
93 | drop_last=drop_last,
94 | )
95 |
96 | return loader_train
97 |
98 |
99 | class InfiniteLoader:
100 | def __init__(self, loader):
101 | self.loader = loader
102 | self.reset()
103 |
104 | def reset(self):
105 | self.it = iter(self.loader)
106 |
107 | def get(self):
108 | try:
109 | return next(self.it)
110 | except StopIteration:
111 | self.reset()
112 | return self.get()
113 |
114 |
115 | def update_birt(model_without_ddp, task_id, args):
116 | if task_id == 0:
117 | print(f'Creating BiRT!')
118 | model_without_ddp = birt.BiRT(
119 | model_without_ddp,
120 | nb_classes=args.initial_increment, # 10
121 | individual_classifier=args.ind_clf, #1-1
122 | head_div=args.head_div > 0., # 0.1
123 | head_div_mode=args.head_div_mode, # 'tr'
124 | joint_tokens=args.joint_tokens, # False
125 | num_blocks=args.replay_from, # from which block to replay representation
126 | multi_token_setup=args.multi_token_setup,
127 | )
128 | else:
129 | print(f'Updating ensemble, new embed dim {model_without_ddp.sabs[-1].dim}.')
130 | model_without_ddp.add_model(args.increment, args.multi_token_setup)
131 |
132 | return model_without_ddp
133 |
--------------------------------------------------------------------------------
/continual/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Implements the knowledge distillation loss
5 | """
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 |
10 |
11 | class DistillationLoss(torch.nn.Module):
12 | """
13 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
14 | taking a teacher model prediction and using it as additional supervision.
15 | """
16 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
17 | distillation_type: str, alpha: float, tau: float):
18 | super().__init__()
19 | self.base_criterion = base_criterion
20 | self.teacher_model = teacher_model
21 | assert distillation_type in ['none', 'soft', 'hard']
22 | self.distillation_type = distillation_type
23 | self.alpha = alpha
24 | self.tau = tau
25 |
26 | def forward(self, inputs, outputs, labels):
27 | """
28 | Args:
29 | inputs: The original inputs that are feed to the teacher model
30 | outputs: the outputs of the model to be trained. It is expected to be
31 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
32 | in the first position and the distillation predictions as the second output
33 | labels: the labels for the base criterion
34 | """
35 | outputs_kd = None
36 | if not isinstance(outputs, torch.Tensor):
37 | # assume that the model outputs a tuple of [outputs, outputs_kd]
38 | outputs, outputs_kd = outputs
39 | base_loss = self.base_criterion(outputs, labels)
40 | if self.distillation_type == 'none':
41 | return base_loss
42 |
43 | if outputs_kd is None:
44 | raise ValueError("When knowledge distillation is enabled, the model is "
45 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
46 | "class_token and the dist_token")
47 | # don't backprop throught the teacher
48 | with torch.no_grad():
49 | teacher_outputs = self.teacher_model(inputs)
50 |
51 | if self.distillation_type == 'soft':
52 | T = self.tau
53 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
54 | # with slight modifications
55 | distillation_loss = F.kl_div(
56 | F.log_softmax(outputs_kd / T, dim=1),
57 | F.log_softmax(teacher_outputs / T, dim=1),
58 | reduction='sum',
59 | log_target=True
60 | ) * (T * T) / outputs_kd.numel()
61 | elif self.distillation_type == 'hard':
62 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
63 |
64 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
65 | return loss
66 |
67 |
68 | def bce_with_logits(x, y):
69 | return F.binary_cross_entropy_with_logits(
70 | x,
71 | torch.eye(x.shape[1])[y].to(y.device)
72 | )
73 |
74 |
75 | def bce_smooth_pos_with_logits(smooth):
76 | def _func(x, y):
77 | return F.binary_cross_entropy_with_logits(
78 | x,
79 | torch.clamp(
80 | torch.eye(x.shape[1])[y].to(y.device) - smooth,
81 | min=0.0
82 | )
83 | )
84 | return _func
85 |
86 |
87 | def bce_smooth_posneg_with_logits(smooth):
88 | def _func(x, y):
89 | return F.binary_cross_entropy_with_logits(
90 | x,
91 | torch.clamp(
92 | torch.eye(x.shape[1])[y].to(y.device) + smooth,
93 | max=1 - smooth
94 | )
95 | )
96 | return _func
97 |
98 |
99 | class LabelSmoothingCrossEntropyBoosting(nn.Module):
100 | """
101 | NLL loss with label smoothing.
102 | """
103 | def __init__(self, smoothing=0.1, alpha=1, gamma=1):
104 | """
105 | Constructor for the LabelSmoothing module.
106 | :param smoothing: label smoothing factor
107 | """
108 | super().__init__()
109 | assert smoothing < 1.0
110 | self.smoothing = smoothing
111 | self.confidence = 1. - smoothing
112 |
113 | self.alpha = alpha
114 | self.gamma = gamma
115 |
116 | def forward(self, x, target, boosting_output=None, boosting_focal=None):
117 | if boosting_output is None:
118 | return self._base_loss(x, target)
119 | return self._focal_loss(x, target, boosting_output, boosting_focal)
120 |
121 | def _focal_loss(self, x, target, boosting_output, boosting_focal):
122 | logprobs = F.log_softmax(x, dim=-1)
123 |
124 | if boosting_focal == 'old':
125 | pt = boosting_output.softmax(-1)[..., :-1]
126 |
127 | f = torch.ones_like(logprobs)
128 | f[:, :boosting_output.shape[1] - 1] = self.alpha * (1 - pt) ** self.gamma
129 | logprobs = f * logprobs
130 | elif boosting_focal == 'new':
131 | pt = boosting_output.softmax(-1)[..., -1]
132 | nb_old_classes = boosting_output.shape[1] - 1
133 |
134 | f = torch.ones_like(logprobs)
135 | f[:, nb_old_classes:] = self.alpha * (1 - pt[:, None]) ** self.gamma
136 | logprobs = f * logprobs
137 | else:
138 | assert False, (boosting_focal)
139 |
140 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
141 | nll_loss = nll_loss.squeeze(1)
142 | smooth_loss = -logprobs.mean(dim=-1)
143 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
144 | return loss.mean()
145 |
146 | def _base_loss(self, x, target):
147 | logprobs = F.log_softmax(x, dim=-1)
148 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
149 | nll_loss = nll_loss.squeeze(1)
150 | smooth_loss = -logprobs.mean(dim=-1)
151 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
152 | return loss.mean()
153 |
154 |
155 | class SoftTargetCrossEntropyBoosting(nn.Module):
156 |
157 | def __init__(self, alpha=1, gamma=1):
158 | super().__init__()
159 | self.alpha = alpha
160 | self.gamma = gamma
161 |
162 | def forward(self, x, target, boosting_output=None, boosting_focal=None):
163 | if boosting_output is None:
164 | return torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1).mean()
165 |
166 | if boosting_focal == 'old':
167 | pt = boosting_output.softmax(-1)[..., :-1]
168 |
169 | f = torch.ones_like(x)
170 | f[:, :boosting_output.shape[1] - 1] = self.alpha * (1 - pt) ** self.gamma
171 | elif boosting_focal == 'new':
172 | pt = boosting_output.softmax(-1)[..., -1]
173 |
174 | nb_old_classes = boosting_output.shape[1] - 1
175 |
176 | f = torch.ones_like(x)
177 | f[:, nb_old_classes:] = self.alpha * (1 - pt[:, None]) ** self.gamma
178 | else:
179 | assert False, (boosting_focal)
180 |
181 | return torch.sum(-target * f * F.log_softmax(x, dim=-1), dim=-1).mean()
182 |
--------------------------------------------------------------------------------
/continual/misc.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeurAI-Lab/BiRT/1340a97bff17fc02e228b754bd80e1fd649ff9cd/continual/misc.py
--------------------------------------------------------------------------------
/continual/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2020 Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda', old_target=None):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | if old_target is not None:
27 | y2 = one_hot(old_target, num_classes, on_value=on_value, off_value=off_value, device=device)
28 | else:
29 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
30 | return y1 * lam + y2 * (1. - lam)
31 |
32 |
33 | def rand_bbox(img_shape, lam, margin=0., count=None):
34 | """ Standard CutMix bounding-box
35 | Generates a random square bbox based on lambda value. This impl includes
36 | support for enforcing a border margin as percent of bbox dimensions.
37 |
38 | Args:
39 | img_shape (tuple): Image shape as tuple
40 | lam (float): Cutmix lambda value
41 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
42 | count (int): Number of bbox to generate
43 | """
44 | ratio = np.sqrt(1 - lam)
45 | img_h, img_w = img_shape[-2:]
46 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
47 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
48 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
49 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
50 | yl = np.clip(cy - cut_h // 2, 0, img_h)
51 | yh = np.clip(cy + cut_h // 2, 0, img_h)
52 | xl = np.clip(cx - cut_w // 2, 0, img_w)
53 | xh = np.clip(cx + cut_w // 2, 0, img_w)
54 | return yl, yh, xl, xh
55 |
56 |
57 | def rand_bbox_minmax(img_shape, minmax, count=None):
58 | """ Min-Max CutMix bounding-box
59 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
60 | based on min/max percent values applied to each dimension of the input image.
61 |
62 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
63 |
64 | Args:
65 | img_shape (tuple): Image shape as tuple
66 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
67 | count (int): Number of bbox to generate
68 | """
69 | assert len(minmax) == 2
70 | img_h, img_w = img_shape[-2:]
71 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
72 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
73 | yl = np.random.randint(0, img_h - cut_h, size=count)
74 | xl = np.random.randint(0, img_w - cut_w, size=count)
75 | yu = yl + cut_h
76 | xu = xl + cut_w
77 | return yl, yu, xl, xu
78 |
79 |
80 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
81 | """ Generate bbox and apply lambda correction.
82 | """
83 | if ratio_minmax is not None:
84 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
85 | else:
86 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
87 | if correct_lam or ratio_minmax is not None:
88 | bbox_area = (yu - yl) * (xu - xl)
89 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
90 | return (yl, yu, xl, xu), lam
91 |
92 |
93 | class Mixup:
94 | """ Mixup/Cutmix that applies different params to each element or whole batch
95 |
96 | Args:
97 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
98 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
99 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
100 | prob (float): probability of applying mixup or cutmix per batch or element
101 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
102 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
103 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
104 | label_smoothing (float): apply label smoothing to the mixed target tensor
105 | num_classes (int): number of classes for target
106 | """
107 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
108 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000,
109 | loader_memory=None):
110 | self.mixup_alpha = mixup_alpha
111 | self.cutmix_alpha = cutmix_alpha
112 | self.cutmix_minmax = cutmix_minmax
113 | if self.cutmix_minmax is not None:
114 | assert len(self.cutmix_minmax) == 2
115 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
116 | self.cutmix_alpha = 1.0
117 | self.mix_prob = prob
118 | self.switch_prob = switch_prob
119 | self.label_smoothing = label_smoothing
120 | self.num_classes = num_classes
121 | self.mode = mode
122 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
123 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
124 | self.loader_memory = loader_memory
125 |
126 | def _params_per_elem(self, batch_size):
127 | lam = np.ones(batch_size, dtype=np.float32)
128 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
129 | if self.mixup_enabled:
130 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
131 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
132 | lam_mix = np.where(
133 | use_cutmix,
134 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
135 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
136 | elif self.mixup_alpha > 0.:
137 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
138 | elif self.cutmix_alpha > 0.:
139 | use_cutmix = np.ones(batch_size, dtype=np.bool)
140 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
141 | else:
142 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
143 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
144 | return lam, use_cutmix
145 |
146 | def _params_per_batch(self):
147 | lam = 1.
148 | use_cutmix = False
149 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
150 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
151 | use_cutmix = np.random.rand() < self.switch_prob
152 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
153 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
154 | elif self.mixup_alpha > 0.:
155 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
156 | elif self.cutmix_alpha > 0.:
157 | use_cutmix = True
158 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
159 | else:
160 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
161 | lam = float(lam_mix)
162 | return lam, use_cutmix
163 |
164 | def _mix_elem(self, x):
165 | batch_size = len(x)
166 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
167 | x_orig = x.clone() # need to keep an unmodified original for mixing source
168 | for i in range(batch_size):
169 | j = batch_size - i - 1
170 | lam = lam_batch[i]
171 | if lam != 1.:
172 | if use_cutmix[i]:
173 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
174 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
175 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
176 | lam_batch[i] = lam
177 | else:
178 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
179 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
180 |
181 | def _mix_pair(self, x):
182 | batch_size = len(x)
183 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
184 | x_orig = x.clone() # need to keep an unmodified original for mixing source
185 | for i in range(batch_size // 2):
186 | j = batch_size - i - 1
187 | lam = lam_batch[i]
188 | if lam != 1.:
189 | if use_cutmix[i]:
190 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
191 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
192 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
193 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
194 | lam_batch[i] = lam
195 | else:
196 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
197 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
198 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
199 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
200 |
201 | def _mix_batch(self, x):
202 | lam, use_cutmix = self._params_per_batch()
203 | if lam == 1.:
204 | return 1.
205 | if use_cutmix:
206 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
207 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
208 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
209 | else:
210 | x_flipped = x.flip(0).mul_(1. - lam)
211 | x.mul_(lam).add_(x_flipped)
212 | return lam
213 |
214 | def _mix_old(self, x, old_x):
215 | lam, use_cutmix = self._params_per_batch()
216 | if lam == 1.:
217 | return 1.
218 | if use_cutmix:
219 | assert False
220 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
221 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
222 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
223 | else:
224 | x_flipped = x.flip(0).mul_(1. - lam)
225 | x.mul_(lam).add_(x_flipped)
226 | #x.mul_(lam).add_(old_x.mul_(1. - lam))
227 | return lam
228 |
229 | def __call__(self, x, target):
230 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
231 | old_y = None
232 | if self.mode == 'elem':
233 | lam = self._mix_elem(x)
234 | elif self.mode == 'pair':
235 | lam = self._mix_pair(x)
236 | elif self.mode == 'batch' or (self.mode == 'old' and self.loader_memory is None):
237 | lam = self._mix_batch(x)
238 | else: # old
239 | old_x, old_y, _ = self.loader_memory.get()
240 | old_x, old_y = old_x.to(x.device), old_y.to(x.device)
241 | lam = self._mix_old(x, old_x)
242 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, old_target=old_y)
243 | return x, target, lam
244 |
245 |
246 | class FastCollateMixup(Mixup):
247 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
248 |
249 | A Mixup impl that's performed while collating the batches.
250 | """
251 |
252 | def _mix_elem_collate(self, output, batch, half=False):
253 | batch_size = len(batch)
254 | num_elem = batch_size // 2 if half else batch_size
255 | assert len(output) == num_elem
256 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
257 | for i in range(num_elem):
258 | j = batch_size - i - 1
259 | lam = lam_batch[i]
260 | mixed = batch[i][0]
261 | if lam != 1.:
262 | if use_cutmix[i]:
263 | if not half:
264 | mixed = mixed.copy()
265 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
266 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
267 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
268 | lam_batch[i] = lam
269 | else:
270 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
271 | np.rint(mixed, out=mixed)
272 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
273 | if half:
274 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
275 | return torch.tensor(lam_batch).unsqueeze(1)
276 |
277 | def _mix_pair_collate(self, output, batch):
278 | batch_size = len(batch)
279 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
280 | for i in range(batch_size // 2):
281 | j = batch_size - i - 1
282 | lam = lam_batch[i]
283 | mixed_i = batch[i][0]
284 | mixed_j = batch[j][0]
285 | assert 0 <= lam <= 1.0
286 | if lam < 1.:
287 | if use_cutmix[i]:
288 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
289 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
290 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
291 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
292 | mixed_j[:, yl:yh, xl:xh] = patch_i
293 | lam_batch[i] = lam
294 | else:
295 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
296 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
297 | mixed_i = mixed_temp
298 | np.rint(mixed_j, out=mixed_j)
299 | np.rint(mixed_i, out=mixed_i)
300 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
301 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
302 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
303 | return torch.tensor(lam_batch).unsqueeze(1)
304 |
305 | def _mix_batch_collate(self, output, batch):
306 | batch_size = len(batch)
307 | lam, use_cutmix = self._params_per_batch()
308 | if use_cutmix:
309 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
310 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
311 | for i in range(batch_size):
312 | j = batch_size - i - 1
313 | mixed = batch[i][0]
314 | if lam != 1.:
315 | if use_cutmix:
316 | mixed = mixed.copy() # don't want to modify the original while iterating
317 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
318 | else:
319 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
320 | np.rint(mixed, out=mixed)
321 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
322 | return lam
323 |
324 | def __call__(self, batch, _=None):
325 | batch_size = len(batch)
326 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
327 | half = 'half' in self.mode
328 | if half:
329 | batch_size //= 2
330 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
331 | if self.mode == 'elem' or self.mode == 'half':
332 | lam = self._mix_elem_collate(output, batch, half=half)
333 | elif self.mode == 'pair':
334 | lam = self._mix_pair_collate(output, batch)
335 | else:
336 | lam = self._mix_batch_collate(output, batch)
337 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
338 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
339 | target = target[:batch_size]
340 | return output, target
341 |
342 |
--------------------------------------------------------------------------------
/continual/mycontinual/__init__.py:
--------------------------------------------------------------------------------
1 | from continual.mycontinual.transformation_incremental import TransformationIncremental
2 | from continual.mycontinual.rotations import Rotations
3 | from continual.mycontinual.permutations import Permutations
4 | from continual.mycontinual.incremental_rotation import IncrementalRotation
5 | from continual.mycontinual.custom_array_task_set import ArrayTaskSet
6 |
7 | __all__ = [
8 | "Rotations",
9 | "TransformationIncremental",
10 | "IncrementalRotation",
11 | "Permutations",
12 | "ArrayTaskSet"
13 | ]
--------------------------------------------------------------------------------
/continual/mycontinual/custom_array_task_set.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union, Optional, List
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 | from torchvision import transforms
7 |
8 | from continuum.viz import plot_samples
9 | from continuum.tasks.base import BaseTaskSet, _tensorize_list, TaskType
10 |
11 |
12 | class ArrayTaskSet(BaseTaskSet):
13 | """A task dataset returned by the CLLoader specialized into numpy/torch image arrays data.
14 |
15 | :param x: The data, either image-arrays or paths to images saved on disk.
16 | :param y: The targets, not one-hot encoded.
17 | :param t: The task id of each sample.
18 | :param trsf: The transformations to apply on the images.
19 | :param target_trsf: The transformations to apply on the labels.
20 | :param bounding_boxes: The bounding boxes annotations to crop images
21 | """
22 |
23 | def __init__(
24 | self,
25 | x: np.ndarray,
26 | y: np.ndarray,
27 | t: np.ndarray,
28 | trsf: Union[transforms.Compose, List[transforms.Compose]],
29 | target_trsf: Optional[Union[transforms.Compose, List[transforms.Compose]]],
30 | bounding_boxes: Optional[np.ndarray] = None
31 | ):
32 | super().__init__(x, y, t, trsf, target_trsf, bounding_boxes=bounding_boxes)
33 | self.data_type = TaskType.IMAGE_ARRAY
34 |
35 | def plot(
36 | self,
37 | path: Union[str, None] = None,
38 | title: str = "",
39 | nb_samples: int = 100,
40 | shape: Optional[Tuple[int, int]] = None,
41 | ) -> None:
42 | """Plot samples of the current task, useful to check if everything is ok.
43 |
44 | :param path: If not None, save on disk at this path.
45 | :param title: The title of the figure.
46 | :param nb_samples: Amount of samples randomly selected.
47 | :param shape: Shape to resize the image before plotting.
48 | """
49 | plot_samples(self, title=title, path=path, nb_samples=nb_samples,
50 | shape=shape, data_type=self.data_type)
51 |
52 | def get_samples(self, indexes):
53 | samples, targets, tasks = [], [], []
54 |
55 | w, h = None, None
56 | for index in indexes:
57 | # we need to use __getitem__ to have the transform used
58 | sample, y, t = self[index]
59 |
60 | # we check dimension of images
61 | if w is None:
62 | w, h = sample.shape[:2]
63 | elif w != sample.shape[0] or h != sample.shape[1]:
64 | raise Exception(
65 | "Images dimension are inconsistent, resize them to a "
66 | "common size using a transformation.\n"
67 | "For example, give to the scenario you're using as `transformations` argument "
68 | "the following: [transforms.Resize((224, 224)), transforms.ToTensor()]"
69 | )
70 |
71 | samples.append(sample)
72 | targets.append(y)
73 | tasks.append(t)
74 |
75 | return _tensorize_list(samples), _tensorize_list(targets), _tensorize_list(tasks)
76 |
77 | def get_sample(self, index: int) -> np.ndarray:
78 | """Returns a Pillow image corresponding to the given `index`.
79 |
80 | :param index: Index to query the image.
81 | :return: A Pillow image.
82 | """
83 | x = self._x[index]
84 | # x = Image.fromarray(x.astype("uint8"))
85 | return x
86 |
87 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int, int]:
88 | """Method used by PyTorch's DataLoaders to query a sample and its target."""
89 | x = self.get_sample(index)
90 | y = self._y[index]
91 | t = self._t[index]
92 |
93 | if self.bounding_boxes is not None:
94 | bbox = self.bounding_boxes[index]
95 | x = x.crop((
96 | max(bbox[0], 0), # x1
97 | max(bbox[1], 0), # y1
98 | min(bbox[2], x.size[0]), # x2
99 | min(bbox[3], x.size[1]), # y2
100 | ))
101 |
102 | x, y, t = self._prepare_data(x, y, t)
103 |
104 | if self.target_trsf is not None:
105 | y = self.get_task_target_trsf(t)(y)
106 |
107 | return x, y, t
108 |
109 | def _prepare_data(self, x, y, t):
110 | if self.trsf is not None:
111 | x = self.get_task_trsf(t)(x)
112 | if not isinstance(x, torch.Tensor):
113 | x = self._to_tensor(x)
114 | return x, y, t
115 |
--------------------------------------------------------------------------------
/continual/mycontinual/incremental_rotation.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as F
2 | import numpy as np
3 |
4 |
5 | class IncrementalRotation(object):
6 | """
7 | Defines an incremental rotation for a numpy array.
8 | """
9 |
10 | def __init__(self, init_deg: int = 0, increase_per_iteration: float = 0.006) -> None:
11 | """
12 | Defines the initial angle as well as the increase for each rotation
13 | :param init_deg:
14 | :param increase_per_iteration:
15 | """
16 | self.increase_per_iteration = increase_per_iteration
17 | self.iteration = 0
18 | self.degrees = init_deg
19 |
20 | def __call__(self, x: np.ndarray) -> np.ndarray:
21 | """
22 | Applies the rotation.
23 | :param x: image to be rotated
24 | :return: rotated image
25 | """
26 | degs = (self.iteration * self.increase_per_iteration + self.degrees) % 360
27 | self.iteration += 1
28 | return F.rotate(x, degs)
29 |
30 | def set_iteration(self, x: int) -> None:
31 | """
32 | Set the iteration to a given integer
33 | :param x: iteration index
34 | """
35 | self.iteration = x
--------------------------------------------------------------------------------
/continual/mycontinual/permutations.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Callable, List, Union
3 |
4 | import numpy as np
5 | import torch
6 | from torchvision import transforms
7 |
8 | from continuum.datasets import _ContinuumDataset
9 | from continual.mycontinual import TransformationIncremental
10 |
11 |
12 | class Permutations(TransformationIncremental):
13 | """Continual Loader, generating datasets for the consecutive tasks.
14 |
15 | Scenario: Permutations scenarios, use same data for all task but with pixels permuted.
16 | Each task get a specific permutation, such as all tasks are different.
17 |
18 | :param cl_dataset: A continual dataset.
19 | :param nb_tasks: The scenario's number of tasks.
20 | :param base_transformations: List of transformations to apply to all tasks.
21 | :param seed: initialization seed for the permutations.
22 | :param shared_label_space: If true same data with different transformation have same label
23 | """
24 |
25 | def __init__(
26 | self,
27 | cl_dataset: _ContinuumDataset,
28 | nb_tasks: Union[int, None] = None,
29 | base_transformations: List[Callable] = None,
30 | seed: Union[int, List[int]] = 0,
31 | shared_label_space=True
32 | ):
33 | trsfs = self._generate_transformations(seed, nb_tasks)
34 |
35 | super().__init__(
36 | cl_dataset=cl_dataset,
37 | incremental_transformations=trsfs,
38 | base_transformations=base_transformations,
39 | shared_label_space=shared_label_space
40 | )
41 |
42 | def _generate_transformations(self, seed, nb_tasks):
43 | if isinstance(seed, int):
44 | if nb_tasks is None:
45 | raise ValueError("You must specify a number of tasks if a single seed is provided.")
46 | rng = np.random.RandomState(seed=seed)
47 | seed = rng.permutation(100000)[:nb_tasks - 1]
48 | elif nb_tasks is not None and nb_tasks != len(seed) + 1:
49 | warnings.warn(
50 | f"Because a list of seed was provided {seed}, "
51 | f"the number of tasks is automatically set to "
52 | f"len(number of seeds) + 1 = {len(seed) + 1}"
53 | )
54 |
55 | return [PermutationTransform(seed=None)] + [PermutationTransform(seed=int(s)) for s in seed]
56 |
57 | def get_task_transformation(self, task_index):
58 | return transforms.Compose(self.trsf.transforms + [self.inc_trsf[task_index]])
59 |
60 |
61 | class PermutationTransform:
62 | """Permutation transformers.
63 |
64 | This transformer is initialized with a seed such as same seed = same permutation.
65 | Seed 0 means no permutations
66 |
67 | :param seed: seed to initialize the random number generator
68 | """
69 |
70 | def __init__(self, seed: Union[int, None]):
71 | self.seed = seed
72 | self.g_cpu = torch.Generator()
73 |
74 | def __call__(self, x):
75 | shape = list(x.shape)
76 | x = x.reshape(-1)
77 | # if seed is None, no permutations
78 | if self.seed is not None:
79 | self.g_cpu.manual_seed(self.seed)
80 | perm = torch.randperm(x.numel(), generator=self.g_cpu).long()
81 | x = x[perm]
82 | return x.reshape(shape)
83 |
--------------------------------------------------------------------------------
/continual/mycontinual/rotations.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Tuple, Union
2 |
3 | from torchvision import transforms
4 |
5 | from continuum.datasets import _ContinuumDataset
6 | from continual.mycontinual import TransformationIncremental
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | class Rotations(TransformationIncremental):
11 | """Continual Loader, generating datasets for the consecutive tasks.
12 |
13 | Scenario: Rotations scenario is a new instance scenario.
14 | For each task data is rotated from a certain angle.
15 |
16 | :param cl_dataset: A continual dataset.
17 | :param nb_tasks: The scenario's number of tasks.
18 | :param list_degrees: list of rotation in degree (int) or list of range. e.g. (0, (40,45), 90).
19 | :param base_transformations: Preprocessing transformation to applied to data before rotation.
20 | :param shared_label_space: If true same data with different transformation have same label
21 | """
22 |
23 | def __init__(
24 | self,
25 | cl_dataset: _ContinuumDataset,
26 | list_degrees: Union[List[Tuple], List[int]],
27 | nb_tasks: Union[int, None] = None,
28 | base_transformations: List[Callable] = None,
29 | shared_label_space=True
30 | ):
31 |
32 | if nb_tasks is not None and len(list_degrees) != nb_tasks:
33 | raise ValueError(
34 | f"The nb of tasks ({nb_tasks}) != number of angles "
35 | f"tuples ({len(list_degrees)}) set in the list"
36 | )
37 |
38 | trsfs = self._generate_transformations(list_degrees)
39 |
40 | super().__init__(
41 | cl_dataset=cl_dataset,
42 | incremental_transformations=trsfs,
43 | base_transformations=base_transformations,
44 | shared_label_space=shared_label_space
45 | )
46 |
47 | def _generate_transformations(self, degrees):
48 | trsfs = []
49 | min_deg, max_deg = None, None
50 |
51 | for deg in degrees:
52 | if isinstance(deg, int) or isinstance(deg, float):
53 | min_deg, max_deg = deg, deg
54 | elif len(deg) == 2:
55 | min_deg, max_deg = deg
56 | else:
57 | raise ValueError(
58 | f"Invalid list of degrees ({degrees}). "
59 | "It should contain either integers (-deg, +deg) or "
60 | "tuples (range) of integers (deg_a, deg_b)."
61 | )
62 |
63 | trsfs.append([transforms.RandomAffine(degrees=[min_deg, max_deg])])
64 |
65 | return trsfs
66 |
--------------------------------------------------------------------------------
/continual/mycontinual/transformation_incremental.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional
2 |
3 | import numpy as np
4 | from torchvision import transforms
5 |
6 | from continuum.datasets import _ContinuumDataset
7 | from continuum.scenarios import InstanceIncremental
8 | from continuum.tasks import TaskSet, TaskType
9 |
10 |
11 | class TransformationIncremental(InstanceIncremental):
12 | """Continual Loader, generating datasets for the consecutive tasks.
13 |
14 | Scenario: Every task contains the same data with different transformations.
15 | It is a cheap way to create instance incremental scenarios.
16 | Moreover, it is easier to analyse what algorithms forget or not.
17 | Classic transformation incremental scenarios are "permutations" and "rotations".
18 |
19 | :param cl_dataset: A continual dataset.
20 | :param incremental_transformations: list of transformations to apply to specific tasks
21 | :param base_transformations: List of transformation to apply to all tasks.
22 | :param shared_label_space: If true same data with different transformation have same label
23 | """
24 |
25 | def __init__(
26 | self,
27 | cl_dataset: _ContinuumDataset,
28 | incremental_transformations: List[List[Callable]],
29 | base_transformations: List[Callable] = None,
30 | shared_label_space=True
31 | ):
32 | nb_tasks = len(incremental_transformations)
33 | if incremental_transformations is None:
34 | raise ValueError("For this scenario a list transformation should be set")
35 |
36 | if cl_dataset.data_type == TaskType.H5:
37 | raise NotImplementedError("TransformationIncremental are not compatible yet with h5 files.")
38 |
39 | self.inc_trsf = incremental_transformations
40 | #self._nb_tasks = self._setup(nb_tasks)
41 | self.shared_label_space = shared_label_space
42 |
43 | super().__init__(
44 | cl_dataset=cl_dataset, nb_tasks=nb_tasks, transformations=base_transformations
45 | )
46 |
47 | self.num_classes_per_task = len(np.unique(self.dataset[1])) # the num of classes is the same for all task is this scenario
48 |
49 | @property
50 | def nb_classes(self) -> int:
51 | """Total number of classes in the whole continual setting."""
52 | if self.shared_label_space:
53 | nb_classes = len(np.unique(self.dataset[1]))
54 | else:
55 | nb_classes = len(np.unique(self.dataset[1])) * self._nb_tasks
56 | return nb_classes
57 |
58 | def get_task_transformation(self, task_index):
59 | return transforms.Compose(self.inc_trsf[task_index] + self.trsf.transforms)
60 |
61 | def update_task_indexes(self, task_index):
62 | new_t = np.ones(len(self.dataset[1])) * task_index
63 | self.dataset = (self.dataset[0], self.dataset[1], new_t)
64 |
65 | def update_labels(self, task_index):
66 | # wrong
67 | # new_y = self.dataset[1] + task_index * self.num_classes_per_task
68 | # we update incrementally then update is simply:
69 | if task_index > 0:
70 | new_y = self.dataset[1] + self.num_classes_per_task
71 | self.dataset = (self.dataset[0], new_y, self.dataset[2])
72 |
73 | def __getitem__(self, task_index):
74 | """Returns a task by its unique index.
75 |
76 | :param task_index: The unique index of a task, between 0 and len(loader) - 1. Or it could
77 | be a list or a numpy array or even a slice.
78 | :return: A train PyTorch's Datasets.
79 | """
80 | x, y, _ = self.dataset
81 |
82 | if isinstance(task_index, slice):
83 | # Convert a slice to a list and respect the Python's advanced indexing conventions
84 | start = task_index.start if task_index.start is not None else 0
85 | stop = task_index.stop if task_index.stop is not None else len(self) + 1
86 | step = task_index.step if task_index.step is not None else 1
87 | task_index = list(range(start, stop, step))
88 | if len(task_index) == 0:
89 | raise ValueError(f"Invalid slicing resulting in no data (start={start}, end={stop}, step={step}).")
90 | elif isinstance(task_index, np.ndarray):
91 | task_index = list(task_index)
92 | elif isinstance(task_index, int):
93 | task_index = [task_index]
94 | else:
95 | raise TypeError(f"Invalid type of task index {type(task_index).__name__}.")
96 |
97 | task_index = set([_handle_negative_indexes(ti, len(self)) for ti in task_index])
98 |
99 | t = np.concatenate([
100 | (np.ones(len(x)) * ti).astype(np.int32) for ti in task_index
101 | ])
102 | x = np.concatenate([
103 | x for _ in range(len(task_index))
104 | ])
105 |
106 | if self.shared_label_space:
107 | y = np.concatenate([
108 | y for _ in range(len(task_index))
109 | ])
110 | else:
111 | # Different transformations have different labels even though
112 | # the original images were the same
113 | y = np.concatenate([
114 | y + ti * self.num_classes_per_task for ti in task_index
115 | ])
116 |
117 | # trsf = [ # Non-used tasks have a None trsf
118 | # self.get_task_transformation(ti)
119 | # if ti in task_index else None
120 | # for ti in range(len(self))
121 | # ]
122 |
123 | trsf = [ # Non-used tasks have a None trsf
124 | self.get_task_transformation(ti)
125 | # if ti in task_index else None
126 | for ti in range(len(self))
127 | ]
128 |
129 | return TaskSet(x, y, t, trsf, data_type=self.cl_dataset.data_type)
130 |
131 |
132 | def _handle_negative_indexes(index: int, total_len: int) -> int:
133 | while index < 0:
134 | index += total_len
135 | return index
136 |
--------------------------------------------------------------------------------
/continual/rehearsal.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import torch
5 | from torchvision import transforms
6 |
7 | from continual.mycontinual import ArrayTaskSet
8 |
9 |
10 | class Memory:
11 | def __init__(self, memory_size, nb_total_classes, rehearsal, rep_replay=False, fixed=True):
12 | self.memory_size = memory_size # 2000
13 | self.nb_total_classes = nb_total_classes # 100
14 | self.rehearsal = rehearsal # icarl_all
15 | self.fixed = fixed # False
16 | self.rep_replay = rep_replay
17 |
18 | self.x = self.y = self.t = None
19 |
20 | self.nb_classes = 0
21 |
22 | @property
23 | def memory_per_class(self):
24 | if self.fixed:
25 | return self.memory_size // self.nb_total_classes
26 | return self.memory_size // self.nb_classes if self.nb_classes > 0 else self.memory_size
27 |
28 | def get_dataset_without_copy(self, base_dataset):
29 | dataset = base_dataset
30 | dataset._x = self.x
31 | dataset._y = self.y
32 | dataset._t = self.t
33 |
34 | return dataset
35 |
36 | def get_dataset(self, base_dataset):
37 | if self.rep_replay:
38 | dataset = ArrayTaskSet(x=self.x, y=self.y, t=self.t, trsf=None, target_trsf=None,
39 | bounding_boxes=None)
40 | else:
41 | dataset = copy.deepcopy(base_dataset)
42 | dataset._x = self.x
43 | dataset._y = self.y
44 | dataset._t = self.t
45 |
46 | return dataset
47 |
48 | def get(self):
49 | return self.x, self.y, self.t
50 |
51 | def __len__(self):
52 | return len(self.x) if self.x is not None else 0
53 |
54 | def save(self, path):
55 | np.savez(
56 | path,
57 | x=self.x, y=self.y, t=self.t
58 | )
59 |
60 | def load(self, path):
61 | data = np.load(path)
62 | self.x = data["x"]
63 | self.y = data["y"]
64 | self.t = data["t"]
65 |
66 | assert len(self) <= self.memory_size, len(self)
67 | self.nb_classes = len(np.unique(self.y))
68 |
69 | def reduce(self):
70 | x, y, t = [], [], []
71 | for class_id in np.unique(self.y):
72 | indexes = np.where(self.y == class_id)[0]
73 | x.append(self.x[indexes[:self.memory_per_class]])
74 | y.append(self.y[indexes[:self.memory_per_class]])
75 | t.append(self.t[indexes[:self.memory_per_class]])
76 |
77 | self.x = np.concatenate(x)
78 | self.y = np.concatenate(y)
79 | self.t = np.concatenate(t)
80 |
81 | def add(self, dataset, model, nb_new_classes):
82 | self.nb_classes += nb_new_classes
83 |
84 | x, y, t = herd_samples(dataset, model, self.memory_per_class, self.rehearsal, self.rep_replay)
85 |
86 | if self.x is None:
87 | self.x, self.y, self.t = x, y, t
88 | else:
89 | if not self.fixed:
90 | self.reduce()
91 | self.x = np.concatenate((self.x, x))
92 | self.y = np.concatenate((self.y, y))
93 | self.t = np.concatenate((self.t, t))
94 |
95 |
96 | def herd_samples(dataset, model, memory_per_class, rehearsal, rep_replay):
97 | x, y, t = dataset._x, dataset._y, dataset._t
98 |
99 | if rehearsal == "random":
100 | indexes = []
101 | for class_id in np.unique(y):
102 | class_indexes = np.where(y == class_id)[0]
103 | indexes.append(
104 | np.random.choice(class_indexes, size=memory_per_class)
105 | )
106 | indexes = np.concatenate(indexes)
107 |
108 | return x[indexes], y[indexes], t[indexes]
109 | elif "closest" in rehearsal:
110 | if rehearsal == 'closest_token':
111 | handling = 'last'
112 | else:
113 | handling = 'all'
114 |
115 | features, targets = extract_features(dataset, model, handling)
116 | indexes = []
117 |
118 | for class_id in np.unique(y):
119 | class_indexes = np.where(y == class_id)[0]
120 | class_features = features[class_indexes]
121 |
122 | class_mean = np.mean(class_features, axis=0, keepdims=True)
123 | distances = np.power(class_features - class_mean, 2).sum(-1)
124 | class_closest_indexes = np.argsort(distances)
125 |
126 | indexes.append(
127 | class_indexes[class_closest_indexes[:memory_per_class]]
128 | )
129 |
130 | indexes = np.concatenate(indexes)
131 | return x[indexes], y[indexes], t[indexes]
132 | elif "furthest" in rehearsal:
133 | if rehearsal == 'furthest_token':
134 | handling = 'last'
135 | else:
136 | handling = 'all'
137 |
138 | features, targets = extract_features(dataset, model, handling)
139 | indexes = []
140 |
141 | for class_id in np.unique(y):
142 | class_indexes = np.where(y == class_id)[0]
143 | class_features = features[class_indexes]
144 |
145 | class_mean = np.mean(class_features, axis=0, keepdims=True)
146 | distances = np.power(class_features - class_mean, 2).sum(-1)
147 | class_furthest_indexes = np.argsort(distances)[::-1]
148 |
149 | indexes.append(
150 | class_indexes[class_furthest_indexes[:memory_per_class]]
151 | )
152 |
153 | indexes = np.concatenate(indexes)
154 | return x[indexes], y[indexes], t[indexes]
155 | elif "icarl":
156 | if rehearsal == 'icarl_token':
157 | handling = 'last'
158 | else:
159 | handling = 'all'
160 |
161 | features, targets = extract_features(dataset, model, handling)
162 | indexes = []
163 |
164 | for class_id in np.unique(y):
165 | class_indexes = np.where(y == class_id)[0]
166 | class_features = features[class_indexes]
167 |
168 | indexes.append(
169 | class_indexes[icarl_selection(class_features, memory_per_class)]
170 | )
171 |
172 | indexes = np.concatenate(indexes)
173 |
174 | # store representations for the samples
175 | if rep_replay:
176 |
177 | dataset.trsf = transforms.Compose([dataset.trsf.transforms[tf] for tf in [0,3,4]])
178 | loader = torch.utils.data.DataLoader(
179 | dataset,
180 | batch_size=128,
181 | num_workers=2,
182 | pin_memory=True,
183 | drop_last=False,
184 | shuffle=False
185 | )
186 |
187 | features, targets = [], []
188 |
189 | with torch.no_grad():
190 | for x, y, _ in loader:
191 | if hasattr(model, 'module'):
192 | reps = model.module.forward_initial(x.cuda())
193 | else:
194 | reps = model.forward_initial(x.cuda())
195 |
196 | reps = reps.detach().cpu().numpy()
197 | y = y.numpy()
198 |
199 | features.append(reps.reshape((reps.shape[0], int(reps.shape[1] ** 0.5),
200 | int(reps.shape[1] ** 0.5), reps.shape[-1])))
201 | targets.append(y)
202 |
203 | features = np.vstack(features)
204 | targets = np.concatenate(targets)
205 |
206 | return features[indexes], targets[indexes], t[indexes]
207 | else:
208 | return x[indexes], y[indexes], t[indexes]
209 | else:
210 | raise ValueError(f"Unknown rehearsal method {rehearsal}!")
211 |
212 |
213 | def extract_features(dataset, model, ensemble_handling='last'):
214 | loader = torch.utils.data.DataLoader(
215 | dataset,
216 | batch_size=128,
217 | num_workers=2,
218 | pin_memory=True,
219 | drop_last=False,
220 | shuffle=False
221 | )
222 |
223 | features, targets = [], []
224 |
225 | with torch.no_grad():
226 | for x, y, _ in loader:
227 | if hasattr(model, 'module'):
228 | feats, _, _ = model.module.forward_features(x.cuda())
229 | else:
230 | feats, _, _ = model.forward_features(x.cuda())
231 |
232 | if isinstance(feats, list):
233 | if ensemble_handling == 'last':
234 | feats = feats[-1]
235 | elif ensemble_handling == 'all':
236 | feats = torch.cat(feats, dim=1)
237 | else:
238 | raise NotImplementedError(f'Unknown handling of multiple features {ensemble_handling}')
239 | elif len(feats.shape) == 3: # joint tokens
240 | if ensemble_handling == 'last':
241 | feats = feats[-1]
242 | elif ensemble_handling == 'all':
243 | feats = feats.permute(1, 0, 2).view(len(x), -1)
244 | else:
245 | raise NotImplementedError(f'Unknown handling of multiple features {ensemble_handling}')
246 |
247 | feats = feats.cpu().numpy()
248 | y = y.numpy()
249 |
250 | features.append(feats)
251 | targets.append(y)
252 |
253 | features = np.concatenate(features)
254 | targets = np.concatenate(targets)
255 |
256 | return features, targets
257 |
258 |
259 | def icarl_selection(features, nb_examplars):
260 | D = features.T
261 | D = D / (np.linalg.norm(D, axis=0) + 1e-8)
262 | mu = np.mean(D, axis=1)
263 | herding_matrix = np.zeros((features.shape[0],))
264 |
265 | w_t = mu
266 | iter_herding, iter_herding_eff = 0, 0
267 |
268 | while not (
269 | np.sum(herding_matrix != 0) == min(nb_examplars, features.shape[0])
270 | ) and iter_herding_eff < 1000:
271 | tmp_t = np.dot(w_t, D)
272 | ind_max = np.argmax(tmp_t)
273 | iter_herding_eff += 1
274 | if herding_matrix[ind_max] == 0:
275 | herding_matrix[ind_max] = 1 + iter_herding
276 | iter_herding += 1
277 |
278 | w_t = w_t + mu - D[:, ind_max]
279 |
280 | herding_matrix[np.where(herding_matrix == 0)[0]] = 10000
281 |
282 | return herding_matrix.argsort()[:nb_examplars]
283 |
284 |
285 | def get_finetuning_dataset(dataset, memory, finetuning='balanced', rep_replay=False):
286 | if finetuning == 'balanced':
287 | x, y, t = memory.get()
288 |
289 | if rep_replay:
290 | # current task samples
291 | new_dataset = ArrayTaskSet(x=x, y=y, t=t, trsf=None,
292 | target_trsf=None, bounding_boxes=None)
293 | else:
294 | new_dataset = copy.deepcopy(dataset)
295 | new_dataset._x = x
296 | new_dataset._y = y
297 | new_dataset._t = t
298 | elif finetuning in ('all', 'none'):
299 | new_dataset = dataset
300 | else:
301 | raise NotImplementedError(f'Unknown finetuning method {finetuning}')
302 |
303 | return new_dataset
304 |
305 |
306 | def get_separate_finetuning_dataset(dataset, memory, finetuning='balanced', rep_replay=False):
307 | if finetuning == 'balanced':
308 | x, y, t = memory.get()
309 |
310 | # extract current and old task samples from memory
311 | cur_task_idx = t == max(np.unique(t))
312 | old_task_idx = t != max(np.unique(t))
313 |
314 | if rep_replay:
315 | # current task samples
316 | first_dataset = ArrayTaskSet(x=x[cur_task_idx], y=y[cur_task_idx], t=t[cur_task_idx], trsf=None,
317 | target_trsf=None, bounding_boxes=None)
318 |
319 | # old task samples
320 | second_dataset = ArrayTaskSet(x=x[old_task_idx], y=y[old_task_idx], t=t[old_task_idx], trsf=None,
321 | target_trsf=None, bounding_boxes=None)
322 | else:
323 | first_dataset = copy.deepcopy(dataset)
324 | first_dataset._x = x[cur_task_idx]
325 | first_dataset._y = y[cur_task_idx]
326 | first_dataset._t = t[cur_task_idx]
327 |
328 | second_dataset = copy.deepcopy(dataset)
329 | second_dataset._x = x[old_task_idx]
330 | second_dataset._y = y[old_task_idx]
331 | second_dataset._t = t[old_task_idx]
332 |
333 | elif finetuning in ('all', 'none'):
334 | # not supported after change
335 | new_dataset = dataset
336 | else:
337 | raise NotImplementedError(f'Unknown finetuning method {finetuning}')
338 |
339 | return first_dataset, second_dataset
340 |
--------------------------------------------------------------------------------
/continual/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class SAM:
5 | """SAM, ASAM, and Look-SAM
6 |
7 | Modified version of: https://github.com/davda54/sam
8 | Only Look-SAM has been added.
9 |
10 | It speeds up SAM quite a lot but the alpha needs to be tuned to reach same performance.
11 | """
12 | def __init__(self, base_optimizer, model_without_ddp, rho=0.05, adaptive=False, div='', use_look_sam=False, look_sam_alpha=0., **kwargs):
13 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
14 |
15 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
16 |
17 | self.base_optimizer = base_optimizer
18 | self.param_groups = self.base_optimizer.param_groups
19 | self.model_without_ddp = model_without_ddp
20 |
21 | self.rho = rho
22 | self.adaptive = adaptive
23 | self.div = div
24 | self.look_sam_alpha = look_sam_alpha
25 | self.use_look_sam = use_look_sam
26 |
27 | self.g_v = dict()
28 |
29 | @torch.no_grad()
30 | def first_step(self):
31 | self.e_w = dict()
32 | self.g = dict()
33 |
34 | grad_norm = self._grad_norm()
35 | for group in self.param_groups:
36 | scale = self.rho / (grad_norm + 1e-12)
37 |
38 | for p in group["params"]:
39 | if p.grad is None: continue
40 | e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
41 | p.add_(e_w) # climb to the local maximum "w + e(w)"
42 | self.e_w[p] = e_w
43 | self.g[p] = p.grad.clone()
44 |
45 | @torch.no_grad()
46 | def second_step(self, look_sam_update=False):
47 | if self.use_look_sam and look_sam_update:
48 | self.g_v = dict()
49 |
50 | for group in self.param_groups:
51 | for p in group["params"]:
52 | if p.grad is None: continue
53 |
54 | if not self.use_look_sam or look_sam_update:
55 | p.sub_(self.e_w[p])
56 |
57 | if self.use_look_sam and look_sam_update:
58 | cos = self._cos(self.g[p], p.grad)
59 | norm_gs = p.grad.norm(p=2)
60 | norm_g = self.g[p].norm(p=2)
61 | self.g_v[p] = p.grad - norm_gs * cos * self.g[p] / norm_g
62 | elif self.use_look_sam:
63 | norm_g = p.grad.norm(p=2)
64 | norm_gv = self.g_v[p].norm(p=2)
65 | p.grad.add_(self.look_sam_alpha * (norm_g / norm_gv) * self.g_v[p])
66 |
67 | self.e_w = None
68 | self.g = None
69 |
70 | def _cos(self, a, b):
71 | return torch.dot(a.view(-1), b.view(-1)) / (a.norm() * b.norm())
72 |
73 | @torch.no_grad()
74 | def step(self, closure=None):
75 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
76 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
77 |
78 | self.first_step(zero_grad=True)
79 | closure()
80 | self.second_step()
81 |
82 | def _grad_norm(self):
83 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
84 | norm = torch.norm(
85 | torch.stack([
86 | ((torch.abs(p) if self.adaptive else 1.0) * p.grad).norm(p=2).to(shared_device)
87 | for group in self.param_groups for p in group["params"]
88 | if p.grad is not None
89 | ]),
90 | p=2
91 | )
92 | return norm
93 |
--------------------------------------------------------------------------------
/continual/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.distributed as dist
5 | import math
6 | import numpy as np
7 |
8 | import continual.utils as utils
9 |
10 |
11 | class SingleRASampler(torch.utils.data.Sampler):
12 | """Sampler that restricts data loading to a subset of the dataset for distributed,
13 | with repeated augmentation.
14 | It ensures that different each augmented version of a sample will be visible to a
15 | different process (GPU)
16 | Heavily based on torch.utils.data.DistributedSampler
17 | """
18 |
19 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
20 | if num_replicas is None:
21 | if not dist.is_available():
22 | raise RuntimeError("Requires distributed package to be available")
23 | num_replicas = dist.get_world_size()
24 | if rank is None:
25 | if not dist.is_available():
26 | raise RuntimeError("Requires distributed package to be available")
27 | rank = dist.get_rank()
28 | self.dataset = dataset
29 | self.num_replicas = num_replicas
30 | self.rank = rank
31 | self.epoch = 0
32 | # self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
33 | self.num_samples = int(math.ceil(len(self.dataset)))
34 | self.total_size = self.num_samples * self.num_replicas
35 | self.num_selected_samples = len(self.dataset)
36 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
37 | # self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
38 | self.shuffle = shuffle
39 |
40 | def __iter__(self):
41 | # deterministically shuffle based on epoch
42 | g = torch.Generator()
43 | g.manual_seed(self.epoch)
44 | if self.shuffle:
45 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
46 | else:
47 | indices = list(range(len(self.dataset)))
48 |
49 | # add extra samples to make it evenly divisible
50 | # indices = indices + indices
51 | indices = [ele for ele in indices for i in range(self.num_replicas)]
52 | # indices += indices[:(self.total_size - len(indices))]
53 | assert len(indices) == self.total_size
54 |
55 | # subsample
56 | # indices = indices[self.rank:self.total_size:self.num_replicas]
57 | # assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.total_size
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
67 |
68 | class RASampler(torch.utils.data.Sampler):
69 | """Sampler that restricts data loading to a subset of the dataset for distributed,
70 | with repeated augmentation.
71 | It ensures that different each augmented version of a sample will be visible to a
72 | different process (GPU)
73 | Heavily based on torch.utils.data.DistributedSampler
74 | """
75 |
76 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
77 | if num_replicas is None:
78 | if not dist.is_available():
79 | raise RuntimeError("Requires distributed package to be available")
80 | num_replicas = dist.get_world_size()
81 | if rank is None:
82 | if not dist.is_available():
83 | raise RuntimeError("Requires distributed package to be available")
84 | rank = dist.get_rank()
85 | self.dataset = dataset
86 | self.num_replicas = num_replicas
87 | self.rank = rank
88 | self.epoch = 0
89 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
90 | self.total_size = self.num_samples * self.num_replicas
91 | self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
92 | # self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
93 | self.shuffle = shuffle
94 |
95 | def __iter__(self):
96 | # deterministically shuffle based on epoch
97 | g = torch.Generator()
98 | g.manual_seed(self.epoch)
99 | if self.shuffle:
100 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
101 | else:
102 | indices = list(range(len(self.dataset)))
103 |
104 | # add extra samples to make it evenly divisible
105 | indices = [ele for ele in indices for i in range(3)]
106 | indices += indices[:(self.total_size - len(indices))]
107 | assert len(indices) == self.total_size
108 |
109 | # subsample
110 | indices = indices[self.rank:self.total_size:self.num_replicas]
111 | assert len(indices) == self.num_samples
112 |
113 | return iter(indices[:self.num_selected_samples])
114 |
115 | def __len__(self):
116 | return self.num_selected_samples
117 |
118 | def set_epoch(self, epoch):
119 | self.epoch = epoch
120 |
121 |
122 | def get_sampler(dataset_train, dataset_val, args):
123 | if args.distributed:
124 | num_tasks = utils.get_world_size()
125 | global_rank = utils.get_rank()
126 | if args.repeated_aug:
127 | sampler_train = RASampler(
128 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
129 | )
130 | else:
131 | sampler_train = torch.utils.data.DistributedSampler(
132 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
133 | )
134 | if args.dist_eval:
135 | if len(dataset_val) % num_tasks != 0:
136 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
137 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
138 | 'equal num of samples per-process.')
139 | sampler_val = torch.utils.data.DistributedSampler(
140 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
141 | else:
142 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
143 | else:
144 | if args.use_repeatedaug_single:
145 | sampler_train = SingleRASampler(dataset_train, num_replicas=2, rank=0, shuffle=True)
146 | else:
147 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
148 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
149 |
150 | return sampler_train, sampler_val
151 |
152 |
153 | def get_train_sampler(dataset_train, args):
154 | if args.distributed:
155 | num_tasks = utils.get_world_size()
156 | global_rank = utils.get_rank()
157 | if args.repeated_aug:
158 | sampler_train = RASampler(
159 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
160 | )
161 | else:
162 | sampler_train = torch.utils.data.DistributedSampler(
163 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
164 | )
165 | else:
166 | if args.use_repeatedaug_single:
167 | sampler_train = SingleRASampler(dataset_train, num_replicas=2, rank=0, shuffle=True)
168 | else:
169 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
170 |
171 | return sampler_train
172 |
--------------------------------------------------------------------------------
/continual/scaler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from timm.utils import dispatch_clip_grad
4 |
5 |
6 | class ContinualScaler:
7 | state_dict_key = "amp_scaler"
8 |
9 | def __init__(self, disable_amp):
10 | self._scaler = torch.cuda.amp.GradScaler(enabled=not disable_amp)
11 |
12 | def __call__(
13 | self, loss, optimizer, model_without_ddp, clip_grad=None, clip_mode='norm',
14 | parameters=None, create_graph=False,
15 | hook=True
16 | ):
17 | self.pre_step(loss, optimizer, parameters, create_graph, clip_grad, clip_mode)
18 | self.post_step(optimizer, model_without_ddp, hook)
19 |
20 | def pre_step(self, loss, optimizer, parameters=None, create_graph=False, clip_grad=None, clip_mode='norm'):
21 | self._scaler.scale(loss).backward(create_graph=create_graph)
22 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
23 | if clip_grad is not None:
24 | assert parameters is not None
25 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
26 |
27 | def post_step(self, optimizer, model_without_ddp, hook=True):
28 | if hook and hasattr(model_without_ddp, 'hook_before_update'):
29 | model_without_ddp.hook_before_update()
30 |
31 | self._scaler.step(optimizer)
32 |
33 | if hook and hasattr(model_without_ddp, 'hook_after_update'):
34 | model_without_ddp.hook_after_update()
35 |
36 | self.update()
37 |
38 | def update(self):
39 | self._scaler.update()
40 |
41 | def state_dict(self):
42 | return self._scaler.state_dict()
43 |
44 | def load_state_dict(self, state_dict):
45 | self._scaler.load_state_dict(state_dict)
46 |
--------------------------------------------------------------------------------
/convert_memory.py:
--------------------------------------------------------------------------------
1 | """
2 | Use this script in case you saved rehearsal memory on a computer A, but then want
3 | to resume training, using those rehearsal samples, on a computer B.
4 |
5 | Because for ImageNet we save the path, which may be different on each computer.
6 | """
7 |
8 | import sys
9 | import glob
10 | import os
11 | import shutil
12 |
13 | import numpy as np
14 |
15 | memory_path = sys.argv[1]
16 | new_base_path = sys.argv[2]
17 |
18 | if os.path.isdir(memory_path):
19 | memory_paths = glob.glob(os.path.abspath(os.path.join(memory_path, "memory_*.npz")))
20 | else:
21 | memory_paths = [memory_path]
22 |
23 | print(memory_paths)
24 |
25 | for p in sorted(memory_paths):
26 | psrc = p
27 | if not os.path.exists(f"{p}_original"):
28 | shutil.copy(p, f"{p}_original")
29 | else:
30 | psrc = f"{p}_original"
31 | print(p)
32 |
33 | data = np.load(p)
34 | x = []
35 | for img_path in data["x"]:
36 | id_ = str(img_path).lstrip("b'").rstrip("'").split("train")[-1][1:]
37 |
38 | x.append(os.path.join(new_base_path, "train", id_))
39 |
40 | np.savez(
41 | p,
42 | x=np.array(x), y=data["y"], t=data["t"]
43 | )
44 | print("Done!")
45 |
--------------------------------------------------------------------------------
/images/BiRT_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeurAI-Lab/BiRT/1340a97bff17fc02e228b754bd80e1fd649ff9cd/images/BiRT_architecture.png
--------------------------------------------------------------------------------
/options/arthur.yaml:
--------------------------------------------------------------------------------
1 | data_path: /local/douillard/
2 | output_basedir: /local/douillard/transformer/checkpoints
3 |
--------------------------------------------------------------------------------
/options/data/cifar100_10-10.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR
2 | initial_increment: 10
3 | increment: 10
4 | #memory_size: 2000
5 |
6 | log_category: 10-10
7 |
--------------------------------------------------------------------------------
/options/data/cifar100_10-10_500.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR
2 | initial_increment: 10
3 | increment: 10
4 | memory_size: 500
5 |
6 | log_category: 10-10
7 |
--------------------------------------------------------------------------------
/options/data/cifar100_2-2.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR
2 | initial_increment: 2
3 | increment: 2
4 | #memory_size: 2000
5 |
6 | log_category: 2-2
7 |
--------------------------------------------------------------------------------
/options/data/cifar100_20-20.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR
2 | initial_increment: 20
3 | increment: 20
4 | #memory_size: 2000
5 |
6 | log_category: 20-20
7 |
--------------------------------------------------------------------------------
/options/data/cifar100_5-5.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR
2 | initial_increment: 5
3 | increment: 5
4 | #memory_size: 2000
5 |
6 | log_category: 5-5
7 |
--------------------------------------------------------------------------------
/options/data/cifar100_joint.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR
2 |
3 | initial_increment: 100
4 | increment: 100
5 |
6 | log_category: joint
7 |
--------------------------------------------------------------------------------
/options/data/cifar100_order1.yaml:
--------------------------------------------------------------------------------
1 | class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
2 |
--------------------------------------------------------------------------------
/options/data/cifar100_order2.yaml:
--------------------------------------------------------------------------------
1 | class_order: [58, 30, 93, 69, 21, 77, 3, 78, 12, 71, 65, 40, 16, 49, 89, 46, 24, 66, 19, 41, 5, 29, 15, 73, 11, 70, 90, 63, 67, 25, 59, 72, 80, 94, 54, 33, 18, 96, 2, 10, 43, 9, 57, 81, 76, 50, 32, 6, 37, 7, 68, 91, 88, 95, 85, 4, 60, 36, 22, 27, 39, 42, 34, 51, 55, 28, 53, 48, 38, 17, 83, 86, 56, 35, 45, 79, 99, 84, 97, 82, 98, 26, 47, 44, 62, 13, 31, 0, 75, 14, 52, 74, 8, 20, 1, 92, 87, 23, 64, 61]
2 |
--------------------------------------------------------------------------------
/options/data/cifar100_order3.yaml:
--------------------------------------------------------------------------------
1 | class_order: [71, 54, 45, 32, 4, 8, 48, 66, 1, 91, 28, 82, 29, 22, 80, 27, 86, 23, 37, 47, 55, 9, 14, 68, 25, 96, 36, 90, 58, 21, 57, 81, 12, 26, 16, 89, 79, 49, 31, 38, 46, 20, 92, 88, 40, 39, 98, 94, 19, 95, 72, 24, 64, 18, 60, 50, 63, 61, 83, 76, 69, 35, 0, 52, 7, 65, 42, 73, 74, 30, 41, 3, 6, 53, 13, 56, 70, 77, 34, 97, 75, 2, 17, 93, 33, 84, 99, 51, 62, 87, 5, 15, 10, 78, 67, 44, 59, 85, 43, 11]
2 |
--------------------------------------------------------------------------------
/options/data/cifar100_order4.yaml:
--------------------------------------------------------------------------------
1 | class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
--------------------------------------------------------------------------------
/options/data/cifar100_order5.yaml:
--------------------------------------------------------------------------------
1 | class_order: [ 5, 2, 7, 93, 9, 99, 12, 6, 21, 15, 53, 46, 31, 48, 39, 45, 92, 25, 61, 73, 70, 32, 55, 91, 90, 42, 76, 52, 83, 75, 43, 47, 24, 11, 13, 37, 18, 38, 40, 10, 67, 66, 41, 33, 74, 16, 79, 86, 78, 97, 80, 64, 54, 28, 14, 50, 35, 71, 1, 62, 88, 0, 60, 95, 36, 51, 87, 58, 56, 65, 59, 68, 82, 4, 81, 29, 27, 3, 34, 22, 72, 30, 89, 77, 63, 8, 84, 98, 17, 26, 23, 94, 20, 85, 19, 96, 49, 57, 44, 69]
--------------------------------------------------------------------------------
/options/data/cifar10_2-2.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR10
2 | initial_increment: 2
3 | increment: 2
4 | #memory_size: 2000
5 |
6 | log_category: 2-2
7 |
--------------------------------------------------------------------------------
/options/data/cifar10_2-2_500.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR10
2 | initial_increment: 2
3 | increment: 2
4 | memory_size: 500
5 |
6 | log_category: 2-2
7 |
--------------------------------------------------------------------------------
/options/data/cifar10_joint.yaml:
--------------------------------------------------------------------------------
1 | data_set: CIFAR10
2 | initial_increment: 10
3 | increment: 10
4 |
5 | log_category: joint
6 |
--------------------------------------------------------------------------------
/options/data/imagenet1000_100-100.yaml:
--------------------------------------------------------------------------------
1 | data_set: imagenet1000
2 | initial_increment: 100
3 | increment: 100
4 | #memory_size: 20000
5 |
6 | log_category: 100-100
7 |
--------------------------------------------------------------------------------
/options/data/imagenet1000_joint.yaml:
--------------------------------------------------------------------------------
1 | data_set: imagenet1000
2 | initial_increment: 1000
3 | increment: 1000
4 |
5 | log_category: joint
6 |
--------------------------------------------------------------------------------
/options/data/imagenet1000_order1.yaml:
--------------------------------------------------------------------------------
1 | class_order: [54, 7, 894, 512, 126, 337, 988, 11, 284, 493, 133, 783, 192, 979, 622, 215, 240, 548, 238, 419, 274, 108,
2 | 928, 856, 494, 836, 473, 650, 85, 262, 508, 590, 390, 174, 637, 288, 658, 219, 912, 142, 852, 160, 704, 289,
3 | 123, 323, 600, 542, 999, 634, 391, 761, 490, 842, 127, 850, 665, 990, 597, 722, 748, 14, 77, 437, 394, 859,
4 | 279, 539, 75, 466, 886, 312, 303, 62, 966, 413, 959, 782, 509, 400, 471, 632, 275, 730, 105, 523, 224, 186,
5 | 478, 507, 470, 906, 699, 989, 324, 812, 260, 911, 446, 44, 765, 759, 67, 36, 5, 30, 184, 797, 159, 741, 954,
6 | 465, 533, 585, 150, 101, 897, 363, 818, 620, 824, 154, 956, 176, 588, 986, 172, 223, 461, 94, 141, 621, 659,
7 | 360, 136, 578, 163, 427, 70, 226, 925, 596, 336, 412, 731, 755, 381, 810, 69, 898, 310, 120, 752, 93, 39,
8 | 326, 537, 905, 448, 347, 51, 615, 601, 229, 947, 348, 220, 949, 972, 73, 913, 522, 193, 753, 921, 257, 957,
9 | 691, 155, 820, 584, 948, 92, 582, 89, 379, 392, 64, 904, 169, 216, 694, 103, 410, 374, 515, 484, 624, 409,
10 | 156, 455, 846, 344, 371, 468, 844, 276, 740, 562, 503, 831, 516, 663, 630, 763, 456, 179, 996, 936, 248,
11 | 333, 941, 63, 738, 802, 372, 828, 74, 540, 299, 750, 335, 177, 822, 643, 593, 800, 459, 580, 933, 306, 378,
12 | 76, 227, 426, 403, 322, 321, 808, 393, 27, 200, 764, 651, 244, 479, 3, 415, 23, 964, 671, 195, 569, 917,
13 | 611, 644, 707, 355, 855, 8, 534, 657, 571, 811, 681, 543, 313, 129, 978, 592, 573, 128, 243, 520, 887, 892,
14 | 696, 26, 551, 168, 71, 398, 778, 529, 526, 792, 868, 266, 443, 24, 57, 15, 871, 678, 745, 845, 208, 188,
15 | 674, 175, 406, 421, 833, 106, 994, 815, 581, 676, 49, 619, 217, 631, 934, 932, 568, 353, 863, 827, 425, 420,
16 | 99, 823, 113, 974, 438, 874, 343, 118, 340, 472, 552, 937, 0, 10, 675, 316, 879, 561, 387, 726, 255, 407,
17 | 56, 927, 655, 809, 839, 640, 297, 34, 497, 210, 606, 971, 589, 138, 263, 587, 993, 973, 382, 572, 735, 535,
18 | 139, 524, 314, 463, 895, 376, 939, 157, 858, 457, 935, 183, 114, 903, 767, 666, 22, 525, 902, 233, 250, 825,
19 | 79, 843, 221, 214, 205, 166, 431, 860, 292, 976, 739, 899, 475, 242, 961, 531, 110, 769, 55, 701, 532, 586,
20 | 729, 253, 486, 787, 774, 165, 627, 32, 291, 962, 922, 222, 705, 454, 356, 445, 746, 776, 404, 950, 241, 452,
21 | 245, 487, 706, 2, 137, 6, 98, 647, 50, 91, 202, 556, 38, 68, 649, 258, 345, 361, 464, 514, 958, 504, 826,
22 | 668, 880, 28, 920, 918, 339, 315, 320, 768, 201, 733, 575, 781, 864, 617, 171, 795, 132, 145, 368, 147, 327,
23 | 713, 688, 848, 690, 975, 354, 853, 148, 648, 300, 436, 780, 693, 682, 246, 449, 492, 162, 97, 59, 357, 198,
24 | 519, 90, 236, 375, 359, 230, 476, 784, 117, 940, 396, 849, 102, 122, 282, 181, 130, 467, 88, 271, 793, 151,
25 | 847, 914, 42, 834, 521, 121, 29, 806, 607, 510, 837, 301, 669, 78, 256, 474, 840, 52, 505, 547, 641, 987,
26 | 801, 629, 491, 605, 112, 429, 401, 742, 528, 87, 442, 910, 638, 785, 264, 711, 369, 428, 805, 744, 380, 725,
27 | 480, 318, 997, 153, 384, 252, 985, 538, 654, 388, 100, 432, 832, 565, 908, 367, 591, 294, 272, 231, 213,
28 | 196, 743, 817, 433, 328, 970, 969, 4, 613, 182, 685, 724, 915, 311, 931, 865, 86, 119, 203, 268, 718, 317,
29 | 926, 269, 161, 209, 807, 645, 513, 261, 518, 305, 758, 872, 58, 65, 146, 395, 481, 747, 41, 283, 204, 564,
30 | 185, 777, 33, 500, 609, 286, 567, 80, 228, 683, 757, 942, 134, 673, 616, 960, 450, 350, 544, 830, 736, 170,
31 | 679, 838, 819, 485, 430, 190, 566, 511, 482, 232, 527, 411, 560, 281, 342, 614, 662, 47, 771, 861, 692, 686,
32 | 277, 373, 16, 946, 265, 35, 9, 884, 909, 610, 358, 18, 737, 977, 677, 803, 595, 135, 458, 12, 46, 418, 599,
33 | 187, 107, 992, 770, 298, 104, 351, 893, 698, 929, 502, 273, 20, 96, 791, 636, 708, 267, 867, 772, 604, 618,
34 | 346, 330, 554, 816, 664, 716, 189, 31, 721, 712, 397, 43, 943, 804, 296, 109, 576, 869, 955, 17, 506, 963,
35 | 786, 720, 628, 779, 982, 633, 891, 734, 980, 386, 365, 794, 325, 841, 878, 370, 695, 293, 951, 66, 594, 717,
36 | 116, 488, 796, 983, 646, 499, 53, 1, 603, 45, 424, 875, 254, 237, 199, 414, 307, 362, 557, 866, 341, 19,
37 | 965, 143, 555, 687, 235, 790, 125, 173, 364, 882, 727, 728, 563, 495, 21, 558, 709, 719, 877, 352, 83, 998,
38 | 991, 469, 967, 760, 498, 814, 612, 715, 290, 72, 131, 259, 441, 924, 773, 48, 625, 501, 440, 82, 684, 862,
39 | 574, 309, 408, 680, 623, 439, 180, 652, 968, 889, 334, 61, 766, 399, 598, 798, 653, 930, 149, 249, 890, 308,
40 | 881, 40, 835, 577, 422, 703, 813, 857, 995, 602, 583, 167, 670, 212, 751, 496, 608, 84, 639, 579, 178, 489,
41 | 37, 197, 789, 530, 111, 876, 570, 700, 444, 287, 366, 883, 385, 536, 460, 851, 81, 144, 60, 251, 13, 953,
42 | 270, 944, 319, 885, 710, 952, 517, 278, 656, 919, 377, 550, 207, 660, 984, 447, 553, 338, 234, 383, 749,
43 | 916, 626, 462, 788, 434, 714, 799, 821, 477, 549, 661, 206, 667, 541, 642, 689, 194, 152, 981, 938, 854,
44 | 483, 332, 280, 546, 389, 405, 545, 239, 896, 672, 923, 402, 423, 907, 888, 140, 870, 559, 756, 25, 211, 158,
45 | 723, 635, 302, 702, 453, 218, 164, 829, 247, 775, 191, 732, 115, 331, 901, 416, 873, 754, 900, 435, 762,
46 | 124, 304, 329, 349, 295, 95, 451, 285, 225, 945, 697, 417]
47 |
--------------------------------------------------------------------------------
/options/data/imagenet100_10-10.yaml:
--------------------------------------------------------------------------------
1 | data_set: imagenet100
2 | initial_increment: 10
3 | increment: 10
4 | #memory_size: 2000
5 |
6 | log_category: 10-10
7 |
--------------------------------------------------------------------------------
/options/data/imagenet100_joint.yaml:
--------------------------------------------------------------------------------
1 | data_set: imagenet100
2 | initial_increment: 100
3 | increment: 100
4 |
5 | log_category: joint
6 |
--------------------------------------------------------------------------------
/options/data/imagenet100_order1.yaml:
--------------------------------------------------------------------------------
1 | class_order: [68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
2 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
3 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
4 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33]
5 |
--------------------------------------------------------------------------------
/options/data/imagenet100_order2.yaml:
--------------------------------------------------------------------------------
1 | class_order: [38, 34, 87, 75, 81, 17, 99, 67, 69, 12, 28, 25, 32, 42, 61, 5, 82,
2 | 58, 13, 94, 47, 51, 84, 39, 49, 8, 57, 1, 55, 36, 10, 20, 64, 2,
3 | 78, 44, 43, 48, 23, 21, 53, 37, 74, 27, 92, 77, 98, 18, 79, 66, 90,
4 | 46, 68, 26, 9, 83, 80, 30, 22, 16, 29, 97, 41, 85, 0, 52, 15, 14,
5 | 86, 63, 24, 59, 54, 11, 76, 70, 6, 7, 56, 93, 89, 50, 71, 4, 19,
6 | 88, 96, 62, 3, 31, 91, 95, 72, 60, 73, 65, 33, 45, 40, 35]
7 |
--------------------------------------------------------------------------------
/options/data/imagenet100_order3.yaml:
--------------------------------------------------------------------------------
1 | class_order: [86, 37, 98, 64, 38, 90, 70, 27, 9, 44, 3, 59, 94, 57, 52, 89, 5,
2 | 99, 72, 10, 78, 41, 32, 7, 61, 71, 23, 46, 29, 74, 42, 11, 21, 54,
3 | 45, 53, 77, 24, 35, 22, 88, 2, 49, 1, 17, 31, 58, 0, 50, 73, 96,
4 | 33, 62, 56, 97, 87, 20, 6, 55, 80, 51, 66, 85, 63, 18, 15, 67, 76,
5 | 26, 75, 65, 14, 36, 68, 43, 40, 92, 82, 39, 93, 25, 91, 84, 79, 34,
6 | 12, 69, 16, 48, 81, 47, 60, 8, 95, 19, 4, 13, 30, 83, 28]
--------------------------------------------------------------------------------
/options/data/tinyimg_20-20.yaml:
--------------------------------------------------------------------------------
1 | data_set: TINYIMG
2 | initial_increment: 20
3 | increment: 20
4 | #memory_size: 2000
5 |
6 | log_category: 20-20
7 |
--------------------------------------------------------------------------------
/options/data/tinyimg_joint.yaml:
--------------------------------------------------------------------------------
1 | data_set: TINYIMG
2 | initial_increment: 200
3 | increment: 200
4 |
5 | log_category: joint
6 |
--------------------------------------------------------------------------------
/options/data/tinyimg_order1.yaml:
--------------------------------------------------------------------------------
1 | class_order: [154, 32, 34, 177, 45, 56, 167, 194, 22, 81, 118, 190, 124, 101, 78, 33, 153, 83, 55, 76, 179, 64, 163, 66, 59, 156, 88, 108, 48, 171, 119, 123, 44, 172, 19, 120, 182, 93, 39, 187, 164, 180, 30, 100, 188, 175, 192, 129, 104, 4, 86, 72, 102, 173, 98, 110, 132, 122, 144, 99, 197, 155, 106, 165, 89, 54, 152, 11, 189, 195, 9, 184, 79, 109, 28, 134, 84, 27, 121, 136, 169, 117, 21, 199, 142, 133, 157, 43, 51, 73, 87, 68, 50, 25, 139, 112, 168, 77, 131, 196, 12, 170, 75, 38, 185, 31, 15, 46, 115, 166, 159, 130, 160, 113, 186, 94, 80, 96, 37, 97, 65, 82, 63, 49, 181, 114, 18, 91, 71, 191, 7, 126, 135, 127, 10, 16, 176, 42, 17, 1, 5, 41, 8, 151, 74, 26, 149, 14, 69, 47, 36, 85, 150, 174, 183, 161, 52, 29, 58, 60, 128, 116, 111, 13, 95, 107, 70, 140, 162, 148, 20, 23, 53, 158, 67, 2, 57, 92, 0, 3, 141, 90, 147, 145, 198, 105, 146, 24, 35, 61, 103, 138, 178, 137, 193, 143, 62, 6, 40, 125]
--------------------------------------------------------------------------------
/options/data/tinyimg_order2.yaml:
--------------------------------------------------------------------------------
1 | class_order: [ 48, 131, 195, 39, 16, 103, 164, 179, 59, 198, 171, 94, 134, 112, 185, 0, 168, 100, 163, 83, 50, 28, 141, 13, 110, 189, 91, 26, 150, 142, 82, 180, 101, 144, 57, 1, 122, 71, 160, 104, 174, 86, 14, 9, 191, 19, 32, 132, 33, 135, 7, 97, 44, 115, 66, 116, 165, 156, 178, 52, 80, 63, 118, 22, 42, 113, 69, 123, 151, 84, 31, 36, 126, 18, 47, 133, 169, 24, 147, 85, 148, 67, 92, 172, 106, 54, 99, 111, 5, 70, 89, 53, 23, 117, 196, 75, 170, 96, 11, 76, 199, 152, 2, 109, 61, 8, 15, 114, 186, 68, 6, 130, 159, 21, 173, 137, 167, 162, 38, 41, 65, 102, 51, 35, 3, 4, 149, 98, 176, 56, 139, 193, 20, 45, 64, 143, 157, 90, 153, 37, 95, 155, 183, 128, 81, 182, 60, 34, 87, 175, 29, 194, 12, 58, 181, 197, 55, 120, 72, 124, 184, 161, 88, 17, 93, 78, 30, 79, 138, 127, 154, 10, 121, 25, 145, 46, 187, 108, 158, 77, 177, 136, 40, 140, 105, 107, 73, 125, 146, 119, 27, 188, 166, 74, 192, 190, 129, 43, 49, 62]
--------------------------------------------------------------------------------
/options/data/tinyimg_order3.yaml:
--------------------------------------------------------------------------------
1 | class_order: [107, 74, 135, 156, 180, 57, 79, 174, 115, 29, 40, 170, 164, 98, 31, 133, 68, 161, 188, 52, 53, 0, 165, 55, 197, 87, 147, 82, 3, 75, 172, 27, 124, 142, 45, 84, 67, 121, 138, 182, 94, 69, 118, 5, 105, 83, 100, 176, 61, 95, 37, 85, 18, 187, 117, 152, 104, 43, 51, 139, 23, 126, 70, 193, 177, 88, 185, 71, 48, 63, 9, 160, 155, 158, 81, 129, 12, 72, 130, 134, 33, 169, 14, 89, 166, 140, 157, 16, 32, 28, 15, 7, 6, 131, 183, 30, 90, 78, 8, 109, 21, 119, 191, 120, 179, 49, 175, 114, 116, 122, 17, 111, 159, 24, 132, 137, 145, 44, 58, 141, 150, 198, 192, 143, 10, 110, 195, 60, 136, 92, 144, 153, 127, 20, 39, 4, 86, 154, 181, 184, 125, 50, 108, 151, 13, 80, 19, 194, 64, 2, 47, 1, 123, 46, 38, 77, 76, 22, 42, 199, 171, 162, 103, 168, 106, 26, 35, 128, 112, 66, 41, 59, 62, 73, 186, 36, 178, 163, 97, 56, 190, 149, 101, 34, 102, 173, 91, 189, 93, 54, 65, 167, 148, 196, 146, 96, 25, 99, 113, 11]
--------------------------------------------------------------------------------
/options/model/cifar_birt.yaml:
--------------------------------------------------------------------------------
1 | #######################
2 | # DyTox, for CIFAR100 #
3 | #######################
4 |
5 | # Model definition
6 | model: convit
7 | embed_dim: 384
8 | depth: 6
9 | num_heads: 12
10 | patch_size: 4
11 | input_size: 32
12 | local_up_to_layer: 5
13 | class_attention: true
14 |
15 | # Training setting
16 | no_amp: true
17 | eval_every: 50
18 |
19 | # Base hyperparameter
20 | weight_decay: 0.000001
21 | batch_size: 128
22 | incremental_batch_size: 128
23 | incremental_lr: 0.0005
24 | rehearsal: icarl_all
25 |
26 | # Knowledge Distillation
27 | auto_kd: true
28 |
29 | # Finetuning
30 | finetuning: balanced
31 | finetuning_epochs: 20
32 |
33 | # Dytox model
34 | dytox: true
35 | freeze_task: [old_task_tokens, old_heads]
36 | freeze_ft: [sab]
37 |
38 | # Divergence head to get diversity
39 | # head_div: 0
40 | head_div: 0.1
41 | head_div_mode: tr
42 |
43 | # Independent Classifiers
44 | ind_clf: 1-1
45 | bce_loss: true
46 |
47 |
48 | # Advanced Augmentations, here disabled
49 |
50 | ## Erasing
51 | reprob: 0.0
52 | remode: pixel
53 | recount: 1
54 | resplit: false
55 |
56 | ## MixUp & CutMix
57 | mixup: 0.0
58 | cutmix: 0.0
59 |
--------------------------------------------------------------------------------
/options/model/imagenet_birt.yaml:
--------------------------------------------------------------------------------
1 | #######################
2 | # DyTox, for CIFAR100 #
3 | #######################
4 |
5 | # Model definition
6 | model: convit
7 | embed_dim: 384
8 | depth: 6
9 | num_heads: 12
10 | patch_size: 16
11 | input_size: 224
12 | local_up_to_layer: 5
13 | class_attention: true
14 |
15 | #batch_size: 64
16 | #incremental_batch_size: 64
17 |
18 | # Training setting
19 | no_amp: false
20 | eval_every: 250
21 |
22 | # Base hyperparameter
23 | weight_decay: 0.000001
24 | batch_size: 128
25 | incremental_batch_size: 128
26 | incremental_lr: 0.0005
27 | rehearsal: icarl_all
28 |
29 | # Knowledge Distillation
30 | auto_kd: true
31 |
32 | # Finetuning
33 | finetuning: balanced
34 | finetuning_epochs: 20
35 |
36 | # Dytox model
37 | dytox: true
38 | freeze_task: [old_task_tokens, old_heads]
39 | freeze_ft: [sab]
40 |
41 | # Divergence head to get diversity
42 | head_div: 0.1
43 | head_div_mode: tr
44 |
45 | # Independent Classifiers
46 | ind_clf: 1-1
47 | bce_loss: true
48 |
49 |
50 | # Advanced Augmentations, here disabled
51 |
52 | ## Erasing
53 | reprob: 0.0
54 | remode: pixel
55 | recount: 1
56 | resplit: false
57 |
58 | ## MixUp & CutMix
59 | mixup: 0.0
60 | cutmix: 0.0
61 |
--------------------------------------------------------------------------------
/options/model/tinyimg_birt.yaml:
--------------------------------------------------------------------------------
1 | #######################
2 | # DyTox, for CIFAR100 #
3 | #######################
4 |
5 | # Model definition
6 | model: convit
7 | embed_dim: 384
8 | depth: 6
9 | num_heads: 12
10 | patch_size: 8
11 | input_size: 64
12 | local_up_to_layer: 5
13 | class_attention: true
14 |
15 | # Training setting
16 | no_amp: true
17 | eval_every: 50
18 |
19 | # Base hyperparameter
20 | weight_decay: 0.000001
21 | batch_size: 128
22 | incremental_batch_size: 128
23 | incremental_lr: 0.0005
24 | rehearsal: icarl_all
25 |
26 | # Knowledge Distillation
27 | auto_kd: true
28 |
29 | # Finetuning
30 | finetuning: balanced
31 | finetuning_epochs: 20
32 |
33 | # Dytox model
34 | dytox: true
35 | freeze_task: [old_task_tokens, old_heads]
36 | freeze_ft: [sab]
37 |
38 | # Divergence head to get diversity
39 | # head_div: 0
40 | head_div: 0.1
41 | head_div_mode: tr
42 |
43 | # Independent Classifiers
44 | ind_clf: 1-1
45 | bce_loss: true
46 |
47 |
48 | # Advanced Augmentations, here disabled
49 |
50 | ## Erasing
51 | reprob: 0.0
52 | remode: pixel
53 | recount: 1
54 | resplit: false
55 |
56 | ## MixUp & CutMix
57 | mixup: 0.0
58 | cutmix: 0.0
59 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7.0
2 | torchvision>=0.8.1
3 | timm>=0.3.2
4 | continuum>=1.0.27
5 |
--------------------------------------------------------------------------------