├── .gitignore ├── LICENSE ├── README.md ├── continual ├── __init__.py ├── birt.py ├── classifier.py ├── cnn │ ├── __init__.py │ ├── abstract.py │ ├── inception.py │ ├── resnet.py │ ├── resnet_rebuffi.py │ ├── resnet_scs.py │ ├── senet.py │ └── vgg.py ├── convit.py ├── datasets.py ├── engine.py ├── factory.py ├── losses.py ├── misc.py ├── mixup.py ├── mycontinual │ ├── __init__.py │ ├── custom_array_task_set.py │ ├── incremental_rotation.py │ ├── permutations.py │ ├── rotations.py │ └── transformation_incremental.py ├── rehearsal.py ├── sam.py ├── samplers.py ├── scaler.py ├── utils.py └── vit.py ├── convert_memory.py ├── imagenet100_splits ├── train_100.txt └── val_100.txt ├── images └── BiRT_architecture.png ├── main.py ├── options ├── arthur.yaml ├── data │ ├── cifar100_10-10.yaml │ ├── cifar100_10-10_500.yaml │ ├── cifar100_2-2.yaml │ ├── cifar100_20-20.yaml │ ├── cifar100_5-5.yaml │ ├── cifar100_joint.yaml │ ├── cifar100_order1.yaml │ ├── cifar100_order2.yaml │ ├── cifar100_order3.yaml │ ├── cifar100_order4.yaml │ ├── cifar100_order5.yaml │ ├── cifar10_2-2.yaml │ ├── cifar10_2-2_500.yaml │ ├── cifar10_joint.yaml │ ├── imagenet1000_100-100.yaml │ ├── imagenet1000_joint.yaml │ ├── imagenet1000_order1.yaml │ ├── imagenet100_10-10.yaml │ ├── imagenet100_joint.yaml │ ├── imagenet100_order1.yaml │ ├── imagenet100_order2.yaml │ ├── imagenet100_order3.yaml │ ├── tinyimg_20-20.yaml │ ├── tinyimg_joint.yaml │ ├── tinyimg_order1.yaml │ ├── tinyimg_order2.yaml │ └── tinyimg_order3.yaml └── model │ ├── cifar_birt.yaml │ ├── imagenet_birt.yaml │ └── tinyimg_birt.yaml └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | checkpoints/ 4 | logs/ 5 | .vscode 6 | .ipynb_checkpoints 7 | outputs/ 8 | .idea/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BiRT 2 | 3 | This is the official repository of the ICML, 2023 paper "[BiRT: Bio-inspired Replay in Vision Transformers for Continual Learning](https://arxiv.org/abs/2305.04769)" by [Kishaan Jeeveswaran](https://scholar.google.com/citations?user=JcqW3_QAAAAJ&hl=en), [Prashant Bhat](https://scholar.google.com/citations?hl=en&user=jrEETfgAAAAJ), [Bahram Zonooz](https://scholar.google.com/citations?hl=en&user=FZmIlY8AAAAJ) and [Elahe Arani](https://scholar.google.com/citations?user=e_I_v6cAAAAJ&hl=en). 4 | 5 | TLDR: A novel representation rehearsal-based continual learning approach that, by incorporating constructive noises at various stages of the vision transformer and regularization, enables effective and memory-efficient Continual Learning. 6 | 7 | ### Schematic of MTSL: 8 | ![image info](./images/BiRT_architecture.png) 9 | 10 | ## Setup: 11 | 12 | OUTPUT_DIR: Directory to save output contents.
13 | DATA_DIR: Directory containing the datasets.
14 | 15 | ### Datasets supported:
16 | 17 | * CIFAR-100 18 | * ImageNet-100 19 | * Tiny ImageNet 20 | 21 | 22 | ### BiRT Training Script: 23 | 24 | To train BiRT on CIFAR-100 dataset 10 tasks setting with buffer size 500: 25 | ``` 26 | python main.py --seed 42 --options options/data/cifar100_10-10.yaml options/data/cifar100_order1.yaml options/model/cifar_birt.yaml --data-path --output-basedir --base-epochs 500 --batch_mixup --batch_logitnoise --ema_alpha 0.001 --ema_frequency 0.003 --distill_version l2 --distill_weight 0.05 --distill_weight_buffer 0.001 --rep_noise_weight 1.0 --repnoise_prob 0.5 --finetune_weight 2 --representation_replay --replay_from 1 --sep_memory --num_workers 8 --csv_filename results.csv --memory-size 500 --tensorboard --epochs 500 27 | ``` 28 | 29 | ### Hyperparameters for other settings:
30 | 31 | | Dataset | Num of Tasks | Buffer Size | ema_alpha | ema_frequency | distill_weight | distill_weight_buffer | 32 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 33 | | CIFAR-100 | 5 | 200 | 0.0005 | 0.001 | 0.05 | 0.01 | 34 | | | | 500 | 0.005 | 0.003 | 0.05 | 0.01 | 35 | | | 10 | 200 | 0.001 | 0.003 | 0.05 | 0.001 | 36 | | | | 500 | 0.001 | 0.003 | 0.05 | 0.001 | 37 | | | | 1000 | 0.0005 | 0.0008 | 0.05 | 0.01 | 38 | | | | 2000 | 0.0002 | 0.0015 | 0.05 | 0.01 | 39 | | | 20 | 200 | 0.005 | 0.001 | 0.05 | 0.08 | 40 | | | | 500 | 0.0005 | 0.003 | 0.05 | 0.1 | 41 | | TINYIMAGENET | 10 | 500 | 0.001 | 0.003 | 0.05 | 0.01 | 42 | | | | 1000 | 0.01 | 0.0008 | 0.01 | 0.001 | 43 | | | | 2000 | 0.0001 | 0.008 | 0.01 | 0.0008 | 44 | | IMAGENET- 100 | 10 | 500 | 0.0001 | 0.003 | 0.05 | 0.001 | 45 | | | | 1000 | 0.0001 | 0.003 | 0.05 | 0.001 | 46 | | | | 2000 | 0.01 | 0.005 | 0.01 | 0.001 | 47 | 48 | ## Cite Our Work: 49 | 50 | If you find the code useful in your research please consider citing our paper: 51 | 52 |
53 | @article{jeeveswaran2023birt,
54 |   title={BiRT: Bio-inspired Replay in Vision Transformers for Continual Learning},
55 |   author={Jeeveswaran, Kishaan and Bhat, Prashant and Zonooz, Bahram and Arani, Elahe},
56 |   journal={arXiv preprint arXiv:2305.04769},
57 |   year={2023}
58 | }
59 | 
60 | -------------------------------------------------------------------------------- /continual/__init__.py: -------------------------------------------------------------------------------- 1 | from continual import rehearsal 2 | from continual import classifier 3 | from continual import vit 4 | from continual import convit 5 | from continual import utils 6 | from continual import scaler 7 | from continual import cnn 8 | from continual import factory 9 | from continual import sam 10 | from continual import samplers 11 | from continual import mixup 12 | -------------------------------------------------------------------------------- /continual/birt.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from timm.models.layers import trunc_normal_ 5 | from torch import nn 6 | 7 | import continual.utils as cutils 8 | 9 | 10 | class ContinualClassifier(nn.Module): 11 | """Your good old classifier to do continual.""" 12 | def __init__(self, embed_dim, nb_classes): 13 | super().__init__() 14 | 15 | self.embed_dim = embed_dim 16 | self.nb_classes = nb_classes 17 | self.head = nn.Linear(embed_dim, nb_classes, bias=True) 18 | self.norm = nn.LayerNorm(embed_dim) 19 | 20 | def reset_parameters(self): 21 | self.head.reset_parameters() 22 | self.norm.reset_parameters() 23 | 24 | def forward(self, x): 25 | x = self.norm(x) 26 | return self.head(x) 27 | 28 | def add_new_outputs(self, n): 29 | head = nn.Linear(self.embed_dim, self.nb_classes + n, bias=True) 30 | head.weight.data[:-n] = self.head.weight.data 31 | 32 | head.to(self.head.weight.device) 33 | self.head = head 34 | self.nb_classes += n 35 | 36 | 37 | class BiRT(nn.Module): 38 | """" 39 | :param transformer: The base transformer. 40 | :param nb_classes: Thhe initial number of classes. 41 | :param individual_classifier: Classifier config, DyTox is in `1-1`. 42 | :param head_div: Whether to use the divergence head for improved diversity. 43 | :param head_div_mode: Use the divergence head in TRaining, FineTuning, or both. 44 | :param joint_tokens: Use a single TAB forward with masked attention (faster but a bit worse). 45 | """ 46 | def __init__( 47 | self, 48 | transformer, 49 | nb_classes, 50 | individual_classifier='', 51 | head_div=False, 52 | head_div_mode=['tr', 'ft'], 53 | joint_tokens=False, 54 | num_blocks=None, 55 | multi_token_setup=False 56 | ): 57 | super().__init__() 58 | 59 | self.nb_classes = nb_classes # 2 60 | self.embed_dim = transformer.embed_dim 61 | self.individual_classifier = individual_classifier 62 | self.use_head_div = head_div # true 63 | self.head_div_mode = head_div_mode # tr 64 | self.head_div = None 65 | self.joint_tokens = joint_tokens # False 66 | self.in_finetuning = False 67 | self.multi_token_setup = multi_token_setup 68 | 69 | self.num_blocks = num_blocks 70 | 71 | 72 | self.nb_classes_per_task = [nb_classes] 73 | 74 | self.patch_embed = transformer.patch_embed 75 | self.pos_embed = transformer.pos_embed 76 | self.pos_drop = transformer.pos_drop 77 | self.sabs = transformer.blocks[:-1] 78 | 79 | self.tabs = transformer.blocks[-1:] 80 | 81 | self.task_tokens = nn.ParameterList([transformer.cls_token]) 82 | 83 | if self.individual_classifier != '': 84 | in_dim, out_dim = self._get_ind_clf_dim() # 384, 10 85 | self.head = nn.ModuleList([ 86 | ContinualClassifier(in_dim, out_dim).cuda() 87 | ]) 88 | else: 89 | self.head = ContinualClassifier( 90 | self.embed_dim * len(self.task_tokens), sum(self.nb_classes_per_task) 91 | ).cuda() 92 | 93 | def end_finetuning(self): 94 | """Start FT mode, usually with backbone freezed and balanced classes.""" 95 | self.in_finetuning = False 96 | 97 | def begin_finetuning(self): 98 | """End FT mode, usually with backbone freezed and balanced classes.""" 99 | self.in_finetuning = True 100 | 101 | def add_model(self, nb_new_classes, multi_token_setup=False): 102 | """Expand model as per the DyTox framework given `nb_new_classes`. 103 | 104 | :param nb_new_classes: Number of new classes brought by the new task. 105 | """ 106 | self.nb_classes_per_task.append(nb_new_classes) 107 | 108 | # Class tokens --------------------------------------------------------- 109 | new_task_token = copy.deepcopy(self.task_tokens[-1]) 110 | trunc_normal_(new_task_token, std=.02) 111 | self.task_tokens.append(new_task_token) 112 | # ---------------------------------------------------------------------- 113 | 114 | # Diversity head ------------------------------------------------------- 115 | if self.use_head_div: 116 | self.head_div = ContinualClassifier( 117 | self.sabs[-1].dim, self.nb_classes_per_task[-1] + 1 118 | ).cuda() 119 | # ---------------------------------------------------------------------- 120 | 121 | # Classifier ----------------------------------------------------------- 122 | if self.individual_classifier != '' and not multi_token_setup: 123 | in_dim, out_dim = self._get_ind_clf_dim() 124 | self.head.append( 125 | ContinualClassifier(in_dim, out_dim).cuda() 126 | ) 127 | elif not multi_token_setup: 128 | self.head = ContinualClassifier( 129 | self.embed_dim * len(self.task_tokens), sum(self.nb_classes_per_task) 130 | ).cuda() 131 | # ---------------------------------------------------------------------- 132 | 133 | def _get_ind_clf_dim(self): 134 | """What are the input and output dim of classifier depending on its config. 135 | 136 | By default, DyTox is in 1-1. 137 | """ 138 | if self.individual_classifier == '1-1': 139 | in_dim = self.sabs[-1].dim 140 | out_dim = self.nb_classes_per_task[-1] 141 | elif self.individual_classifier == '1-n': 142 | in_dim = self.embed_dim 143 | out_dim = sum(self.nb_classes_per_task) 144 | elif self.individual_classifier == 'n-n': 145 | in_dim = len(self.task_tokens) * self.embed_dim 146 | out_dim = sum(self.nb_classes_per_task) 147 | elif self.individual_classifier == 'n-1': 148 | in_dim = len(self.task_tokens) * self.embed_dim 149 | out_dim = self.nb_classes_per_task[-1] 150 | else: 151 | raise NotImplementedError(f'Unknown ind classifier {self.individual_classifier}') 152 | return in_dim, out_dim 153 | 154 | def freeze(self, names): 155 | """Choose what to freeze depending on the name of the module.""" 156 | requires_grad = False 157 | cutils.freeze_parameters(self, requires_grad=not requires_grad) 158 | self.train() 159 | 160 | for name in names: 161 | if name == 'all': 162 | self.eval() 163 | return cutils.freeze_parameters(self) 164 | elif name == 'multitoken_all': 165 | # self.eval() 166 | return cutils.freeze_parameters(self) 167 | elif name == 'old_task_tokens': 168 | cutils.freeze_parameters(self.task_tokens[:-1], requires_grad=requires_grad) 169 | elif name == 'freeze_token': 170 | cutils.freeze_parameters(self.task_tokens[-1], requires_grad=requires_grad) 171 | elif name == 'task_tokens': 172 | cutils.freeze_parameters(self.task_tokens, requires_grad=requires_grad) 173 | elif name == 'sab': 174 | self.sabs.eval() 175 | cutils.freeze_parameters(self.patch_embed, requires_grad=requires_grad) 176 | cutils.freeze_parameters(self.pos_embed, requires_grad=requires_grad) 177 | cutils.freeze_parameters(self.sabs, requires_grad=requires_grad) 178 | elif name == 'partial_sab': 179 | cutils.freeze_parameters(self.patch_embed, requires_grad=requires_grad) 180 | cutils.freeze_parameters(self.pos_embed, requires_grad=requires_grad) 181 | cutils.freeze_parameters(self.sabs[:self.num_blocks], requires_grad=requires_grad) 182 | elif name == 'tab': 183 | self.tabs.eval() 184 | cutils.freeze_parameters(self.tabs, requires_grad=requires_grad) 185 | elif name == 'old_heads': 186 | self.head[:-1].eval() 187 | cutils.freeze_parameters(self.head[:-1], requires_grad=requires_grad) 188 | elif name == 'heads': 189 | self.head.eval() 190 | cutils.freeze_parameters(self.head, requires_grad=requires_grad) 191 | elif name == 'head_div': 192 | self.head_div.eval() 193 | cutils.freeze_parameters(self.head_div, requires_grad=requires_grad) 194 | else: 195 | raise NotImplementedError(f'Unknown name={name}.') 196 | 197 | def param_groups(self): 198 | return { 199 | 'all': self.parameters(), 200 | 'old_task_tokens': self.task_tokens[:-1], 201 | 'task_tokens': self.task_tokens.parameters(), 202 | 'new_task_tokens': [self.task_tokens[-1]], 203 | 'sa': self.sabs.parameters(), 204 | 'patch': self.patch_embed.parameters(), 205 | 'pos': [self.pos_embed], 206 | 'ca': self.tabs.parameters(), 207 | 'old_heads': self.head[:-self.nb_classes_per_task[-1]].parameters() \ 208 | if self.individual_classifier else \ 209 | self.head.parameters(), 210 | 'new_head': self.head[-1].parameters() if self.individual_classifier else self.head.parameters(), 211 | 'head': self.head.parameters(), 212 | 'head_div': self.head_div.parameters() if self.head_div is not None else None 213 | } 214 | 215 | def reset_classifier(self): 216 | if isinstance(self.head, nn.ModuleList): 217 | for head in self.head: 218 | head.reset_parameters() 219 | else: 220 | self.head.reset_parameters() 221 | 222 | def hook_before_update(self): 223 | pass 224 | 225 | def hook_after_update(self): 226 | pass 227 | 228 | def hook_after_epoch(self): 229 | pass 230 | 231 | def epoch_log(self): 232 | """Write here whatever you want to log on the internal state of the model.""" 233 | log = {} 234 | 235 | # Compute mean distance between class tokens 236 | mean_dist, min_dist, max_dist = [], float('inf'), 0. 237 | with torch.no_grad(): 238 | for i in range(len(self.task_tokens)): 239 | for j in range(i + 1, len(self.task_tokens)): 240 | dist = torch.norm(self.task_tokens[i] - self.task_tokens[j], p=2).item() 241 | mean_dist.append(dist) 242 | 243 | min_dist = min(dist, min_dist) 244 | max_dist = max(dist, max_dist) 245 | 246 | if len(mean_dist) > 0: 247 | mean_dist = sum(mean_dist) / len(mean_dist) 248 | else: 249 | mean_dist = 0. 250 | min_dist = 0. 251 | 252 | assert min_dist <= mean_dist <= max_dist, (min_dist, mean_dist, max_dist) 253 | log['token_mean_dist'] = round(mean_dist, 5) 254 | log['token_min_dist'] = round(min_dist, 5) 255 | log['token_max_dist'] = round(max_dist, 5) 256 | return log 257 | 258 | def get_internal_losses(self, clf_loss): 259 | """If you want to compute some internal loss, like a EWC loss for example. 260 | 261 | :param clf_loss: The main classification loss (if you wanted to use its gradient for example). 262 | :return: a dictionnary of losses, all values will be summed in the final loss. 263 | """ 264 | int_losses = {} 265 | return int_losses 266 | 267 | def forward_initial(self, x): 268 | # Shared part, this is the ENCODER 269 | B = x.shape[0] 270 | 271 | x = self.patch_embed(x) 272 | if self.pos_embed is not None: 273 | x = x + self.pos_embed 274 | x = self.pos_drop(x) 275 | 276 | for blk in self.sabs[:self.num_blocks]: 277 | x, attn, v = blk(x) 278 | 279 | return x 280 | 281 | def forward_latter(self, x, args=None): 282 | B = x.shape[0] 283 | s_e, s_a, s_v = [], [], [] 284 | for blk in self.sabs[self.num_blocks:]: 285 | x, attn, v = blk(x, args=args) 286 | s_e.append(x) 287 | s_a.append(attn) 288 | s_v.append(v) 289 | 290 | # Specific part, this is what we called the "task specific DECODER" 291 | if self.joint_tokens: 292 | return self.forward_features_jointtokens(x) 293 | 294 | tokens = [] 295 | attentions = [] 296 | mask_heads = None 297 | 298 | for task_token in self.task_tokens: 299 | task_token = task_token.expand(B, -1, -1) 300 | 301 | ca_blocks = self.tabs 302 | 303 | for blk in ca_blocks: 304 | task_token, attn, v = blk(torch.cat((task_token, x), dim=1), mask_heads=mask_heads) 305 | 306 | attentions.append(attn) 307 | tokens.append(task_token[:, 0]) 308 | 309 | self._class_tokens = tokens 310 | return self.forward_classifier(tokens, tokens[-1], attentions) 311 | 312 | def forward_features_multitoken(self, x, batch_tasks=None): 313 | # Shared part, this is the ENCODER 314 | B = x.shape[0] 315 | 316 | x = self.patch_embed(x) 317 | x = x + self.pos_embed 318 | x = self.pos_drop(x) 319 | 320 | s_e, s_a, s_v = [], [], [] 321 | for blk in self.sabs: 322 | x, attn, v = blk(x) 323 | s_e.append(x) 324 | if attn is not None: 325 | s_a.append(attn) 326 | s_v.append(v) 327 | 328 | # Specific part, this is what we called the "task specific DECODER" 329 | if self.joint_tokens: 330 | return self.forward_features_jointtokens(x) 331 | 332 | tokens = [] 333 | attentions = [] 334 | mask_heads = None 335 | 336 | if self.training: 337 | x_tokens = torch.cat((torch.cat([self.task_tokens[i] for i in batch_tasks]), x), dim=1) 338 | ca_blocks = self.tabs 339 | for blk in ca_blocks: 340 | task_token, attn, v = blk(x_tokens, mask_heads=mask_heads) 341 | 342 | attentions.append(attn.unsqueeze(dim=0)) 343 | tokens.append(task_token[:, 0]) 344 | 345 | self._class_tokens = tokens 346 | return tokens, tokens[-1], attentions 347 | 348 | else: 349 | for task_token in self.task_tokens: 350 | task_token = task_token.expand(B, -1, -1) 351 | 352 | ca_blocks = self.tabs 353 | 354 | for blk in ca_blocks: 355 | task_token, attn, v = blk(torch.cat((task_token, x), dim=1), mask_heads=mask_heads) 356 | 357 | attentions.append(attn.unsqueeze(dim=0)) 358 | tokens.append(task_token[:, 0]) 359 | 360 | self._class_tokens = tokens 361 | return tokens, tokens[-1], attentions 362 | 363 | def forward_features(self, x, batch_tasks=None): 364 | # Shared part, this is the ENCODER 365 | B = x.shape[0] 366 | 367 | x = self.patch_embed(x) 368 | if self.pos_embed is not None: 369 | x = x + self.pos_embed 370 | x = self.pos_drop(x) 371 | 372 | s_e, s_a, s_v = [], [], [] 373 | for blk in self.sabs: 374 | x, attn, v = blk(x) 375 | s_e.append(x) 376 | if attn is not None: 377 | s_a.append(attn) 378 | if v is not None: 379 | s_v.append(v) 380 | 381 | # Specific part, this is what we called the "task specific DECODER" 382 | if self.joint_tokens: 383 | return self.forward_features_jointtokens(x) 384 | 385 | tokens = [] 386 | attentions = [] 387 | mask_heads = None 388 | 389 | for task_token in self.task_tokens: 390 | task_token = task_token.expand(B, -1, -1) 391 | 392 | ca_blocks = self.tabs 393 | 394 | for blk in ca_blocks: 395 | task_token, attn, v = blk(torch.cat((task_token, x), dim=1), mask_heads=mask_heads) 396 | 397 | attentions.append(attn.unsqueeze(dim=0)) 398 | tokens.append(task_token[:, 0]) 399 | 400 | self._class_tokens = tokens 401 | return tokens, tokens[-1], attentions 402 | 403 | def forward_features_jointtokens(self, x): 404 | """Method to do a single TAB forward with all task tokens. 405 | 406 | A masking is used to avoid interaction between tasks. In theory it should 407 | give the same results as multiple TAB forward, but in practice it's a little 408 | bit worse, not sure why. So if you have an idea, please tell me! 409 | """ 410 | B = len(x) 411 | 412 | task_tokens = torch.cat( 413 | [task_token.expand(B, 1, -1) for task_token in self.task_tokens], 414 | dim=1 415 | ) 416 | 417 | for blk in self.tabs: 418 | task_tokens, _, _ = blk( 419 | torch.cat((task_tokens, x), dim=1), 420 | task_index=len(self.task_tokens), 421 | attn_mask=True 422 | ) 423 | 424 | if self.individual_classifier in ('1-1', '1-n'): 425 | return task_tokens.permute(1, 0, 2), task_tokens[:, -1], None 426 | return task_tokens.view(B, -1), task_tokens[:, -1], None 427 | 428 | def forward_classifier_multitoken(self, tokens, last_token, attentions): 429 | """Once all task embeddings e_1, ..., e_t are extracted, classify. 430 | 431 | Classifier has different mode based on a pattern x-y: 432 | - x means the number of task embeddings in input 433 | - y means the number of task to predict 434 | 435 | So: 436 | - n-n: predicts all task given all embeddings 437 | But: 438 | - 1-1: predict 1 task given 1 embedding, which is the 'independent classifier' used in the paper. 439 | 440 | :param tokens: A list of all task tokens embeddings. 441 | :param last_token: The ultimate task token embedding from the latest task. 442 | """ 443 | logits_div = None 444 | 445 | logits = [] 446 | 447 | # assuming self.individual_classifier is always '1-1' here 448 | for i in range(len(tokens)): 449 | logits.append(self.head[0](tokens[i])) 450 | 451 | logits = torch.cat(logits, dim=1) 452 | 453 | attentions = torch.cat(attentions, dim=0) 454 | 455 | return { 456 | 'logits': logits, 457 | 'div': logits_div, 458 | 'tokens': tokens, # 128, 384 459 | 'attention': attentions # 128, 12, 1, 65 460 | } 461 | 462 | def forward_classifier(self, tokens, last_token, attentions): 463 | """Once all task embeddings e_1, ..., e_t are extracted, classify. 464 | 465 | Classifier has different mode based on a pattern x-y: 466 | - x means the number of task embeddings in input 467 | - y means the number of task to predict 468 | 469 | So: 470 | - n-n: predicts all task given all embeddings 471 | But: 472 | - 1-1: predict 1 task given 1 embedding, which is the 'independent classifier' used in the paper. 473 | 474 | :param tokens: A list of all task tokens embeddings. 475 | :param last_token: The ultimate task token embedding from the latest task. 476 | """ 477 | logits_div = None 478 | 479 | if self.individual_classifier != '': 480 | logits = [] 481 | 482 | for i, head in enumerate(self.head): 483 | if self.individual_classifier in ('1-n', '1-1'): 484 | logits.append(head(tokens[i])) 485 | else: # n-1, n-n 486 | logits.append(head(torch.cat(tokens[:i+1], dim=1))) 487 | 488 | if self.individual_classifier in ('1-1', 'n-1'): 489 | logits = torch.cat(logits, dim=1) 490 | else: # 1-n, n-n 491 | final_logits = torch.zeros_like(logits[-1]) 492 | for i in range(len(logits)): 493 | final_logits[:, :logits[i].shape[1]] += logits[i] 494 | 495 | for i, c in enumerate(self.nb_classes_per_task): 496 | final_logits[:, :c] /= len(self.nb_classes_per_task) - i 497 | 498 | logits = final_logits 499 | elif isinstance(tokens, torch.Tensor): 500 | logits = self.head(tokens) 501 | else: 502 | logits = self.head(torch.cat(tokens, dim=1)) 503 | 504 | if self.head_div is not None and eval_training_finetuning(self.head_div_mode, self.in_finetuning): 505 | logits_div = self.head_div(last_token) # only last token 506 | 507 | # modify attentions list to extract only the first element 508 | attentions = torch.cat(attentions, dim=0) 509 | 510 | return { 511 | 'logits': logits, 512 | 'div': logits_div, 513 | 'tokens': tokens, # 128, 384 514 | 'attention': attentions # 128, 12, 1, 65 515 | } 516 | 517 | def forward(self, x, batch_tasks=None, initial=False, latter=False, args=None): 518 | if initial: 519 | return self.forward_initial(x) 520 | elif latter: 521 | return self.forward_latter(x, args=args) 522 | elif self.multi_token_setup: 523 | tokens, last_token, attentions = self.forward_features_multitoken(x, batch_tasks=batch_tasks) 524 | return self.forward_classifier_multitoken(tokens, last_token, attentions) 525 | else: 526 | tokens, last_token, attentions = self.forward_features(x, batch_tasks=batch_tasks) 527 | return self.forward_classifier(tokens, last_token, attentions) 528 | 529 | 530 | def eval_training_finetuning(mode, in_ft): 531 | if 'tr' in mode and 'ft' in mode: 532 | return True 533 | if 'tr' in mode and not in_ft: 534 | return True 535 | if 'ft' in mode and in_ft: 536 | return True 537 | return False 538 | -------------------------------------------------------------------------------- /continual/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Classifier(nn.Module): 7 | def __init__(self, embed_dim, nb_total_classes, nb_base_classes, increment, nb_tasks=None, bias=True, complete=True, cosine=False, norm=True): 8 | super().__init__() 9 | 10 | self.embed_dim = embed_dim 11 | self.nb_classes = nb_base_classes 12 | self.cosine = cosine # false 13 | 14 | if self.cosine not in (False, None, ''): 15 | self.scale = nn.Parameter(torch.tensor(1.)) 16 | else: 17 | self.scale = 1 18 | self.head = nn.Linear(embed_dim, nb_base_classes, bias=not cosine) 19 | self.norm = nn.LayerNorm(embed_dim) if norm else nn.Identitty() 20 | self.increment = increment 21 | 22 | def reset_parameters(self): 23 | self.head.reset_parameters() 24 | self.norm.reset_parameters() 25 | 26 | def forward(self, x): 27 | x = self.norm(x) 28 | 29 | if self.cosine not in (False, None, ''): 30 | w = self.head.weight # (c, d) 31 | 32 | if self.cosine == 'pcc': 33 | x = x - x.mean(dim=1, keepdims=True) 34 | w = w - w.mean(dim=1, keepdims=True) 35 | x = F.normalize(x, p=2, dim=1) # (bs, d) 36 | w = F.normalize(w, p=2, dim=1) # (c, d) 37 | return self.scale * torch.mm(x, w.T) 38 | 39 | return self.head(x) 40 | 41 | def init_prev_head(self, head): 42 | w, b = head.weight.data, head.bias.data 43 | self.head.weight.data[:w.shape[0], :w.shape[1]] = w 44 | self.head.bias.data[:b.shape[0]] = b 45 | 46 | def init_prev_norm(self, norm): 47 | w, b = norm.weight.data, norm.bias.data 48 | self.norm.weight.data[:w.shape[0]] = w 49 | self.norm.bias.data[:b.shape[0]] = b 50 | 51 | @torch.no_grad() 52 | def weight_align(self, nb_new_classes): 53 | w = self.head.weight.data 54 | norms = torch.norm(w, dim=1) 55 | 56 | norm_old = norms[:-nb_new_classes] 57 | norm_new = norms[-nb_new_classes:] 58 | 59 | gamma = torch.mean(norm_old) / torch.mean(norm_new) 60 | w[-nb_new_classes:] = gamma * w[-nb_new_classes:] 61 | 62 | def add_classes(self): 63 | self.add_new_outputs(self.increment) 64 | 65 | def add_new_outputs(self, n): 66 | head = nn.Linear(self.embed_dim, self.nb_classes + n, bias=not self.cosine) 67 | head.weight.data[:-n] = self.head.weight.data 68 | if not self.cosine: 69 | head.bias.data[:-n] = self.head.bias.data 70 | 71 | head.to(self.head.weight.device) 72 | self.head = head 73 | self.nb_classes += n 74 | -------------------------------------------------------------------------------- /continual/cnn/__init__.py: -------------------------------------------------------------------------------- 1 | from continual.cnn.abstract import AbstractCNN 2 | from continual.cnn.inception import InceptionV3 3 | from continual.cnn.senet import legacy_seresnet18 as seresnet18 4 | from continual.cnn.resnet import ( 5 | resnet18, resnet34, resnet50, resnext50_32x4d, wide_resnet50_2 6 | ) 7 | from continual.cnn.resnet_scs import resnet18_scs, resnet18_scs_avg, resnet18_scs_max 8 | from continual.cnn.vgg import vgg16_bn, vgg16 9 | from continual.cnn.resnet_rebuffi import CifarResNet as rebuffi 10 | -------------------------------------------------------------------------------- /continual/cnn/abstract.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import continual.utils as cutils 4 | 5 | 6 | class AbstractCNN(nn.Module): 7 | def reset_classifier(self): 8 | self.head.reset_parameters() 9 | 10 | def get_internal_losses(self, clf_loss): 11 | return {} 12 | 13 | def end_finetuning(self): 14 | pass 15 | 16 | def begin_finetuning(self): 17 | pass 18 | 19 | def epoch_log(self): 20 | return {} 21 | 22 | def get_classifier(self): 23 | return self.head 24 | 25 | def freeze(self, names): 26 | cutils.freeze_parameters(self, requires_grad=True) 27 | self.train() 28 | 29 | for name in names: 30 | if name == 'head': 31 | cutils.freeze_parameters(self.head) 32 | self.head.eval() 33 | elif name == 'backbone': 34 | for k, p in self.named_parameters(): 35 | if not k.startswith('head'): 36 | cutils.freeze_parameters(p) 37 | elif name == 'all': 38 | cutils.freeze_parameters(self) 39 | self.eval() 40 | else: 41 | raise NotImplementedError(f'Unknown module name to freeze {name}') 42 | -------------------------------------------------------------------------------- /continual/cnn/inception.py: -------------------------------------------------------------------------------- 1 | """ inceptionv3 in pytorch 2 | 3 | 4 | [1] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna 5 | 6 | Rethinking the Inception Architecture for Computer Vision 7 | https://arxiv.org/abs/1512.00567v3 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | from continual.cnn import AbstractCNN 15 | 16 | 17 | class BasicConv2d(nn.Module): 18 | 19 | def __init__(self, input_channels, output_channels, **kwargs): 20 | super().__init__() 21 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) 22 | self.bn = nn.BatchNorm2d(output_channels) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | x = self.bn(x) 28 | x = self.relu(x) 29 | 30 | return x 31 | 32 | #same naive inception module 33 | class InceptionA(nn.Module): 34 | 35 | def __init__(self, input_channels, pool_features): 36 | super().__init__() 37 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1) 38 | 39 | self.branch5x5 = nn.Sequential( 40 | BasicConv2d(input_channels, 48, kernel_size=1), 41 | BasicConv2d(48, 64, kernel_size=5, padding=2) 42 | ) 43 | 44 | self.branch3x3 = nn.Sequential( 45 | BasicConv2d(input_channels, 64, kernel_size=1), 46 | BasicConv2d(64, 96, kernel_size=3, padding=1), 47 | BasicConv2d(96, 96, kernel_size=3, padding=1) 48 | ) 49 | 50 | self.branchpool = nn.Sequential( 51 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 52 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1) 53 | ) 54 | 55 | def forward(self, x): 56 | 57 | #x -> 1x1(same) 58 | branch1x1 = self.branch1x1(x) 59 | 60 | #x -> 1x1 -> 5x5(same) 61 | branch5x5 = self.branch5x5(x) 62 | #branch5x5 = self.branch5x5_2(branch5x5) 63 | 64 | #x -> 1x1 -> 3x3 -> 3x3(same) 65 | branch3x3 = self.branch3x3(x) 66 | 67 | #x -> pool -> 1x1(same) 68 | branchpool = self.branchpool(x) 69 | 70 | outputs = [branch1x1, branch5x5, branch3x3, branchpool] 71 | 72 | return torch.cat(outputs, 1) 73 | 74 | #downsample 75 | #Factorization into smaller convolutions 76 | class InceptionB(nn.Module): 77 | 78 | def __init__(self, input_channels): 79 | super().__init__() 80 | 81 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2) 82 | 83 | self.branch3x3stack = nn.Sequential( 84 | BasicConv2d(input_channels, 64, kernel_size=1), 85 | BasicConv2d(64, 96, kernel_size=3, padding=1), 86 | BasicConv2d(96, 96, kernel_size=3, stride=2) 87 | ) 88 | 89 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 90 | 91 | def forward(self, x): 92 | 93 | #x - > 3x3(downsample) 94 | branch3x3 = self.branch3x3(x) 95 | 96 | #x -> 3x3 -> 3x3(downsample) 97 | branch3x3stack = self.branch3x3stack(x) 98 | 99 | #x -> avgpool(downsample) 100 | branchpool = self.branchpool(x) 101 | 102 | #"""We can use two parallel stride 2 blocks: P and C. P is a pooling 103 | #layer (either average or maximum pooling) the activation, both of 104 | #them are stride 2 the filter banks of which are concatenated as in 105 | #figure 10.""" 106 | outputs = [branch3x3, branch3x3stack, branchpool] 107 | 108 | return torch.cat(outputs, 1) 109 | 110 | #Factorizing Convolutions with Large Filter Size 111 | class InceptionC(nn.Module): 112 | def __init__(self, input_channels, channels_7x7): 113 | super().__init__() 114 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 115 | 116 | c7 = channels_7x7 117 | 118 | #In theory, we could go even further and argue that one can replace any n × n 119 | #convolution by a 1 × n convolution followed by a n × 1 convolution and the 120 | #computational cost saving increases dramatically as n grows (see figure 6). 121 | self.branch7x7 = nn.Sequential( 122 | BasicConv2d(input_channels, c7, kernel_size=1), 123 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 124 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 125 | ) 126 | 127 | self.branch7x7stack = nn.Sequential( 128 | BasicConv2d(input_channels, c7, kernel_size=1), 129 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 130 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)), 131 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 132 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 133 | ) 134 | 135 | self.branch_pool = nn.Sequential( 136 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 137 | BasicConv2d(input_channels, 192, kernel_size=1), 138 | ) 139 | 140 | def forward(self, x): 141 | 142 | #x -> 1x1(same) 143 | branch1x1 = self.branch1x1(x) 144 | 145 | #x -> 1layer 1*7 and 7*1 (same) 146 | branch7x7 = self.branch7x7(x) 147 | 148 | #x-> 2layer 1*7 and 7*1(same) 149 | branch7x7stack = self.branch7x7stack(x) 150 | 151 | #x-> avgpool (same) 152 | branchpool = self.branch_pool(x) 153 | 154 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool] 155 | 156 | return torch.cat(outputs, 1) 157 | 158 | class InceptionD(nn.Module): 159 | 160 | def __init__(self, input_channels): 161 | super().__init__() 162 | 163 | self.branch3x3 = nn.Sequential( 164 | BasicConv2d(input_channels, 192, kernel_size=1), 165 | BasicConv2d(192, 320, kernel_size=3, stride=2) 166 | ) 167 | 168 | self.branch7x7 = nn.Sequential( 169 | BasicConv2d(input_channels, 192, kernel_size=1), 170 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), 171 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)), 172 | BasicConv2d(192, 192, kernel_size=3, stride=2) 173 | ) 174 | 175 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2) 176 | 177 | def forward(self, x): 178 | 179 | #x -> 1x1 -> 3x3(downsample) 180 | branch3x3 = self.branch3x3(x) 181 | 182 | #x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample) 183 | branch7x7 = self.branch7x7(x) 184 | 185 | #x -> avgpool (downsample) 186 | branchpool = self.branchpool(x) 187 | 188 | outputs = [branch3x3, branch7x7, branchpool] 189 | 190 | return torch.cat(outputs, 1) 191 | 192 | 193 | #same 194 | class InceptionE(nn.Module): 195 | def __init__(self, input_channels): 196 | super().__init__() 197 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1) 198 | 199 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1) 200 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 201 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 202 | 203 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1) 204 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 205 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 206 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 207 | 208 | self.branch_pool = nn.Sequential( 209 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 210 | BasicConv2d(input_channels, 192, kernel_size=1) 211 | ) 212 | 213 | def forward(self, x): 214 | 215 | #x -> 1x1 (same) 216 | branch1x1 = self.branch1x1(x) 217 | 218 | # x -> 1x1 -> 3x1 219 | # x -> 1x1 -> 1x3 220 | # concatenate(3x1, 1x3) 221 | #"""7. Inception modules with expanded the filter bank outputs. 222 | #This architecture is used on the coarsest (8 × 8) grids to promote 223 | #high dimensional representations, as suggested by principle 224 | #2 of Section 2.""" 225 | branch3x3 = self.branch3x3_1(x) 226 | branch3x3 = [ 227 | self.branch3x3_2a(branch3x3), 228 | self.branch3x3_2b(branch3x3) 229 | ] 230 | branch3x3 = torch.cat(branch3x3, 1) 231 | 232 | # x -> 1x1 -> 3x3 -> 1x3 233 | # x -> 1x1 -> 3x3 -> 3x1 234 | #concatenate(1x3, 3x1) 235 | branch3x3stack = self.branch3x3stack_1(x) 236 | branch3x3stack = self.branch3x3stack_2(branch3x3stack) 237 | branch3x3stack = [ 238 | self.branch3x3stack_3a(branch3x3stack), 239 | self.branch3x3stack_3b(branch3x3stack) 240 | ] 241 | branch3x3stack = torch.cat(branch3x3stack, 1) 242 | 243 | branchpool = self.branch_pool(x) 244 | 245 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool] 246 | 247 | return torch.cat(outputs, 1) 248 | 249 | class InceptionV3(AbstractCNN): 250 | 251 | def __init__(self, num_classes=100): 252 | super().__init__() 253 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, padding=1) 254 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 255 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 256 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 257 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 258 | 259 | #naive inception module 260 | self.Mixed_5b = InceptionA(192, pool_features=32) 261 | self.Mixed_5c = InceptionA(256, pool_features=64) 262 | self.Mixed_5d = InceptionA(288, pool_features=64) 263 | 264 | #downsample 265 | self.Mixed_6a = InceptionB(288) 266 | 267 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 268 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 269 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 270 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 271 | 272 | #downsample 273 | self.Mixed_7a = InceptionD(768) 274 | 275 | self.Mixed_7b = InceptionE(1280) 276 | self.Mixed_7c = InceptionE(2048) 277 | 278 | #6*6 feature size 279 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 280 | self.dropout = nn.Dropout2d() 281 | self.head = None 282 | self.embed_dim = 2048 283 | 284 | def forward(self, x): 285 | 286 | #32 -> 30 287 | x = self.Conv2d_1a_3x3(x) 288 | x = self.Conv2d_2a_3x3(x) 289 | x = self.Conv2d_2b_3x3(x) 290 | x = self.Conv2d_3b_1x1(x) 291 | x = self.Conv2d_4a_3x3(x) 292 | 293 | #30 -> 30 294 | x = self.Mixed_5b(x) 295 | x = self.Mixed_5c(x) 296 | x = self.Mixed_5d(x) 297 | 298 | #30 -> 14 299 | #Efficient Grid Size Reduction to avoid representation 300 | #bottleneck 301 | x = self.Mixed_6a(x) 302 | 303 | #14 -> 14 304 | #"""In practice, we have found that employing this factorization does not 305 | #work well on early layers, but it gives very good results on medium 306 | #grid-sizes (On m × m feature maps, where m ranges between 12 and 20). 307 | #On that level, very good results can be achieved by using 1 × 7 convolutions 308 | #followed by 7 × 1 convolutions.""" 309 | x = self.Mixed_6b(x) 310 | x = self.Mixed_6c(x) 311 | x = self.Mixed_6d(x) 312 | x = self.Mixed_6e(x) 313 | 314 | #14 -> 6 315 | #Efficient Grid Size Reduction 316 | x = self.Mixed_7a(x) 317 | 318 | #6 -> 6 319 | #We are using this solution only on the coarsest grid, 320 | #since that is the place where producing high dimensional 321 | #sparse representation is the most critical as the ratio of 322 | #local processing (by 1 × 1 convolutions) is increased compared 323 | #to the spatial aggregation.""" 324 | x = self.Mixed_7b(x) 325 | x = self.Mixed_7c(x) 326 | 327 | #6 -> 1 328 | x = self.avgpool(x) 329 | x = self.dropout(x) 330 | x = x.view(x.size(0), -1) 331 | x = self.head(x) 332 | return x 333 | 334 | 335 | def inceptionv3(): 336 | return InceptionV3() 337 | 338 | 339 | 340 | -------------------------------------------------------------------------------- /continual/cnn/resnet.py: -------------------------------------------------------------------------------- 1 | #from .utils import load_state_dict_from_url 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from continual.cnn import AbstractCNN 7 | from torch import Tensor 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion: int = 1 40 | 41 | def __init__( 42 | self, 43 | inplanes: int, 44 | planes: int, 45 | stride: int = 1, 46 | downsample: Optional[nn.Module] = None, 47 | groups: int = 1, 48 | base_width: int = 64, 49 | dilation: int = 1, 50 | norm_layer: Optional[Callable[..., nn.Module]] = None 51 | ) -> None: 52 | super(BasicBlock, self).__init__() 53 | if norm_layer is None: 54 | norm_layer = nn.BatchNorm2d 55 | if groups != 1 or base_width != 64: 56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 57 | #if dilation > 1: 58 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = norm_layer(planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = norm_layer(planes) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | identity = self.downsample(x) 80 | 81 | out += identity 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 91 | # This variant is also known as ResNet V1.5 and improves accuracy according to 92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 93 | 94 | expansion: int = 4 95 | 96 | def __init__( 97 | self, 98 | inplanes: int, 99 | planes: int, 100 | stride: int = 1, 101 | downsample: Optional[nn.Module] = None, 102 | groups: int = 1, 103 | base_width: int = 64, 104 | dilation: int = 1, 105 | norm_layer: Optional[Callable[..., nn.Module]] = None 106 | ) -> None: 107 | super(Bottleneck, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | width = int(planes * (base_width / 64.)) * groups 111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 112 | self.conv1 = conv1x1(inplanes, width) 113 | self.bn1 = norm_layer(width) 114 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 115 | self.bn2 = norm_layer(width) 116 | self.conv3 = conv1x1(width, planes * self.expansion) 117 | self.bn3 = norm_layer(planes * self.expansion) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | def forward(self, x: Tensor) -> Tensor: 123 | identity = x 124 | 125 | out = self.conv1(x) 126 | out = self.bn1(out) 127 | out = self.relu(out) 128 | 129 | out = self.conv2(out) 130 | out = self.bn2(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv3(out) 134 | out = self.bn3(out) 135 | 136 | if self.downsample is not None: 137 | identity = self.downsample(x) 138 | 139 | out += identity 140 | out = self.relu(out) 141 | 142 | return out 143 | 144 | 145 | class ResNet(AbstractCNN): 146 | 147 | def __init__( 148 | self, 149 | block: Type[Union[BasicBlock, Bottleneck]], 150 | layers: List[int], 151 | num_classes: int = 1000, 152 | zero_init_residual: bool = False, 153 | groups: int = 1, 154 | width_per_group: int = 64, 155 | replace_stride_with_dilation: Optional[List[bool]] = None, 156 | norm_layer: Optional[Callable[..., nn.Module]] = None 157 | ) -> None: 158 | super(ResNet, self).__init__() 159 | if norm_layer is None: 160 | norm_layer = nn.BatchNorm2d 161 | self._norm_layer = norm_layer 162 | 163 | self.inplanes = 64 164 | self.dilation = 1 165 | if replace_stride_with_dilation is None: 166 | # each element in the tuple indicates if we should replace 167 | # the 2x2 stride with a dilated convolution instead 168 | replace_stride_with_dilation = [False, False, False] 169 | if len(replace_stride_with_dilation) != 3: 170 | raise ValueError("replace_stride_with_dilation should be None " 171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 172 | self.groups = groups 173 | self.base_width = width_per_group 174 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 175 | bias=False) 176 | self.bn1 = norm_layer(self.inplanes) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.layer1 = self._make_layer(block, 64, layers[0]) 180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 181 | dilate=replace_stride_with_dilation[0]) 182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 183 | dilate=replace_stride_with_dilation[1]) 184 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 185 | dilate=replace_stride_with_dilation[2]) 186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 187 | 188 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 189 | self.embed_dim = 512 * block.expansion 190 | self.head = None 191 | 192 | for m in self.modules(): 193 | if isinstance(m, nn.Conv2d): 194 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 195 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 196 | nn.init.constant_(m.weight, 1) 197 | nn.init.constant_(m.bias, 0) 198 | 199 | # Zero-initialize the last BN in each residual branch, 200 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 201 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 202 | if zero_init_residual: 203 | for m in self.modules(): 204 | if isinstance(m, Bottleneck): 205 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 206 | elif isinstance(m, BasicBlock): 207 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 208 | 209 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 210 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 211 | norm_layer = self._norm_layer 212 | downsample = None 213 | previous_dilation = self.dilation 214 | if dilate: 215 | self.dilation *= stride 216 | stride = 1 217 | if stride != 1 or self.inplanes != planes * block.expansion: 218 | downsample = nn.Sequential( 219 | conv1x1(self.inplanes, planes * block.expansion, stride), 220 | norm_layer(planes * block.expansion), 221 | ) 222 | 223 | layers = [] 224 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 225 | self.base_width, previous_dilation, norm_layer)) 226 | self.inplanes = planes * block.expansion 227 | for _ in range(1, blocks): 228 | layers.append(block(self.inplanes, planes, groups=self.groups, 229 | base_width=self.base_width, dilation=self.dilation, 230 | norm_layer=norm_layer)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def _make_layer_nodown(self, inplanes: int, planes: int, blocks: int, 235 | stride: int = 1, dilation: int = 1) -> nn.Sequential: 236 | norm_layer = self._norm_layer 237 | downsample = nn.Conv2d(256, 512, kernel_size=1) 238 | previous_dilation = self.dilation = dilation 239 | 240 | layers = [] 241 | layers.append(BasicBlock(inplanes, planes, stride, downsample, self.groups, 242 | self.base_width, previous_dilation, norm_layer)) 243 | self.inplanes = planes * BasicBlock.expansion 244 | for _ in range(1, blocks): 245 | layers.append(BasicBlock(self.inplanes, planes, groups=self.groups, 246 | base_width=self.base_width, dilation=self.dilation, 247 | norm_layer=norm_layer)) 248 | 249 | return nn.Sequential(*layers) 250 | 251 | def _forward_impl(self, x: Tensor) -> Tensor: 252 | # See note [TorchScript super()] 253 | x = self.conv1(x) 254 | x = self.bn1(x) 255 | x = self.relu(x) 256 | x = self.maxpool(x) 257 | 258 | x = self.layer1(x) 259 | x = self.layer2(x) 260 | x = self.layer3(x) 261 | x = self.layer4(x) 262 | 263 | x = self.avgpool(x) 264 | x = torch.flatten(x, 1) 265 | x = self.head(x) 266 | 267 | return x 268 | 269 | def forward(self, x: Tensor) -> Tensor: 270 | return self._forward_impl(x) 271 | 272 | def forward_tokens(self, x): 273 | x = self.conv1(x) 274 | x = self.bn1(x) 275 | x = self.relu(x) 276 | x = self.maxpool(x) 277 | 278 | x = self.layer1(x) 279 | x = self.layer2(x) 280 | x = self.layer3(x) 281 | x = self.layer4(x) 282 | 283 | x = self.head(x) 284 | return x.view(x.shape[0], self.embed_dim, -1).permute(0, 2, 1) 285 | 286 | def forward_features(self, x): 287 | x = self.conv1(x) 288 | x = self.bn1(x) 289 | x = self.relu(x) 290 | x = self.maxpool(x) 291 | 292 | x = self.layer1(x) 293 | x = self.layer2(x) 294 | x = self.layer3(x) 295 | x = self.layer4(x) 296 | 297 | x = self.avgpool(x) 298 | x = torch.flatten(x, 1) 299 | return x, None, None 300 | 301 | 302 | def _resnet( 303 | arch: str, 304 | block: Type[Union[BasicBlock, Bottleneck]], 305 | layers: List[int], 306 | pretrained: bool, 307 | progress: bool, 308 | **kwargs: Any 309 | ) -> ResNet: 310 | model = ResNet(block, layers, **kwargs) 311 | if pretrained: 312 | state_dict = load_state_dict_from_url(model_urls[arch], 313 | progress=progress) 314 | model.load_state_dict(state_dict) 315 | return model 316 | 317 | 318 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 319 | r"""ResNet-18 model from 320 | `"Deep Residual Learning for Image Recognition" `_. 321 | 322 | Args: 323 | pretrained (bool): If True, returns a model pre-trained on ImageNet 324 | progress (bool): If True, displays a progress bar of the download to stderr 325 | """ 326 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 327 | **kwargs) 328 | 329 | 330 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 331 | r"""ResNet-34 model from 332 | `"Deep Residual Learning for Image Recognition" `_. 333 | 334 | Args: 335 | pretrained (bool): If True, returns a model pre-trained on ImageNet 336 | progress (bool): If True, displays a progress bar of the download to stderr 337 | """ 338 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 339 | **kwargs) 340 | 341 | 342 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 343 | r"""ResNet-50 model from 344 | `"Deep Residual Learning for Image Recognition" `_. 345 | 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 351 | **kwargs) 352 | 353 | 354 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 355 | r"""ResNet-101 model from 356 | `"Deep Residual Learning for Image Recognition" `_. 357 | 358 | Args: 359 | pretrained (bool): If True, returns a model pre-trained on ImageNet 360 | progress (bool): If True, displays a progress bar of the download to stderr 361 | """ 362 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 363 | **kwargs) 364 | 365 | 366 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 367 | r"""ResNet-152 model from 368 | `"Deep Residual Learning for Image Recognition" `_. 369 | 370 | Args: 371 | pretrained (bool): If True, returns a model pre-trained on ImageNet 372 | progress (bool): If True, displays a progress bar of the download to stderr 373 | """ 374 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 375 | **kwargs) 376 | 377 | 378 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 379 | r"""ResNeXt-50 32x4d model from 380 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 381 | 382 | Args: 383 | pretrained (bool): If True, returns a model pre-trained on ImageNet 384 | progress (bool): If True, displays a progress bar of the download to stderr 385 | """ 386 | kwargs['groups'] = 32 387 | kwargs['width_per_group'] = 4 388 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 389 | pretrained, progress, **kwargs) 390 | 391 | 392 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 393 | r"""ResNeXt-101 32x8d model from 394 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 395 | 396 | Args: 397 | pretrained (bool): If True, returns a model pre-trained on ImageNet 398 | progress (bool): If True, displays a progress bar of the download to stderr 399 | """ 400 | kwargs['groups'] = 32 401 | kwargs['width_per_group'] = 8 402 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 403 | pretrained, progress, **kwargs) 404 | 405 | 406 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 407 | r"""Wide ResNet-50-2 model from 408 | `"Wide Residual Networks" `_. 409 | 410 | The model is the same as ResNet except for the bottleneck number of channels 411 | which is twice larger in every block. The number of channels in outer 1x1 412 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 413 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 414 | 415 | Args: 416 | pretrained (bool): If True, returns a model pre-trained on ImageNet 417 | progress (bool): If True, displays a progress bar of the download to stderr 418 | """ 419 | kwargs['width_per_group'] = 64 * 2 420 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 421 | pretrained, progress, **kwargs) 422 | 423 | 424 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 425 | r"""Wide ResNet-101-2 model from 426 | `"Wide Residual Networks" `_. 427 | 428 | The model is the same as ResNet except for the bottleneck number of channels 429 | which is twice larger in every block. The number of channels in outer 1x1 430 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 431 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 432 | 433 | Args: 434 | pretrained (bool): If True, returns a model pre-trained on ImageNet 435 | progress (bool): If True, displays a progress bar of the download to stderr 436 | """ 437 | kwargs['width_per_group'] = 64 * 2 438 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 439 | pretrained, progress, **kwargs) 440 | -------------------------------------------------------------------------------- /continual/cnn/resnet_rebuffi.py: -------------------------------------------------------------------------------- 1 | """Pytorch port of the resnet used for CIFAR100 by iCaRL. 2 | 3 | https://github.com/srebuffi/iCaRL/blob/master/iCaRL-TheanoLasagne/utils_cifar100.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | from continual.cnn import AbstractCNN 11 | 12 | 13 | class DownsampleStride(nn.Module): 14 | 15 | def __init__(self, n=2): 16 | super(DownsampleStride, self).__init__() 17 | self._n = n 18 | 19 | def forward(self, x): 20 | return x[..., ::2, ::2] 21 | 22 | 23 | class DownsampleConv(nn.Module): 24 | 25 | def __init__(self, inplanes, planes): 26 | super().__init__() 27 | 28 | self.conv = nn.Sequential( 29 | nn.Conv2d(inplanes, planes, stride=2, kernel_size=1, bias=False), 30 | nn.BatchNorm2d(planes), 31 | ) 32 | 33 | def forward(self, x): 34 | return self.conv(x) 35 | 36 | 37 | class ResidualBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, increase_dim=False, last_relu=False, downsampling="stride"): 41 | super(ResidualBlock, self).__init__() 42 | 43 | self.increase_dim = increase_dim 44 | 45 | if increase_dim: 46 | first_stride = 2 47 | planes = inplanes * 2 48 | else: 49 | first_stride = 1 50 | planes = inplanes 51 | 52 | self.conv_a = nn.Conv2d( 53 | inplanes, planes, kernel_size=3, stride=first_stride, padding=1, bias=False 54 | ) 55 | self.bn_a = nn.BatchNorm2d(planes) 56 | 57 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 58 | self.bn_b = nn.BatchNorm2d(planes) 59 | 60 | if increase_dim: 61 | if downsampling == "stride": 62 | self.downsampler = DownsampleStride() 63 | self._need_pad = True 64 | else: 65 | self.downsampler = DownsampleConv(inplanes, planes) 66 | self._need_pad = False 67 | 68 | self.last_relu = last_relu 69 | 70 | @staticmethod 71 | def pad(x): 72 | return torch.cat((x, x.mul(0)), 1) 73 | 74 | def forward(self, x): 75 | y = self.conv_a(x) 76 | y = self.bn_a(y) 77 | y = F.relu(y, inplace=True) 78 | 79 | y = self.conv_b(y) 80 | y = self.bn_b(y) 81 | 82 | if self.increase_dim: 83 | x = self.downsampler(x) 84 | if self._need_pad: 85 | x = self.pad(x) 86 | 87 | y = x + y 88 | 89 | if self.last_relu: 90 | y = F.relu(y, inplace=True) 91 | 92 | return y 93 | 94 | 95 | class PreActResidualBlock(nn.Module): 96 | expansion = 1 97 | 98 | def __init__(self, inplanes, increase_dim=False, last_relu=False): 99 | super().__init__() 100 | 101 | self.increase_dim = increase_dim 102 | 103 | if increase_dim: 104 | first_stride = 2 105 | planes = inplanes * 2 106 | else: 107 | first_stride = 1 108 | planes = inplanes 109 | 110 | self.bn_a = nn.BatchNorm2d(inplanes) 111 | self.conv_a = nn.Conv2d( 112 | inplanes, planes, kernel_size=3, stride=first_stride, padding=1, bias=False 113 | ) 114 | 115 | self.bn_b = nn.BatchNorm2d(planes) 116 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 117 | 118 | if increase_dim: 119 | self.downsample = DownsampleStride() 120 | self.pad = lambda x: torch.cat((x, x.mul(0)), 1) 121 | self.last_relu = last_relu 122 | 123 | def forward(self, x): 124 | y = self.bn_a(x) 125 | y = F.relu(y, inplace=True) 126 | y = self.conv_a(x) 127 | 128 | y = self.bn_b(y) 129 | y = F.relu(y, inplace=True) 130 | y = self.conv_b(y) 131 | 132 | if self.increase_dim: 133 | x = self.downsample(x) 134 | x = self.pad(x) 135 | 136 | y = x + y 137 | 138 | if self.last_relu: 139 | y = F.relu(y, inplace=True) 140 | 141 | return y 142 | 143 | 144 | class Stage(nn.Module): 145 | 146 | def __init__(self, blocks, block_relu=False): 147 | super().__init__() 148 | 149 | self.blocks = nn.ModuleList(blocks) 150 | self.block_relu = block_relu 151 | 152 | def forward(self, x): 153 | intermediary_features = [] 154 | 155 | for b in self.blocks: 156 | x = b(x) 157 | intermediary_features.append(x) 158 | 159 | if self.block_relu: 160 | x = F.relu(x) 161 | 162 | return intermediary_features, x 163 | 164 | 165 | class CifarResNet(AbstractCNN): 166 | """ 167 | ResNet optimized for the Cifar Dataset, as specified in 168 | https://arxiv.org/abs/1512.03385.pdf 169 | """ 170 | 171 | def __init__( 172 | self, 173 | n=5, 174 | nf=16, 175 | channels=3, 176 | preact=False, 177 | zero_residual=True, 178 | pooling_config={"type": "avg"}, 179 | downsampling="stride", 180 | all_attentions=False, 181 | last_relu=True, 182 | **kwargs 183 | ): 184 | """ Constructor 185 | Args: 186 | depth: number of layers. 187 | num_classes: number of classes 188 | base_width: base width 189 | """ 190 | if kwargs: 191 | raise ValueError("Unused kwargs: {}.".format(kwargs)) 192 | 193 | self.all_attentions = all_attentions 194 | self._downsampling_type = downsampling 195 | self.last_relu = last_relu 196 | 197 | Block = ResidualBlock if not preact else PreActResidualBlock 198 | 199 | super(CifarResNet, self).__init__() 200 | 201 | self.conv_1_3x3 = nn.Conv2d(channels, nf, kernel_size=3, stride=1, padding=1, bias=False) 202 | self.bn_1 = nn.BatchNorm2d(nf) 203 | 204 | self.stage_1 = self._make_layer(Block, nf, increase_dim=False, n=n) 205 | self.stage_2 = self._make_layer(Block, nf, increase_dim=True, n=n - 1) 206 | self.stage_3 = self._make_layer(Block, 2 * nf, increase_dim=True, n=n - 2) 207 | self.stage_4 = Block( 208 | 4 * nf, increase_dim=False, last_relu=False, downsampling=self._downsampling_type 209 | ) 210 | 211 | if pooling_config["type"] == "avg": 212 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 213 | else: 214 | raise ValueError("Unknown pooling type {}.".format(pooling_config["type"])) 215 | 216 | self.embed_dim = 4 * nf 217 | self.head = None 218 | 219 | for m in self.modules(): 220 | if isinstance(m, nn.Conv2d): 221 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 222 | elif isinstance(m, nn.BatchNorm2d): 223 | nn.init.constant_(m.weight, 1) 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, nn.Linear): 226 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 227 | 228 | if zero_residual: 229 | for m in self.modules(): 230 | if isinstance(m, ResidualBlock): 231 | nn.init.constant_(m.bn_b.weight, 0) 232 | 233 | def _make_layer(self, Block, planes, increase_dim=False, n=None): 234 | layers = [] 235 | 236 | if increase_dim: 237 | layers.append( 238 | Block( 239 | planes, 240 | increase_dim=True, 241 | last_relu=True, 242 | downsampling=self._downsampling_type 243 | ) 244 | ) 245 | planes = 2 * planes 246 | 247 | for i in range(n): 248 | layers.append(Block(planes, last_relu=True, downsampling=self._downsampling_type)) 249 | 250 | return Stage(layers, block_relu=self.last_relu) 251 | 252 | @property 253 | def last_conv(self): 254 | return self.stage_4.conv_b 255 | 256 | def forward(self, x): 257 | x = self.conv_1_3x3(x) 258 | x = F.relu(self.bn_1(x), inplace=True) 259 | 260 | feats_s1, x = self.stage_1(x) 261 | feats_s2, x = self.stage_2(x) 262 | feats_s3, x = self.stage_3(x) 263 | x = self.stage_4(x) 264 | 265 | features = self.end_features(F.relu(x, inplace=False)) 266 | 267 | return self.head(features) 268 | 269 | def end_features(self, x): 270 | x = self.pool(x) 271 | x = x.view(x.size(0), -1) 272 | 273 | return x 274 | 275 | 276 | def resnet_rebuffi(n=5, **kwargs): 277 | return CifarResNet(n=n, **kwargs) 278 | -------------------------------------------------------------------------------- /continual/cnn/senet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from continual.cnn import AbstractCNN 4 | 5 | """ 6 | SEResNet implementation from Cadene's pretrained models 7 | https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py 8 | Additional credit to https://github.com/creafz 9 | 10 | Original model: https://github.com/hujie-frank/SENet 11 | 12 | ResNet code gently borrowed from 13 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 14 | 15 | FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate 16 | support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here. 17 | """ 18 | import math 19 | from collections import OrderedDict 20 | 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 25 | from timm.models.helpers import build_model_with_cfg 26 | from timm.models.layers import create_classifier 27 | from timm.models.registry import register_model 28 | 29 | __all__ = ['SENet'] 30 | 31 | 32 | def _cfg(url='', **kwargs): 33 | return { 34 | 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 35 | 'crop_pct': 0.875, 'interpolation': 'bilinear', 36 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 37 | 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', 38 | **kwargs 39 | } 40 | 41 | 42 | default_cfgs = { 43 | 'legacy_senet154': 44 | _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'), 45 | 'legacy_seresnet18': _cfg( 46 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', 47 | interpolation='bicubic'), 48 | 'legacy_seresnet34': _cfg( 49 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), 50 | 'legacy_seresnet50': _cfg( 51 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), 52 | 'legacy_seresnet101': _cfg( 53 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), 54 | 'legacy_seresnet152': _cfg( 55 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), 56 | 'legacy_seresnext26_32x4d': _cfg( 57 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', 58 | interpolation='bicubic'), 59 | 'legacy_seresnext50_32x4d': 60 | _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), 61 | 'legacy_seresnext101_32x4d': 62 | _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'), 63 | } 64 | 65 | 66 | def _weight_init(m): 67 | if isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 69 | elif isinstance(m, nn.BatchNorm2d): 70 | nn.init.constant_(m.weight, 1.) 71 | nn.init.constant_(m.bias, 0.) 72 | 73 | 74 | class SEModule(nn.Module): 75 | 76 | def __init__(self, channels, reduction): 77 | super(SEModule, self).__init__() 78 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) 81 | self.sigmoid = nn.Sigmoid() 82 | 83 | def forward(self, x): 84 | module_input = x 85 | x = x.mean((2, 3), keepdim=True) 86 | x = self.fc1(x) 87 | x = self.relu(x) 88 | x = self.fc2(x) 89 | x = self.sigmoid(x) 90 | return module_input * x 91 | 92 | 93 | class Bottleneck(nn.Module): 94 | """ 95 | Base class for bottlenecks that implements `forward()` method. 96 | """ 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out = self.se_module(out) + residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class SEBottleneck(Bottleneck): 122 | """ 123 | Bottleneck for SENet154. 124 | """ 125 | expansion = 4 126 | 127 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 128 | downsample=None): 129 | super(SEBottleneck, self).__init__() 130 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 131 | self.bn1 = nn.BatchNorm2d(planes * 2) 132 | self.conv2 = nn.Conv2d( 133 | planes * 2, planes * 4, kernel_size=3, stride=stride, 134 | padding=1, groups=groups, bias=False) 135 | self.bn2 = nn.BatchNorm2d(planes * 4) 136 | self.conv3 = nn.Conv2d( 137 | planes * 4, planes * 4, kernel_size=1, bias=False) 138 | self.bn3 = nn.BatchNorm2d(planes * 4) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.se_module = SEModule(planes * 4, reduction=reduction) 141 | self.downsample = downsample 142 | self.stride = stride 143 | 144 | 145 | class SEResNetBottleneck(Bottleneck): 146 | """ 147 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 148 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 149 | (the latter is used in the torchvision implementation of ResNet). 150 | """ 151 | expansion = 4 152 | 153 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 154 | downsample=None): 155 | super(SEResNetBottleneck, self).__init__() 156 | self.conv1 = nn.Conv2d( 157 | inplanes, planes, kernel_size=1, bias=False, stride=stride) 158 | self.bn1 = nn.BatchNorm2d(planes) 159 | self.conv2 = nn.Conv2d( 160 | planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) 161 | self.bn2 = nn.BatchNorm2d(planes) 162 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 163 | self.bn3 = nn.BatchNorm2d(planes * 4) 164 | self.relu = nn.ReLU(inplace=True) 165 | self.se_module = SEModule(planes * 4, reduction=reduction) 166 | self.downsample = downsample 167 | self.stride = stride 168 | 169 | 170 | class SEResNeXtBottleneck(Bottleneck): 171 | """ 172 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 173 | """ 174 | expansion = 4 175 | 176 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 177 | downsample=None, base_width=4): 178 | super(SEResNeXtBottleneck, self).__init__() 179 | width = math.floor(planes * (base_width / 64)) * groups 180 | self.conv1 = nn.Conv2d( 181 | inplanes, width, kernel_size=1, bias=False, stride=1) 182 | self.bn1 = nn.BatchNorm2d(width) 183 | self.conv2 = nn.Conv2d( 184 | width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) 185 | self.bn2 = nn.BatchNorm2d(width) 186 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 187 | self.bn3 = nn.BatchNorm2d(planes * 4) 188 | self.relu = nn.ReLU(inplace=True) 189 | self.se_module = SEModule(planes * 4, reduction=reduction) 190 | self.downsample = downsample 191 | self.stride = stride 192 | 193 | 194 | class SEResNetBlock(nn.Module): 195 | expansion = 1 196 | 197 | def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): 198 | super(SEResNetBlock, self).__init__() 199 | self.conv1 = nn.Conv2d( 200 | inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) 201 | self.bn1 = nn.BatchNorm2d(planes) 202 | self.conv2 = nn.Conv2d( 203 | planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) 204 | self.bn2 = nn.BatchNorm2d(planes) 205 | self.relu = nn.ReLU(inplace=True) 206 | self.se_module = SEModule(planes, reduction=reduction) 207 | self.downsample = downsample 208 | self.stride = stride 209 | 210 | def forward(self, x): 211 | residual = x 212 | 213 | out = self.conv1(x) 214 | out = self.bn1(out) 215 | out = self.relu(out) 216 | 217 | out = self.conv2(out) 218 | out = self.bn2(out) 219 | out = self.relu(out) 220 | 221 | if self.downsample is not None: 222 | residual = self.downsample(x) 223 | 224 | out = self.se_module(out) + residual 225 | out = self.relu(out) 226 | 227 | return out 228 | 229 | 230 | class SENet(AbstractCNN): 231 | 232 | def __init__(self, block, layers, groups, reduction, drop_rate=0.2, 233 | in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, 234 | downsample_padding=0, num_classes=1000, global_pool='avg'): 235 | """ 236 | Parameters 237 | ---------- 238 | block (nn.Module): Bottleneck class. 239 | - For SENet154: SEBottleneck 240 | - For SE-ResNet models: SEResNetBottleneck 241 | - For SE-ResNeXt models: SEResNeXtBottleneck 242 | layers (list of ints): Number of residual blocks for 4 layers of the 243 | network (layer1...layer4). 244 | groups (int): Number of groups for the 3x3 convolution in each 245 | bottleneck block. 246 | - For SENet154: 64 247 | - For SE-ResNet models: 1 248 | - For SE-ResNeXt models: 32 249 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 250 | - For all models: 16 251 | dropout_p (float or None): Drop probability for the Dropout layer. 252 | If `None` the Dropout layer is not used. 253 | - For SENet154: 0.2 254 | - For SE-ResNet models: None 255 | - For SE-ResNeXt models: None 256 | inplanes (int): Number of input channels for layer1. 257 | - For SENet154: 128 258 | - For SE-ResNet models: 64 259 | - For SE-ResNeXt models: 64 260 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 261 | a single 7x7 convolution in layer0. 262 | - For SENet154: True 263 | - For SE-ResNet models: False 264 | - For SE-ResNeXt models: False 265 | downsample_kernel_size (int): Kernel size for downsampling convolutions 266 | in layer2, layer3 and layer4. 267 | - For SENet154: 3 268 | - For SE-ResNet models: 1 269 | - For SE-ResNeXt models: 1 270 | downsample_padding (int): Padding for downsampling convolutions in 271 | layer2, layer3 and layer4. 272 | - For SENet154: 1 273 | - For SE-ResNet models: 0 274 | - For SE-ResNeXt models: 0 275 | num_classes (int): Number of outputs in `last_linear` layer. 276 | - For all models: 1000 277 | """ 278 | super(SENet, self).__init__() 279 | self.inplanes = inplanes 280 | self.num_classes = num_classes 281 | self.drop_rate = drop_rate 282 | if input_3x3: 283 | layer0_modules = [ 284 | ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), 285 | ('bn1', nn.BatchNorm2d(64)), 286 | ('relu1', nn.ReLU(inplace=True)), 287 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), 288 | ('bn2', nn.BatchNorm2d(64)), 289 | ('relu2', nn.ReLU(inplace=True)), 290 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), 291 | ('bn3', nn.BatchNorm2d(inplanes)), 292 | ('relu3', nn.ReLU(inplace=True)), 293 | ] 294 | else: 295 | layer0_modules = [ 296 | ('conv1', nn.Conv2d( 297 | in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), 298 | ('bn1', nn.BatchNorm2d(inplanes)), 299 | ('relu1', nn.ReLU(inplace=True)), 300 | ] 301 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 302 | # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`. 303 | self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 304 | self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')] 305 | self.layer1 = self._make_layer( 306 | block, 307 | planes=64, 308 | blocks=layers[0], 309 | groups=groups, 310 | reduction=reduction, 311 | downsample_kernel_size=1, 312 | downsample_padding=0 313 | ) 314 | self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')] 315 | self.layer2 = self._make_layer( 316 | block, 317 | planes=128, 318 | blocks=layers[1], 319 | stride=2, 320 | groups=groups, 321 | reduction=reduction, 322 | downsample_kernel_size=downsample_kernel_size, 323 | downsample_padding=downsample_padding 324 | ) 325 | self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')] 326 | self.layer3 = self._make_layer( 327 | block, 328 | planes=256, 329 | blocks=layers[2], 330 | stride=2, 331 | groups=groups, 332 | reduction=reduction, 333 | downsample_kernel_size=downsample_kernel_size, 334 | downsample_padding=downsample_padding 335 | ) 336 | self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')] 337 | self.layer4 = self._make_layer( 338 | block, 339 | planes=512, 340 | blocks=layers[3], 341 | stride=2, 342 | groups=groups, 343 | reduction=reduction, 344 | downsample_kernel_size=downsample_kernel_size, 345 | downsample_padding=downsample_padding 346 | ) 347 | self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] 348 | self.num_features = 512 * block.expansion 349 | self.embed_dim = 512 * block.expansion 350 | self.global_pool, self.last_linear = create_classifier( 351 | self.num_features, self.num_classes, pool_type=global_pool) 352 | 353 | for m in self.modules(): 354 | _weight_init(m) 355 | 356 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 357 | downsample_kernel_size=1, downsample_padding=0): 358 | downsample = None 359 | if stride != 1 or self.inplanes != planes * block.expansion: 360 | downsample = nn.Sequential( 361 | nn.Conv2d( 362 | self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, 363 | stride=stride, padding=downsample_padding, bias=False), 364 | nn.BatchNorm2d(planes * block.expansion), 365 | ) 366 | 367 | layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] 368 | self.inplanes = planes * block.expansion 369 | for i in range(1, blocks): 370 | layers.append(block(self.inplanes, planes, groups, reduction)) 371 | 372 | return nn.Sequential(*layers) 373 | 374 | def get_classifier(self): 375 | return self.last_linear 376 | 377 | def reset_classifier(self, num_classes, global_pool='avg'): 378 | self.num_classes = num_classes 379 | self.global_pool, self.last_linear = create_classifier( 380 | self.num_features, self.num_classes, pool_type=global_pool) 381 | 382 | def forward_features(self, x): 383 | x = self.layer0(x) 384 | x = self.pool0(x) 385 | x = self.layer1(x) 386 | x = self.layer2(x) 387 | x = self.layer3(x) 388 | x = self.layer4(x) 389 | return x 390 | 391 | def logits(self, x): 392 | x = self.global_pool(x) 393 | if self.drop_rate > 0.: 394 | x = F.dropout(x, p=self.drop_rate, training=self.training) 395 | x = self.head(x) 396 | return x 397 | 398 | def forward(self, x): 399 | x = self.forward_features(x) 400 | x = self.logits(x) 401 | return x 402 | 403 | 404 | def _create_senet(variant, pretrained=False, **kwargs): 405 | return build_model_with_cfg( 406 | SENet, variant, pretrained, 407 | default_cfg=default_cfgs[variant], 408 | **kwargs) 409 | 410 | 411 | @register_model 412 | def legacy_seresnet18(pretrained=False, **kwargs): 413 | model_args = dict( 414 | block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs) 415 | return _create_senet('legacy_seresnet18', pretrained, **model_args) 416 | 417 | 418 | @register_model 419 | def legacy_seresnet34(pretrained=False, **kwargs): 420 | model_args = dict( 421 | block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) 422 | return _create_senet('legacy_seresnet34', pretrained, **model_args) 423 | 424 | 425 | @register_model 426 | def legacy_seresnet50(pretrained=False, **kwargs): 427 | model_args = dict( 428 | block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) 429 | return _create_senet('legacy_seresnet50', pretrained, **model_args) 430 | 431 | 432 | @register_model 433 | def legacy_seresnet101(pretrained=False, **kwargs): 434 | model_args = dict( 435 | block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs) 436 | return _create_senet('legacy_seresnet101', pretrained, **model_args) 437 | 438 | 439 | @register_model 440 | def legacy_seresnet152(pretrained=False, **kwargs): 441 | model_args = dict( 442 | block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs) 443 | return _create_senet('legacy_seresnet152', pretrained, **model_args) 444 | 445 | 446 | @register_model 447 | def legacy_senet154(pretrained=False, **kwargs): 448 | model_args = dict( 449 | block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16, 450 | downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs) 451 | return _create_senet('legacy_senet154', pretrained, **model_args) 452 | 453 | 454 | @register_model 455 | def legacy_seresnext26_32x4d(pretrained=False, **kwargs): 456 | model_args = dict( 457 | block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs) 458 | return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args) 459 | 460 | 461 | @register_model 462 | def legacy_seresnext50_32x4d(pretrained=False, **kwargs): 463 | model_args = dict( 464 | block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs) 465 | return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args) 466 | 467 | 468 | @register_model 469 | def legacy_seresnext101_32x4d(pretrained=False, **kwargs): 470 | model_args = dict( 471 | block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs) 472 | return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args) 473 | -------------------------------------------------------------------------------- /continual/cnn/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from .utils import load_state_dict_from_url 4 | from typing import Union, List, Dict, Any, cast 5 | 6 | from continual.cnn import AbstractCNN 7 | 8 | __all__ = [ 9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 10 | 'vgg19_bn', 'vgg19', 11 | ] 12 | 13 | 14 | model_urls = { 15 | 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth', 16 | 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth', 17 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 18 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 19 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 20 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 21 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 22 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 23 | } 24 | 25 | 26 | class VGG(AbstractCNN): 27 | 28 | def __init__( 29 | self, 30 | features: nn.Module, 31 | num_classes: int = 1000, 32 | init_weights: bool = True 33 | ) -> None: 34 | super(VGG, self).__init__() 35 | self.features = features 36 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 37 | self.classifier = nn.Sequential( 38 | nn.Linear(512 * 7 * 7, 4096), 39 | nn.ReLU(True), 40 | nn.Dropout(), 41 | nn.Linear(4096, 4096), 42 | nn.ReLU(True), 43 | nn.Dropout(), 44 | ) 45 | 46 | self.head = None 47 | self.embed_dim = 4096 48 | 49 | if init_weights: 50 | self._initialize_weights() 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.features(x) 54 | x = self.avgpool(x) 55 | x = torch.flatten(x, 1) 56 | x = self.classifier(x) 57 | return self.head(x) 58 | 59 | def _initialize_weights(self) -> None: 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 63 | if m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | nn.init.constant_(m.weight, 1) 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.Linear): 69 | nn.init.normal_(m.weight, 0, 0.01) 70 | nn.init.constant_(m.bias, 0) 71 | 72 | 73 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 74 | layers: List[nn.Module] = [] 75 | in_channels = 3 76 | for v in cfg: 77 | if v == 'M': 78 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 79 | else: 80 | v = cast(int, v) 81 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 82 | if batch_norm: 83 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 84 | else: 85 | layers += [conv2d, nn.ReLU(inplace=True)] 86 | in_channels = v 87 | return nn.Sequential(*layers) 88 | 89 | 90 | cfgs: Dict[str, List[Union[str, int]]] = { 91 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 92 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 93 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 94 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 95 | } 96 | 97 | 98 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: 99 | if pretrained: 100 | kwargs['init_weights'] = False 101 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 102 | if pretrained: 103 | state_dict = load_state_dict_from_url(model_urls[arch], 104 | progress=progress) 105 | model.load_state_dict(state_dict) 106 | return model 107 | 108 | 109 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 110 | r"""VGG 11-layer model (configuration "A") from 111 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 112 | 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | progress (bool): If True, displays a progress bar of the download to stderr 116 | """ 117 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 118 | 119 | 120 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 121 | r"""VGG 11-layer model (configuration "A") with batch normalization 122 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | progress (bool): If True, displays a progress bar of the download to stderr 127 | """ 128 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 129 | 130 | 131 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 132 | r"""VGG 13-layer model (configuration "B") 133 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 134 | 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | progress (bool): If True, displays a progress bar of the download to stderr 138 | """ 139 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 140 | 141 | 142 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 143 | r"""VGG 13-layer model (configuration "B") with batch normalization 144 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 145 | 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | progress (bool): If True, displays a progress bar of the download to stderr 149 | """ 150 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 151 | 152 | 153 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 154 | r"""VGG 16-layer model (configuration "D") 155 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 162 | 163 | 164 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 165 | r"""VGG 16-layer model (configuration "D") with batch normalization 166 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | progress (bool): If True, displays a progress bar of the download to stderr 171 | """ 172 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 173 | 174 | 175 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 176 | r"""VGG 19-layer model (configuration "E") 177 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | progress (bool): If True, displays a progress bar of the download to stderr 182 | """ 183 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 184 | 185 | 186 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 187 | r"""VGG 19-layer model (configuration 'E') with batch normalization 188 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | progress (bool): If True, displays a progress bar of the download to stderr 193 | """ 194 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 195 | -------------------------------------------------------------------------------- /continual/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import json 4 | import os 5 | import warnings 6 | 7 | from continuum import ClassIncremental 8 | # from continuum import Permutations 9 | from continual.mycontinual import Rotations, IncrementalRotation, Permutations 10 | from continuum.datasets import CIFAR100, MNIST, ImageNet100, ImageFolderDataset, CIFAR10, TinyImageNet200, STL10 11 | from timm.data import create_transform 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from torchvision import transforms 14 | from torchvision.datasets.folder import ImageFolder, default_loader 15 | from torchvision.transforms import functional as Fv 16 | 17 | from typing import Tuple, Union 18 | 19 | import numpy as np 20 | 21 | from continuum.datasets import ImageFolderDataset 22 | from continuum.download import download, unzip 23 | 24 | try: 25 | interpolation = Fv.InterpolationMode.BICUBIC 26 | except: 27 | interpolation = 3 28 | 29 | 30 | from torch.utils.data import DataLoader 31 | import torch.nn.functional as F 32 | from argparse import Namespace 33 | from copy import deepcopy 34 | import torch 35 | from PIL import Image 36 | # from datasets.utils.validation import get_train_val 37 | from typing import Tuple 38 | 39 | 40 | class ImageNet1000(ImageFolderDataset): 41 | """Continuum dataset for datasets with tree-like structure. 42 | :param train_folder: The folder of the train data. 43 | :param test_folder: The folder of the test data. 44 | :param download: Dummy parameter. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | data_path: str, 50 | train: bool = True, 51 | download: bool = False, 52 | ): 53 | super().__init__(data_path=data_path, train=train, download=download) 54 | 55 | def get_data(self): 56 | if self.train: 57 | self.data_path = os.path.join(self.data_path, "train") 58 | else: 59 | self.data_path = os.path.join(self.data_path, "val") 60 | return super().get_data() 61 | 62 | 63 | class INatDataset(ImageFolder): 64 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 65 | category='name', loader=default_loader): 66 | self.transform = transform 67 | self.loader = loader 68 | self.target_transform = target_transform 69 | self.year = year 70 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 71 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 72 | with open(path_json) as json_file: 73 | data = json.load(json_file) 74 | 75 | with open(os.path.join(root, 'categories.json')) as json_file: 76 | data_catg = json.load(json_file) 77 | 78 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 79 | 80 | with open(path_json_for_targeter) as json_file: 81 | data_for_targeter = json.load(json_file) 82 | 83 | targeter = {} 84 | indexer = 0 85 | for elem in data_for_targeter['annotations']: 86 | king = [] 87 | king.append(data_catg[int(elem['category_id'])][category]) 88 | if king[0] not in targeter.keys(): 89 | targeter[king[0]] = indexer 90 | indexer += 1 91 | self.nb_classes = len(targeter) 92 | 93 | self.samples = [] 94 | for elem in data['images']: 95 | cut = elem['file_name'].split('/') 96 | target_current = int(cut[2]) 97 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 98 | 99 | categors = data_catg[target_current] 100 | target_current_true = targeter[categors[category]] 101 | self.samples.append((path_current, target_current_true)) 102 | 103 | 104 | def build_dataset(is_train, args): 105 | transform = build_transform(is_train, args) 106 | 107 | if args.data_set.lower() == 'cifar10': 108 | dataset = CIFAR10(args.data_path, train=is_train, download=True) 109 | elif args.data_set.lower() == 'cifar': 110 | dataset = CIFAR100(args.data_path, train=is_train, download=True) 111 | elif args.data_set.lower() == 'tinyimg': 112 | dataset = TinyImageNet200(args.data_path, train=is_train, download=True) 113 | elif args.data_set.lower() == 'imagenet100': 114 | dataset = ImageNet100_local( 115 | args.data_path, train=is_train, 116 | data_subset=os.path.join('./imagenet100_splits', "train_100.txt" if is_train else "val_100.txt") 117 | ) 118 | elif args.data_set.lower() == 'imagenet1000': 119 | dataset = ImageNet1000(args.data_path, train=is_train) 120 | else: 121 | raise ValueError(f'Unknown dataset {args.data_set}.') 122 | 123 | scenario = ClassIncremental( 124 | dataset, 125 | initial_increment=args.initial_increment, 126 | increment=args.increment, 127 | transformations=transform.transforms, 128 | class_order=args.class_order 129 | ) 130 | nb_classes = scenario.nb_classes #100 131 | 132 | return scenario, nb_classes 133 | 134 | 135 | def build_transform(is_train, args): 136 | if args.aa == 'none': 137 | args.aa = None 138 | 139 | with warnings.catch_warnings(): 140 | resize_im = args.input_size > 32 141 | if is_train: 142 | # this should always dispatch to transforms_imagenet_train 143 | transform = create_transform( 144 | input_size=args.input_size, 145 | is_training=True, 146 | color_jitter=args.color_jitter, 147 | auto_augment=args.aa, 148 | interpolation='bicubic', 149 | re_prob=args.reprob, 150 | re_mode=args.remode, 151 | re_count=args.recount, 152 | ) 153 | if not resize_im: 154 | transform.transforms[0] = transforms.RandomCrop( 155 | args.input_size, padding=4) 156 | 157 | if args.input_size == 32 and (args.data_set == 'CIFAR' or args.data_set == 'CIFAR10'): 158 | transform.transforms[-1] = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 159 | elif args.data_set == 'STL10': 160 | transform.transforms[-1] = transforms.Normalize((0.4192, 0.4124, 0.3804), (0.2714, 0.2679, 0.2771)) 161 | return transform 162 | 163 | t = [] 164 | if resize_im and args.data_set != 'TINYIMG': 165 | size = int((256 / 224) * args.input_size) 166 | t.append( 167 | transforms.Resize(size, interpolation=interpolation), # to maintain same ratio w.r.t. 224 images 168 | ) 169 | t.append(transforms.CenterCrop(args.input_size)) 170 | 171 | t.append(transforms.ToTensor()) 172 | if args.input_size == 32 and (args.data_set == 'CIFAR' or args.data_set == 'CIFAR10'): 173 | # Normalization values for CIFAR100 174 | t.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))) 175 | else: 176 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 177 | 178 | composed_transforms = transforms.Compose(t) 179 | return composed_transforms 180 | 181 | 182 | class ImageNet1000_local(ImageFolderDataset): 183 | """ImageNet1000 dataset. 184 | 185 | Simple wrapper around ImageFolderDataset to provide a link to the download 186 | page. 187 | """ 188 | def __init__(self, *args, **kwargs): 189 | super().__init__(*args, **kwargs) 190 | # if self.train: 191 | # self.data_path = os.path.join(self.data_path, "train") 192 | # else: 193 | # self.data_path = os.path.join(self.data_path, "val") 194 | 195 | @property 196 | def transformations(self): 197 | """Default transformations if nothing is provided to the scenario.""" 198 | return [transforms.ToTensor(), 199 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 200 | 201 | def _download(self): 202 | if not os.path.exists(self.data_path): 203 | raise IOError( 204 | "You must download yourself the ImageNet dataset." 205 | " Please go to http://www.image-net.org/challenges/LSVRC/2012/downloads and" 206 | " download 'Training images (Task 1 & 2)' and 'Validation images (all tasks)'." 207 | ) 208 | print("ImageNet already downloaded.") 209 | 210 | 211 | class ImageNet100_local(ImageNet1000_local): 212 | """Subset of ImageNet1000 made of only 100 classes. 213 | 214 | You must download the ImageNet1000 dataset then provide the images subset. 215 | If in doubt, use the option at initialization `download=True` and it will 216 | auto-download for you the subset ids used in: 217 | * Small Task Incremental Learning 218 | Douillard et al. 2020 219 | """ 220 | 221 | train_subset_url = "https://github.com/Continvvm/continuum/releases/download/v0.1/train_100.txt" 222 | test_subset_url = "https://github.com/Continvvm/continuum/releases/download/v0.1/val_100.txt" 223 | 224 | def __init__( 225 | self, *args, data_subset: Union[Tuple[np.array, np.array], str, None] = None, **kwargs 226 | ): 227 | self.data_subset = data_subset 228 | super().__init__(*args, **kwargs) 229 | 230 | def _download(self): 231 | super()._download() 232 | 233 | filename = "val_100.txt" 234 | self.subset_url = self.test_subset_url 235 | if self.train: 236 | filename = "train_100.txt" 237 | self.subset_url = self.train_subset_url 238 | 239 | if self.data_subset is None: 240 | self.data_subset = os.path.join(self.data_path, filename) 241 | download(self.subset_url, self.data_path) 242 | 243 | def get_data(self) -> Tuple[np.ndarray, np.ndarray, Union[np.ndarray, None]]: 244 | data = self._parse_subset(self.data_subset, train=self.train) # type: ignore 245 | return (*data, None) 246 | 247 | def _parse_subset( 248 | self, 249 | subset: Union[Tuple[np.array, np.array], str, None], 250 | train: bool = True 251 | ) -> Tuple[np.array, np.array]: 252 | if isinstance(subset, str): 253 | x, y = [], [] 254 | 255 | with open(subset, "r") as f: 256 | for line in f: 257 | split_line = line.split(" ") 258 | path = split_line[0].strip() 259 | x.append(os.path.join(self.data_path, path)) 260 | y.append(int(split_line[1].strip())) 261 | x = np.array(x) 262 | y = np.array(y) 263 | return x, y 264 | return subset # type: ignore 265 | -------------------------------------------------------------------------------- /continual/factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from continual import convit, birt, samplers, vit 4 | from continual.cnn import (InceptionV3, rebuffi, resnet18, resnet34, resnet50, 5 | resnext50_32x4d, seresnet18, vgg16, vgg16_bn, 6 | wide_resnet50_2, resnet18_scs, resnet18_scs_max, resnet18_scs_avg) 7 | 8 | 9 | def get_backbone(args): 10 | print(f"Creating model: {args.model}") 11 | if args.model == 'vit': 12 | model = vit.VisionTransformer( 13 | num_classes=args.nb_classes, 14 | drop_rate=args.drop, 15 | drop_path_rate=args.drop_path, 16 | img_size=args.input_size, 17 | patch_size=args.patch_size, 18 | embed_dim=args.embed_dim, 19 | depth=args.depth, 20 | num_heads=args.num_heads 21 | ) 22 | elif args.model == 'convit': 23 | model = convit.ConVit( 24 | num_classes=args.nb_classes, 25 | drop_rate=args.drop, 26 | drop_path_rate=args.drop_path, 27 | img_size=args.input_size, 28 | patch_size=args.patch_size, 29 | embed_dim=args.embed_dim, 30 | depth=args.depth, 31 | num_heads=args.num_heads, 32 | local_up_to_layer=args.local_up_to_layer, 33 | locality_strength=args.locality_strength, 34 | class_attention=args.class_attention, 35 | ca_type='jointca' if args.joint_tokens else 'base', 36 | norm_layer=args.norm, 37 | dynamic_tokens=args.dynamic_tokens, 38 | num_blocks=args.replay_from, 39 | attn_version=args.attn_version 40 | ) 41 | elif args.model == 'resnet18_scs': model = resnet18_scs() 42 | elif args.model == 'resnet18_scs_avg': model = resnet18_scs_max() 43 | elif args.model == 'resnet18_scs_max': model = resnet18_scs_avg() 44 | elif args.model == 'resnet18': model = resnet18() 45 | elif args.model == 'resnet34': model = resnet34() 46 | elif args.model == 'resnet50': model = resnet50() 47 | elif args.model == 'wide_resnet50': model = wide_resnet50_2() 48 | elif args.model == 'resnext50': model = resnext50_32x4d() 49 | elif args.model == 'seresnet18': model = seresnet18() 50 | elif args.model == 'inception3': model = InceptionV3() 51 | elif args.model == 'vgg16bn': model = vgg16_bn() 52 | elif args.model == 'vgg16': model = vgg16() 53 | elif args.model == 'rebuffi': model = rebuffi() 54 | else: 55 | raise NotImplementedError(f'Unknown backbone {args.model}') 56 | 57 | return model 58 | 59 | 60 | 61 | def get_loaders(dataset_train, dataset_val, args, drop_last=True): 62 | sampler_train, sampler_val = samplers.get_sampler(dataset_train, dataset_val, args) 63 | 64 | loader_train = torch.utils.data.DataLoader( 65 | dataset_train, sampler=sampler_train, 66 | batch_size=args.batch_size, 67 | num_workers=args.num_workers, 68 | pin_memory=args.pin_mem, 69 | drop_last=drop_last, 70 | ) 71 | 72 | loader_val = torch.utils.data.DataLoader( 73 | dataset_val, sampler=sampler_val, 74 | batch_size=int(1.5 * args.batch_size), 75 | num_workers=args.num_workers, 76 | pin_memory=args.pin_mem, 77 | drop_last=False 78 | ) 79 | 80 | return loader_train, loader_val 81 | 82 | 83 | def get_train_loaders(dataset_train, args, batch_size=None, drop_last=True): 84 | batch_size = batch_size or args.batch_size 85 | 86 | sampler_train = samplers.get_train_sampler(dataset_train, args) 87 | 88 | loader_train = torch.utils.data.DataLoader( 89 | dataset_train, sampler=sampler_train, 90 | batch_size=batch_size, 91 | num_workers=args.num_workers, 92 | pin_memory=False, 93 | drop_last=drop_last, 94 | ) 95 | 96 | return loader_train 97 | 98 | 99 | class InfiniteLoader: 100 | def __init__(self, loader): 101 | self.loader = loader 102 | self.reset() 103 | 104 | def reset(self): 105 | self.it = iter(self.loader) 106 | 107 | def get(self): 108 | try: 109 | return next(self.it) 110 | except StopIteration: 111 | self.reset() 112 | return self.get() 113 | 114 | 115 | def update_birt(model_without_ddp, task_id, args): 116 | if task_id == 0: 117 | print(f'Creating BiRT!') 118 | model_without_ddp = birt.BiRT( 119 | model_without_ddp, 120 | nb_classes=args.initial_increment, # 10 121 | individual_classifier=args.ind_clf, #1-1 122 | head_div=args.head_div > 0., # 0.1 123 | head_div_mode=args.head_div_mode, # 'tr' 124 | joint_tokens=args.joint_tokens, # False 125 | num_blocks=args.replay_from, # from which block to replay representation 126 | multi_token_setup=args.multi_token_setup, 127 | ) 128 | else: 129 | print(f'Updating ensemble, new embed dim {model_without_ddp.sabs[-1].dim}.') 130 | model_without_ddp.add_model(args.increment, args.multi_token_setup) 131 | 132 | return model_without_ddp 133 | -------------------------------------------------------------------------------- /continual/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | class DistillationLoss(torch.nn.Module): 12 | """ 13 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 14 | taking a teacher model prediction and using it as additional supervision. 15 | """ 16 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 17 | distillation_type: str, alpha: float, tau: float): 18 | super().__init__() 19 | self.base_criterion = base_criterion 20 | self.teacher_model = teacher_model 21 | assert distillation_type in ['none', 'soft', 'hard'] 22 | self.distillation_type = distillation_type 23 | self.alpha = alpha 24 | self.tau = tau 25 | 26 | def forward(self, inputs, outputs, labels): 27 | """ 28 | Args: 29 | inputs: The original inputs that are feed to the teacher model 30 | outputs: the outputs of the model to be trained. It is expected to be 31 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 32 | in the first position and the distillation predictions as the second output 33 | labels: the labels for the base criterion 34 | """ 35 | outputs_kd = None 36 | if not isinstance(outputs, torch.Tensor): 37 | # assume that the model outputs a tuple of [outputs, outputs_kd] 38 | outputs, outputs_kd = outputs 39 | base_loss = self.base_criterion(outputs, labels) 40 | if self.distillation_type == 'none': 41 | return base_loss 42 | 43 | if outputs_kd is None: 44 | raise ValueError("When knowledge distillation is enabled, the model is " 45 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 46 | "class_token and the dist_token") 47 | # don't backprop throught the teacher 48 | with torch.no_grad(): 49 | teacher_outputs = self.teacher_model(inputs) 50 | 51 | if self.distillation_type == 'soft': 52 | T = self.tau 53 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 54 | # with slight modifications 55 | distillation_loss = F.kl_div( 56 | F.log_softmax(outputs_kd / T, dim=1), 57 | F.log_softmax(teacher_outputs / T, dim=1), 58 | reduction='sum', 59 | log_target=True 60 | ) * (T * T) / outputs_kd.numel() 61 | elif self.distillation_type == 'hard': 62 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 63 | 64 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 65 | return loss 66 | 67 | 68 | def bce_with_logits(x, y): 69 | return F.binary_cross_entropy_with_logits( 70 | x, 71 | torch.eye(x.shape[1])[y].to(y.device) 72 | ) 73 | 74 | 75 | def bce_smooth_pos_with_logits(smooth): 76 | def _func(x, y): 77 | return F.binary_cross_entropy_with_logits( 78 | x, 79 | torch.clamp( 80 | torch.eye(x.shape[1])[y].to(y.device) - smooth, 81 | min=0.0 82 | ) 83 | ) 84 | return _func 85 | 86 | 87 | def bce_smooth_posneg_with_logits(smooth): 88 | def _func(x, y): 89 | return F.binary_cross_entropy_with_logits( 90 | x, 91 | torch.clamp( 92 | torch.eye(x.shape[1])[y].to(y.device) + smooth, 93 | max=1 - smooth 94 | ) 95 | ) 96 | return _func 97 | 98 | 99 | class LabelSmoothingCrossEntropyBoosting(nn.Module): 100 | """ 101 | NLL loss with label smoothing. 102 | """ 103 | def __init__(self, smoothing=0.1, alpha=1, gamma=1): 104 | """ 105 | Constructor for the LabelSmoothing module. 106 | :param smoothing: label smoothing factor 107 | """ 108 | super().__init__() 109 | assert smoothing < 1.0 110 | self.smoothing = smoothing 111 | self.confidence = 1. - smoothing 112 | 113 | self.alpha = alpha 114 | self.gamma = gamma 115 | 116 | def forward(self, x, target, boosting_output=None, boosting_focal=None): 117 | if boosting_output is None: 118 | return self._base_loss(x, target) 119 | return self._focal_loss(x, target, boosting_output, boosting_focal) 120 | 121 | def _focal_loss(self, x, target, boosting_output, boosting_focal): 122 | logprobs = F.log_softmax(x, dim=-1) 123 | 124 | if boosting_focal == 'old': 125 | pt = boosting_output.softmax(-1)[..., :-1] 126 | 127 | f = torch.ones_like(logprobs) 128 | f[:, :boosting_output.shape[1] - 1] = self.alpha * (1 - pt) ** self.gamma 129 | logprobs = f * logprobs 130 | elif boosting_focal == 'new': 131 | pt = boosting_output.softmax(-1)[..., -1] 132 | nb_old_classes = boosting_output.shape[1] - 1 133 | 134 | f = torch.ones_like(logprobs) 135 | f[:, nb_old_classes:] = self.alpha * (1 - pt[:, None]) ** self.gamma 136 | logprobs = f * logprobs 137 | else: 138 | assert False, (boosting_focal) 139 | 140 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 141 | nll_loss = nll_loss.squeeze(1) 142 | smooth_loss = -logprobs.mean(dim=-1) 143 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 144 | return loss.mean() 145 | 146 | def _base_loss(self, x, target): 147 | logprobs = F.log_softmax(x, dim=-1) 148 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 149 | nll_loss = nll_loss.squeeze(1) 150 | smooth_loss = -logprobs.mean(dim=-1) 151 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 152 | return loss.mean() 153 | 154 | 155 | class SoftTargetCrossEntropyBoosting(nn.Module): 156 | 157 | def __init__(self, alpha=1, gamma=1): 158 | super().__init__() 159 | self.alpha = alpha 160 | self.gamma = gamma 161 | 162 | def forward(self, x, target, boosting_output=None, boosting_focal=None): 163 | if boosting_output is None: 164 | return torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1).mean() 165 | 166 | if boosting_focal == 'old': 167 | pt = boosting_output.softmax(-1)[..., :-1] 168 | 169 | f = torch.ones_like(x) 170 | f[:, :boosting_output.shape[1] - 1] = self.alpha * (1 - pt) ** self.gamma 171 | elif boosting_focal == 'new': 172 | pt = boosting_output.softmax(-1)[..., -1] 173 | 174 | nb_old_classes = boosting_output.shape[1] - 1 175 | 176 | f = torch.ones_like(x) 177 | f[:, nb_old_classes:] = self.alpha * (1 - pt[:, None]) ** self.gamma 178 | else: 179 | assert False, (boosting_focal) 180 | 181 | return torch.sum(-target * f * F.log_softmax(x, dim=-1), dim=-1).mean() 182 | -------------------------------------------------------------------------------- /continual/misc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeurAI-Lab/BiRT/1340a97bff17fc02e228b754bd80e1fd649ff9cd/continual/misc.py -------------------------------------------------------------------------------- /continual/mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2020 Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda', old_target=None): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | if old_target is not None: 27 | y2 = one_hot(old_target, num_classes, on_value=on_value, off_value=off_value, device=device) 28 | else: 29 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 30 | return y1 * lam + y2 * (1. - lam) 31 | 32 | 33 | def rand_bbox(img_shape, lam, margin=0., count=None): 34 | """ Standard CutMix bounding-box 35 | Generates a random square bbox based on lambda value. This impl includes 36 | support for enforcing a border margin as percent of bbox dimensions. 37 | 38 | Args: 39 | img_shape (tuple): Image shape as tuple 40 | lam (float): Cutmix lambda value 41 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 42 | count (int): Number of bbox to generate 43 | """ 44 | ratio = np.sqrt(1 - lam) 45 | img_h, img_w = img_shape[-2:] 46 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 47 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 48 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 49 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 50 | yl = np.clip(cy - cut_h // 2, 0, img_h) 51 | yh = np.clip(cy + cut_h // 2, 0, img_h) 52 | xl = np.clip(cx - cut_w // 2, 0, img_w) 53 | xh = np.clip(cx + cut_w // 2, 0, img_w) 54 | return yl, yh, xl, xh 55 | 56 | 57 | def rand_bbox_minmax(img_shape, minmax, count=None): 58 | """ Min-Max CutMix bounding-box 59 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 60 | based on min/max percent values applied to each dimension of the input image. 61 | 62 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 63 | 64 | Args: 65 | img_shape (tuple): Image shape as tuple 66 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 67 | count (int): Number of bbox to generate 68 | """ 69 | assert len(minmax) == 2 70 | img_h, img_w = img_shape[-2:] 71 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 72 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 73 | yl = np.random.randint(0, img_h - cut_h, size=count) 74 | xl = np.random.randint(0, img_w - cut_w, size=count) 75 | yu = yl + cut_h 76 | xu = xl + cut_w 77 | return yl, yu, xl, xu 78 | 79 | 80 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 81 | """ Generate bbox and apply lambda correction. 82 | """ 83 | if ratio_minmax is not None: 84 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 85 | else: 86 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 87 | if correct_lam or ratio_minmax is not None: 88 | bbox_area = (yu - yl) * (xu - xl) 89 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 90 | return (yl, yu, xl, xu), lam 91 | 92 | 93 | class Mixup: 94 | """ Mixup/Cutmix that applies different params to each element or whole batch 95 | 96 | Args: 97 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 98 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 99 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 100 | prob (float): probability of applying mixup or cutmix per batch or element 101 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 102 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 103 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 104 | label_smoothing (float): apply label smoothing to the mixed target tensor 105 | num_classes (int): number of classes for target 106 | """ 107 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 108 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, 109 | loader_memory=None): 110 | self.mixup_alpha = mixup_alpha 111 | self.cutmix_alpha = cutmix_alpha 112 | self.cutmix_minmax = cutmix_minmax 113 | if self.cutmix_minmax is not None: 114 | assert len(self.cutmix_minmax) == 2 115 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 116 | self.cutmix_alpha = 1.0 117 | self.mix_prob = prob 118 | self.switch_prob = switch_prob 119 | self.label_smoothing = label_smoothing 120 | self.num_classes = num_classes 121 | self.mode = mode 122 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 123 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 124 | self.loader_memory = loader_memory 125 | 126 | def _params_per_elem(self, batch_size): 127 | lam = np.ones(batch_size, dtype=np.float32) 128 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 129 | if self.mixup_enabled: 130 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 131 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 132 | lam_mix = np.where( 133 | use_cutmix, 134 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 135 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 136 | elif self.mixup_alpha > 0.: 137 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 138 | elif self.cutmix_alpha > 0.: 139 | use_cutmix = np.ones(batch_size, dtype=np.bool) 140 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 141 | else: 142 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 143 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 144 | return lam, use_cutmix 145 | 146 | def _params_per_batch(self): 147 | lam = 1. 148 | use_cutmix = False 149 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 150 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 151 | use_cutmix = np.random.rand() < self.switch_prob 152 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 153 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 154 | elif self.mixup_alpha > 0.: 155 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 156 | elif self.cutmix_alpha > 0.: 157 | use_cutmix = True 158 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 159 | else: 160 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 161 | lam = float(lam_mix) 162 | return lam, use_cutmix 163 | 164 | def _mix_elem(self, x): 165 | batch_size = len(x) 166 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 167 | x_orig = x.clone() # need to keep an unmodified original for mixing source 168 | for i in range(batch_size): 169 | j = batch_size - i - 1 170 | lam = lam_batch[i] 171 | if lam != 1.: 172 | if use_cutmix[i]: 173 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 174 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 175 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 176 | lam_batch[i] = lam 177 | else: 178 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 179 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 180 | 181 | def _mix_pair(self, x): 182 | batch_size = len(x) 183 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 184 | x_orig = x.clone() # need to keep an unmodified original for mixing source 185 | for i in range(batch_size // 2): 186 | j = batch_size - i - 1 187 | lam = lam_batch[i] 188 | if lam != 1.: 189 | if use_cutmix[i]: 190 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 191 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 192 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 193 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 194 | lam_batch[i] = lam 195 | else: 196 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 197 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 198 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 199 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 200 | 201 | def _mix_batch(self, x): 202 | lam, use_cutmix = self._params_per_batch() 203 | if lam == 1.: 204 | return 1. 205 | if use_cutmix: 206 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 207 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 208 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] 209 | else: 210 | x_flipped = x.flip(0).mul_(1. - lam) 211 | x.mul_(lam).add_(x_flipped) 212 | return lam 213 | 214 | def _mix_old(self, x, old_x): 215 | lam, use_cutmix = self._params_per_batch() 216 | if lam == 1.: 217 | return 1. 218 | if use_cutmix: 219 | assert False 220 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 221 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 222 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] 223 | else: 224 | x_flipped = x.flip(0).mul_(1. - lam) 225 | x.mul_(lam).add_(x_flipped) 226 | #x.mul_(lam).add_(old_x.mul_(1. - lam)) 227 | return lam 228 | 229 | def __call__(self, x, target): 230 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 231 | old_y = None 232 | if self.mode == 'elem': 233 | lam = self._mix_elem(x) 234 | elif self.mode == 'pair': 235 | lam = self._mix_pair(x) 236 | elif self.mode == 'batch' or (self.mode == 'old' and self.loader_memory is None): 237 | lam = self._mix_batch(x) 238 | else: # old 239 | old_x, old_y, _ = self.loader_memory.get() 240 | old_x, old_y = old_x.to(x.device), old_y.to(x.device) 241 | lam = self._mix_old(x, old_x) 242 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, old_target=old_y) 243 | return x, target, lam 244 | 245 | 246 | class FastCollateMixup(Mixup): 247 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 248 | 249 | A Mixup impl that's performed while collating the batches. 250 | """ 251 | 252 | def _mix_elem_collate(self, output, batch, half=False): 253 | batch_size = len(batch) 254 | num_elem = batch_size // 2 if half else batch_size 255 | assert len(output) == num_elem 256 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 257 | for i in range(num_elem): 258 | j = batch_size - i - 1 259 | lam = lam_batch[i] 260 | mixed = batch[i][0] 261 | if lam != 1.: 262 | if use_cutmix[i]: 263 | if not half: 264 | mixed = mixed.copy() 265 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 266 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 267 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 268 | lam_batch[i] = lam 269 | else: 270 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 271 | np.rint(mixed, out=mixed) 272 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 273 | if half: 274 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 275 | return torch.tensor(lam_batch).unsqueeze(1) 276 | 277 | def _mix_pair_collate(self, output, batch): 278 | batch_size = len(batch) 279 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 280 | for i in range(batch_size // 2): 281 | j = batch_size - i - 1 282 | lam = lam_batch[i] 283 | mixed_i = batch[i][0] 284 | mixed_j = batch[j][0] 285 | assert 0 <= lam <= 1.0 286 | if lam < 1.: 287 | if use_cutmix[i]: 288 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 289 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 290 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 291 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 292 | mixed_j[:, yl:yh, xl:xh] = patch_i 293 | lam_batch[i] = lam 294 | else: 295 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 296 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 297 | mixed_i = mixed_temp 298 | np.rint(mixed_j, out=mixed_j) 299 | np.rint(mixed_i, out=mixed_i) 300 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 301 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 302 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 303 | return torch.tensor(lam_batch).unsqueeze(1) 304 | 305 | def _mix_batch_collate(self, output, batch): 306 | batch_size = len(batch) 307 | lam, use_cutmix = self._params_per_batch() 308 | if use_cutmix: 309 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 310 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 311 | for i in range(batch_size): 312 | j = batch_size - i - 1 313 | mixed = batch[i][0] 314 | if lam != 1.: 315 | if use_cutmix: 316 | mixed = mixed.copy() # don't want to modify the original while iterating 317 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 318 | else: 319 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 320 | np.rint(mixed, out=mixed) 321 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 322 | return lam 323 | 324 | def __call__(self, batch, _=None): 325 | batch_size = len(batch) 326 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 327 | half = 'half' in self.mode 328 | if half: 329 | batch_size //= 2 330 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 331 | if self.mode == 'elem' or self.mode == 'half': 332 | lam = self._mix_elem_collate(output, batch, half=half) 333 | elif self.mode == 'pair': 334 | lam = self._mix_pair_collate(output, batch) 335 | else: 336 | lam = self._mix_batch_collate(output, batch) 337 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 338 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 339 | target = target[:batch_size] 340 | return output, target 341 | 342 | -------------------------------------------------------------------------------- /continual/mycontinual/__init__.py: -------------------------------------------------------------------------------- 1 | from continual.mycontinual.transformation_incremental import TransformationIncremental 2 | from continual.mycontinual.rotations import Rotations 3 | from continual.mycontinual.permutations import Permutations 4 | from continual.mycontinual.incremental_rotation import IncrementalRotation 5 | from continual.mycontinual.custom_array_task_set import ArrayTaskSet 6 | 7 | __all__ = [ 8 | "Rotations", 9 | "TransformationIncremental", 10 | "IncrementalRotation", 11 | "Permutations", 12 | "ArrayTaskSet" 13 | ] -------------------------------------------------------------------------------- /continual/mycontinual/custom_array_task_set.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Optional, List 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | 8 | from continuum.viz import plot_samples 9 | from continuum.tasks.base import BaseTaskSet, _tensorize_list, TaskType 10 | 11 | 12 | class ArrayTaskSet(BaseTaskSet): 13 | """A task dataset returned by the CLLoader specialized into numpy/torch image arrays data. 14 | 15 | :param x: The data, either image-arrays or paths to images saved on disk. 16 | :param y: The targets, not one-hot encoded. 17 | :param t: The task id of each sample. 18 | :param trsf: The transformations to apply on the images. 19 | :param target_trsf: The transformations to apply on the labels. 20 | :param bounding_boxes: The bounding boxes annotations to crop images 21 | """ 22 | 23 | def __init__( 24 | self, 25 | x: np.ndarray, 26 | y: np.ndarray, 27 | t: np.ndarray, 28 | trsf: Union[transforms.Compose, List[transforms.Compose]], 29 | target_trsf: Optional[Union[transforms.Compose, List[transforms.Compose]]], 30 | bounding_boxes: Optional[np.ndarray] = None 31 | ): 32 | super().__init__(x, y, t, trsf, target_trsf, bounding_boxes=bounding_boxes) 33 | self.data_type = TaskType.IMAGE_ARRAY 34 | 35 | def plot( 36 | self, 37 | path: Union[str, None] = None, 38 | title: str = "", 39 | nb_samples: int = 100, 40 | shape: Optional[Tuple[int, int]] = None, 41 | ) -> None: 42 | """Plot samples of the current task, useful to check if everything is ok. 43 | 44 | :param path: If not None, save on disk at this path. 45 | :param title: The title of the figure. 46 | :param nb_samples: Amount of samples randomly selected. 47 | :param shape: Shape to resize the image before plotting. 48 | """ 49 | plot_samples(self, title=title, path=path, nb_samples=nb_samples, 50 | shape=shape, data_type=self.data_type) 51 | 52 | def get_samples(self, indexes): 53 | samples, targets, tasks = [], [], [] 54 | 55 | w, h = None, None 56 | for index in indexes: 57 | # we need to use __getitem__ to have the transform used 58 | sample, y, t = self[index] 59 | 60 | # we check dimension of images 61 | if w is None: 62 | w, h = sample.shape[:2] 63 | elif w != sample.shape[0] or h != sample.shape[1]: 64 | raise Exception( 65 | "Images dimension are inconsistent, resize them to a " 66 | "common size using a transformation.\n" 67 | "For example, give to the scenario you're using as `transformations` argument " 68 | "the following: [transforms.Resize((224, 224)), transforms.ToTensor()]" 69 | ) 70 | 71 | samples.append(sample) 72 | targets.append(y) 73 | tasks.append(t) 74 | 75 | return _tensorize_list(samples), _tensorize_list(targets), _tensorize_list(tasks) 76 | 77 | def get_sample(self, index: int) -> np.ndarray: 78 | """Returns a Pillow image corresponding to the given `index`. 79 | 80 | :param index: Index to query the image. 81 | :return: A Pillow image. 82 | """ 83 | x = self._x[index] 84 | # x = Image.fromarray(x.astype("uint8")) 85 | return x 86 | 87 | def __getitem__(self, index: int) -> Tuple[np.ndarray, int, int]: 88 | """Method used by PyTorch's DataLoaders to query a sample and its target.""" 89 | x = self.get_sample(index) 90 | y = self._y[index] 91 | t = self._t[index] 92 | 93 | if self.bounding_boxes is not None: 94 | bbox = self.bounding_boxes[index] 95 | x = x.crop(( 96 | max(bbox[0], 0), # x1 97 | max(bbox[1], 0), # y1 98 | min(bbox[2], x.size[0]), # x2 99 | min(bbox[3], x.size[1]), # y2 100 | )) 101 | 102 | x, y, t = self._prepare_data(x, y, t) 103 | 104 | if self.target_trsf is not None: 105 | y = self.get_task_target_trsf(t)(y) 106 | 107 | return x, y, t 108 | 109 | def _prepare_data(self, x, y, t): 110 | if self.trsf is not None: 111 | x = self.get_task_trsf(t)(x) 112 | if not isinstance(x, torch.Tensor): 113 | x = self._to_tensor(x) 114 | return x, y, t 115 | -------------------------------------------------------------------------------- /continual/mycontinual/incremental_rotation.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as F 2 | import numpy as np 3 | 4 | 5 | class IncrementalRotation(object): 6 | """ 7 | Defines an incremental rotation for a numpy array. 8 | """ 9 | 10 | def __init__(self, init_deg: int = 0, increase_per_iteration: float = 0.006) -> None: 11 | """ 12 | Defines the initial angle as well as the increase for each rotation 13 | :param init_deg: 14 | :param increase_per_iteration: 15 | """ 16 | self.increase_per_iteration = increase_per_iteration 17 | self.iteration = 0 18 | self.degrees = init_deg 19 | 20 | def __call__(self, x: np.ndarray) -> np.ndarray: 21 | """ 22 | Applies the rotation. 23 | :param x: image to be rotated 24 | :return: rotated image 25 | """ 26 | degs = (self.iteration * self.increase_per_iteration + self.degrees) % 360 27 | self.iteration += 1 28 | return F.rotate(x, degs) 29 | 30 | def set_iteration(self, x: int) -> None: 31 | """ 32 | Set the iteration to a given integer 33 | :param x: iteration index 34 | """ 35 | self.iteration = x -------------------------------------------------------------------------------- /continual/mycontinual/permutations.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Callable, List, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision import transforms 7 | 8 | from continuum.datasets import _ContinuumDataset 9 | from continual.mycontinual import TransformationIncremental 10 | 11 | 12 | class Permutations(TransformationIncremental): 13 | """Continual Loader, generating datasets for the consecutive tasks. 14 | 15 | Scenario: Permutations scenarios, use same data for all task but with pixels permuted. 16 | Each task get a specific permutation, such as all tasks are different. 17 | 18 | :param cl_dataset: A continual dataset. 19 | :param nb_tasks: The scenario's number of tasks. 20 | :param base_transformations: List of transformations to apply to all tasks. 21 | :param seed: initialization seed for the permutations. 22 | :param shared_label_space: If true same data with different transformation have same label 23 | """ 24 | 25 | def __init__( 26 | self, 27 | cl_dataset: _ContinuumDataset, 28 | nb_tasks: Union[int, None] = None, 29 | base_transformations: List[Callable] = None, 30 | seed: Union[int, List[int]] = 0, 31 | shared_label_space=True 32 | ): 33 | trsfs = self._generate_transformations(seed, nb_tasks) 34 | 35 | super().__init__( 36 | cl_dataset=cl_dataset, 37 | incremental_transformations=trsfs, 38 | base_transformations=base_transformations, 39 | shared_label_space=shared_label_space 40 | ) 41 | 42 | def _generate_transformations(self, seed, nb_tasks): 43 | if isinstance(seed, int): 44 | if nb_tasks is None: 45 | raise ValueError("You must specify a number of tasks if a single seed is provided.") 46 | rng = np.random.RandomState(seed=seed) 47 | seed = rng.permutation(100000)[:nb_tasks - 1] 48 | elif nb_tasks is not None and nb_tasks != len(seed) + 1: 49 | warnings.warn( 50 | f"Because a list of seed was provided {seed}, " 51 | f"the number of tasks is automatically set to " 52 | f"len(number of seeds) + 1 = {len(seed) + 1}" 53 | ) 54 | 55 | return [PermutationTransform(seed=None)] + [PermutationTransform(seed=int(s)) for s in seed] 56 | 57 | def get_task_transformation(self, task_index): 58 | return transforms.Compose(self.trsf.transforms + [self.inc_trsf[task_index]]) 59 | 60 | 61 | class PermutationTransform: 62 | """Permutation transformers. 63 | 64 | This transformer is initialized with a seed such as same seed = same permutation. 65 | Seed 0 means no permutations 66 | 67 | :param seed: seed to initialize the random number generator 68 | """ 69 | 70 | def __init__(self, seed: Union[int, None]): 71 | self.seed = seed 72 | self.g_cpu = torch.Generator() 73 | 74 | def __call__(self, x): 75 | shape = list(x.shape) 76 | x = x.reshape(-1) 77 | # if seed is None, no permutations 78 | if self.seed is not None: 79 | self.g_cpu.manual_seed(self.seed) 80 | perm = torch.randperm(x.numel(), generator=self.g_cpu).long() 81 | x = x[perm] 82 | return x.reshape(shape) 83 | -------------------------------------------------------------------------------- /continual/mycontinual/rotations.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Tuple, Union 2 | 3 | from torchvision import transforms 4 | 5 | from continuum.datasets import _ContinuumDataset 6 | from continual.mycontinual import TransformationIncremental 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | class Rotations(TransformationIncremental): 11 | """Continual Loader, generating datasets for the consecutive tasks. 12 | 13 | Scenario: Rotations scenario is a new instance scenario. 14 | For each task data is rotated from a certain angle. 15 | 16 | :param cl_dataset: A continual dataset. 17 | :param nb_tasks: The scenario's number of tasks. 18 | :param list_degrees: list of rotation in degree (int) or list of range. e.g. (0, (40,45), 90). 19 | :param base_transformations: Preprocessing transformation to applied to data before rotation. 20 | :param shared_label_space: If true same data with different transformation have same label 21 | """ 22 | 23 | def __init__( 24 | self, 25 | cl_dataset: _ContinuumDataset, 26 | list_degrees: Union[List[Tuple], List[int]], 27 | nb_tasks: Union[int, None] = None, 28 | base_transformations: List[Callable] = None, 29 | shared_label_space=True 30 | ): 31 | 32 | if nb_tasks is not None and len(list_degrees) != nb_tasks: 33 | raise ValueError( 34 | f"The nb of tasks ({nb_tasks}) != number of angles " 35 | f"tuples ({len(list_degrees)}) set in the list" 36 | ) 37 | 38 | trsfs = self._generate_transformations(list_degrees) 39 | 40 | super().__init__( 41 | cl_dataset=cl_dataset, 42 | incremental_transformations=trsfs, 43 | base_transformations=base_transformations, 44 | shared_label_space=shared_label_space 45 | ) 46 | 47 | def _generate_transformations(self, degrees): 48 | trsfs = [] 49 | min_deg, max_deg = None, None 50 | 51 | for deg in degrees: 52 | if isinstance(deg, int) or isinstance(deg, float): 53 | min_deg, max_deg = deg, deg 54 | elif len(deg) == 2: 55 | min_deg, max_deg = deg 56 | else: 57 | raise ValueError( 58 | f"Invalid list of degrees ({degrees}). " 59 | "It should contain either integers (-deg, +deg) or " 60 | "tuples (range) of integers (deg_a, deg_b)." 61 | ) 62 | 63 | trsfs.append([transforms.RandomAffine(degrees=[min_deg, max_deg])]) 64 | 65 | return trsfs 66 | -------------------------------------------------------------------------------- /continual/mycontinual/transformation_incremental.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional 2 | 3 | import numpy as np 4 | from torchvision import transforms 5 | 6 | from continuum.datasets import _ContinuumDataset 7 | from continuum.scenarios import InstanceIncremental 8 | from continuum.tasks import TaskSet, TaskType 9 | 10 | 11 | class TransformationIncremental(InstanceIncremental): 12 | """Continual Loader, generating datasets for the consecutive tasks. 13 | 14 | Scenario: Every task contains the same data with different transformations. 15 | It is a cheap way to create instance incremental scenarios. 16 | Moreover, it is easier to analyse what algorithms forget or not. 17 | Classic transformation incremental scenarios are "permutations" and "rotations". 18 | 19 | :param cl_dataset: A continual dataset. 20 | :param incremental_transformations: list of transformations to apply to specific tasks 21 | :param base_transformations: List of transformation to apply to all tasks. 22 | :param shared_label_space: If true same data with different transformation have same label 23 | """ 24 | 25 | def __init__( 26 | self, 27 | cl_dataset: _ContinuumDataset, 28 | incremental_transformations: List[List[Callable]], 29 | base_transformations: List[Callable] = None, 30 | shared_label_space=True 31 | ): 32 | nb_tasks = len(incremental_transformations) 33 | if incremental_transformations is None: 34 | raise ValueError("For this scenario a list transformation should be set") 35 | 36 | if cl_dataset.data_type == TaskType.H5: 37 | raise NotImplementedError("TransformationIncremental are not compatible yet with h5 files.") 38 | 39 | self.inc_trsf = incremental_transformations 40 | #self._nb_tasks = self._setup(nb_tasks) 41 | self.shared_label_space = shared_label_space 42 | 43 | super().__init__( 44 | cl_dataset=cl_dataset, nb_tasks=nb_tasks, transformations=base_transformations 45 | ) 46 | 47 | self.num_classes_per_task = len(np.unique(self.dataset[1])) # the num of classes is the same for all task is this scenario 48 | 49 | @property 50 | def nb_classes(self) -> int: 51 | """Total number of classes in the whole continual setting.""" 52 | if self.shared_label_space: 53 | nb_classes = len(np.unique(self.dataset[1])) 54 | else: 55 | nb_classes = len(np.unique(self.dataset[1])) * self._nb_tasks 56 | return nb_classes 57 | 58 | def get_task_transformation(self, task_index): 59 | return transforms.Compose(self.inc_trsf[task_index] + self.trsf.transforms) 60 | 61 | def update_task_indexes(self, task_index): 62 | new_t = np.ones(len(self.dataset[1])) * task_index 63 | self.dataset = (self.dataset[0], self.dataset[1], new_t) 64 | 65 | def update_labels(self, task_index): 66 | # wrong 67 | # new_y = self.dataset[1] + task_index * self.num_classes_per_task 68 | # we update incrementally then update is simply: 69 | if task_index > 0: 70 | new_y = self.dataset[1] + self.num_classes_per_task 71 | self.dataset = (self.dataset[0], new_y, self.dataset[2]) 72 | 73 | def __getitem__(self, task_index): 74 | """Returns a task by its unique index. 75 | 76 | :param task_index: The unique index of a task, between 0 and len(loader) - 1. Or it could 77 | be a list or a numpy array or even a slice. 78 | :return: A train PyTorch's Datasets. 79 | """ 80 | x, y, _ = self.dataset 81 | 82 | if isinstance(task_index, slice): 83 | # Convert a slice to a list and respect the Python's advanced indexing conventions 84 | start = task_index.start if task_index.start is not None else 0 85 | stop = task_index.stop if task_index.stop is not None else len(self) + 1 86 | step = task_index.step if task_index.step is not None else 1 87 | task_index = list(range(start, stop, step)) 88 | if len(task_index) == 0: 89 | raise ValueError(f"Invalid slicing resulting in no data (start={start}, end={stop}, step={step}).") 90 | elif isinstance(task_index, np.ndarray): 91 | task_index = list(task_index) 92 | elif isinstance(task_index, int): 93 | task_index = [task_index] 94 | else: 95 | raise TypeError(f"Invalid type of task index {type(task_index).__name__}.") 96 | 97 | task_index = set([_handle_negative_indexes(ti, len(self)) for ti in task_index]) 98 | 99 | t = np.concatenate([ 100 | (np.ones(len(x)) * ti).astype(np.int32) for ti in task_index 101 | ]) 102 | x = np.concatenate([ 103 | x for _ in range(len(task_index)) 104 | ]) 105 | 106 | if self.shared_label_space: 107 | y = np.concatenate([ 108 | y for _ in range(len(task_index)) 109 | ]) 110 | else: 111 | # Different transformations have different labels even though 112 | # the original images were the same 113 | y = np.concatenate([ 114 | y + ti * self.num_classes_per_task for ti in task_index 115 | ]) 116 | 117 | # trsf = [ # Non-used tasks have a None trsf 118 | # self.get_task_transformation(ti) 119 | # if ti in task_index else None 120 | # for ti in range(len(self)) 121 | # ] 122 | 123 | trsf = [ # Non-used tasks have a None trsf 124 | self.get_task_transformation(ti) 125 | # if ti in task_index else None 126 | for ti in range(len(self)) 127 | ] 128 | 129 | return TaskSet(x, y, t, trsf, data_type=self.cl_dataset.data_type) 130 | 131 | 132 | def _handle_negative_indexes(index: int, total_len: int) -> int: 133 | while index < 0: 134 | index += total_len 135 | return index 136 | -------------------------------------------------------------------------------- /continual/rehearsal.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | 7 | from continual.mycontinual import ArrayTaskSet 8 | 9 | 10 | class Memory: 11 | def __init__(self, memory_size, nb_total_classes, rehearsal, rep_replay=False, fixed=True): 12 | self.memory_size = memory_size # 2000 13 | self.nb_total_classes = nb_total_classes # 100 14 | self.rehearsal = rehearsal # icarl_all 15 | self.fixed = fixed # False 16 | self.rep_replay = rep_replay 17 | 18 | self.x = self.y = self.t = None 19 | 20 | self.nb_classes = 0 21 | 22 | @property 23 | def memory_per_class(self): 24 | if self.fixed: 25 | return self.memory_size // self.nb_total_classes 26 | return self.memory_size // self.nb_classes if self.nb_classes > 0 else self.memory_size 27 | 28 | def get_dataset_without_copy(self, base_dataset): 29 | dataset = base_dataset 30 | dataset._x = self.x 31 | dataset._y = self.y 32 | dataset._t = self.t 33 | 34 | return dataset 35 | 36 | def get_dataset(self, base_dataset): 37 | if self.rep_replay: 38 | dataset = ArrayTaskSet(x=self.x, y=self.y, t=self.t, trsf=None, target_trsf=None, 39 | bounding_boxes=None) 40 | else: 41 | dataset = copy.deepcopy(base_dataset) 42 | dataset._x = self.x 43 | dataset._y = self.y 44 | dataset._t = self.t 45 | 46 | return dataset 47 | 48 | def get(self): 49 | return self.x, self.y, self.t 50 | 51 | def __len__(self): 52 | return len(self.x) if self.x is not None else 0 53 | 54 | def save(self, path): 55 | np.savez( 56 | path, 57 | x=self.x, y=self.y, t=self.t 58 | ) 59 | 60 | def load(self, path): 61 | data = np.load(path) 62 | self.x = data["x"] 63 | self.y = data["y"] 64 | self.t = data["t"] 65 | 66 | assert len(self) <= self.memory_size, len(self) 67 | self.nb_classes = len(np.unique(self.y)) 68 | 69 | def reduce(self): 70 | x, y, t = [], [], [] 71 | for class_id in np.unique(self.y): 72 | indexes = np.where(self.y == class_id)[0] 73 | x.append(self.x[indexes[:self.memory_per_class]]) 74 | y.append(self.y[indexes[:self.memory_per_class]]) 75 | t.append(self.t[indexes[:self.memory_per_class]]) 76 | 77 | self.x = np.concatenate(x) 78 | self.y = np.concatenate(y) 79 | self.t = np.concatenate(t) 80 | 81 | def add(self, dataset, model, nb_new_classes): 82 | self.nb_classes += nb_new_classes 83 | 84 | x, y, t = herd_samples(dataset, model, self.memory_per_class, self.rehearsal, self.rep_replay) 85 | 86 | if self.x is None: 87 | self.x, self.y, self.t = x, y, t 88 | else: 89 | if not self.fixed: 90 | self.reduce() 91 | self.x = np.concatenate((self.x, x)) 92 | self.y = np.concatenate((self.y, y)) 93 | self.t = np.concatenate((self.t, t)) 94 | 95 | 96 | def herd_samples(dataset, model, memory_per_class, rehearsal, rep_replay): 97 | x, y, t = dataset._x, dataset._y, dataset._t 98 | 99 | if rehearsal == "random": 100 | indexes = [] 101 | for class_id in np.unique(y): 102 | class_indexes = np.where(y == class_id)[0] 103 | indexes.append( 104 | np.random.choice(class_indexes, size=memory_per_class) 105 | ) 106 | indexes = np.concatenate(indexes) 107 | 108 | return x[indexes], y[indexes], t[indexes] 109 | elif "closest" in rehearsal: 110 | if rehearsal == 'closest_token': 111 | handling = 'last' 112 | else: 113 | handling = 'all' 114 | 115 | features, targets = extract_features(dataset, model, handling) 116 | indexes = [] 117 | 118 | for class_id in np.unique(y): 119 | class_indexes = np.where(y == class_id)[0] 120 | class_features = features[class_indexes] 121 | 122 | class_mean = np.mean(class_features, axis=0, keepdims=True) 123 | distances = np.power(class_features - class_mean, 2).sum(-1) 124 | class_closest_indexes = np.argsort(distances) 125 | 126 | indexes.append( 127 | class_indexes[class_closest_indexes[:memory_per_class]] 128 | ) 129 | 130 | indexes = np.concatenate(indexes) 131 | return x[indexes], y[indexes], t[indexes] 132 | elif "furthest" in rehearsal: 133 | if rehearsal == 'furthest_token': 134 | handling = 'last' 135 | else: 136 | handling = 'all' 137 | 138 | features, targets = extract_features(dataset, model, handling) 139 | indexes = [] 140 | 141 | for class_id in np.unique(y): 142 | class_indexes = np.where(y == class_id)[0] 143 | class_features = features[class_indexes] 144 | 145 | class_mean = np.mean(class_features, axis=0, keepdims=True) 146 | distances = np.power(class_features - class_mean, 2).sum(-1) 147 | class_furthest_indexes = np.argsort(distances)[::-1] 148 | 149 | indexes.append( 150 | class_indexes[class_furthest_indexes[:memory_per_class]] 151 | ) 152 | 153 | indexes = np.concatenate(indexes) 154 | return x[indexes], y[indexes], t[indexes] 155 | elif "icarl": 156 | if rehearsal == 'icarl_token': 157 | handling = 'last' 158 | else: 159 | handling = 'all' 160 | 161 | features, targets = extract_features(dataset, model, handling) 162 | indexes = [] 163 | 164 | for class_id in np.unique(y): 165 | class_indexes = np.where(y == class_id)[0] 166 | class_features = features[class_indexes] 167 | 168 | indexes.append( 169 | class_indexes[icarl_selection(class_features, memory_per_class)] 170 | ) 171 | 172 | indexes = np.concatenate(indexes) 173 | 174 | # store representations for the samples 175 | if rep_replay: 176 | 177 | dataset.trsf = transforms.Compose([dataset.trsf.transforms[tf] for tf in [0,3,4]]) 178 | loader = torch.utils.data.DataLoader( 179 | dataset, 180 | batch_size=128, 181 | num_workers=2, 182 | pin_memory=True, 183 | drop_last=False, 184 | shuffle=False 185 | ) 186 | 187 | features, targets = [], [] 188 | 189 | with torch.no_grad(): 190 | for x, y, _ in loader: 191 | if hasattr(model, 'module'): 192 | reps = model.module.forward_initial(x.cuda()) 193 | else: 194 | reps = model.forward_initial(x.cuda()) 195 | 196 | reps = reps.detach().cpu().numpy() 197 | y = y.numpy() 198 | 199 | features.append(reps.reshape((reps.shape[0], int(reps.shape[1] ** 0.5), 200 | int(reps.shape[1] ** 0.5), reps.shape[-1]))) 201 | targets.append(y) 202 | 203 | features = np.vstack(features) 204 | targets = np.concatenate(targets) 205 | 206 | return features[indexes], targets[indexes], t[indexes] 207 | else: 208 | return x[indexes], y[indexes], t[indexes] 209 | else: 210 | raise ValueError(f"Unknown rehearsal method {rehearsal}!") 211 | 212 | 213 | def extract_features(dataset, model, ensemble_handling='last'): 214 | loader = torch.utils.data.DataLoader( 215 | dataset, 216 | batch_size=128, 217 | num_workers=2, 218 | pin_memory=True, 219 | drop_last=False, 220 | shuffle=False 221 | ) 222 | 223 | features, targets = [], [] 224 | 225 | with torch.no_grad(): 226 | for x, y, _ in loader: 227 | if hasattr(model, 'module'): 228 | feats, _, _ = model.module.forward_features(x.cuda()) 229 | else: 230 | feats, _, _ = model.forward_features(x.cuda()) 231 | 232 | if isinstance(feats, list): 233 | if ensemble_handling == 'last': 234 | feats = feats[-1] 235 | elif ensemble_handling == 'all': 236 | feats = torch.cat(feats, dim=1) 237 | else: 238 | raise NotImplementedError(f'Unknown handling of multiple features {ensemble_handling}') 239 | elif len(feats.shape) == 3: # joint tokens 240 | if ensemble_handling == 'last': 241 | feats = feats[-1] 242 | elif ensemble_handling == 'all': 243 | feats = feats.permute(1, 0, 2).view(len(x), -1) 244 | else: 245 | raise NotImplementedError(f'Unknown handling of multiple features {ensemble_handling}') 246 | 247 | feats = feats.cpu().numpy() 248 | y = y.numpy() 249 | 250 | features.append(feats) 251 | targets.append(y) 252 | 253 | features = np.concatenate(features) 254 | targets = np.concatenate(targets) 255 | 256 | return features, targets 257 | 258 | 259 | def icarl_selection(features, nb_examplars): 260 | D = features.T 261 | D = D / (np.linalg.norm(D, axis=0) + 1e-8) 262 | mu = np.mean(D, axis=1) 263 | herding_matrix = np.zeros((features.shape[0],)) 264 | 265 | w_t = mu 266 | iter_herding, iter_herding_eff = 0, 0 267 | 268 | while not ( 269 | np.sum(herding_matrix != 0) == min(nb_examplars, features.shape[0]) 270 | ) and iter_herding_eff < 1000: 271 | tmp_t = np.dot(w_t, D) 272 | ind_max = np.argmax(tmp_t) 273 | iter_herding_eff += 1 274 | if herding_matrix[ind_max] == 0: 275 | herding_matrix[ind_max] = 1 + iter_herding 276 | iter_herding += 1 277 | 278 | w_t = w_t + mu - D[:, ind_max] 279 | 280 | herding_matrix[np.where(herding_matrix == 0)[0]] = 10000 281 | 282 | return herding_matrix.argsort()[:nb_examplars] 283 | 284 | 285 | def get_finetuning_dataset(dataset, memory, finetuning='balanced', rep_replay=False): 286 | if finetuning == 'balanced': 287 | x, y, t = memory.get() 288 | 289 | if rep_replay: 290 | # current task samples 291 | new_dataset = ArrayTaskSet(x=x, y=y, t=t, trsf=None, 292 | target_trsf=None, bounding_boxes=None) 293 | else: 294 | new_dataset = copy.deepcopy(dataset) 295 | new_dataset._x = x 296 | new_dataset._y = y 297 | new_dataset._t = t 298 | elif finetuning in ('all', 'none'): 299 | new_dataset = dataset 300 | else: 301 | raise NotImplementedError(f'Unknown finetuning method {finetuning}') 302 | 303 | return new_dataset 304 | 305 | 306 | def get_separate_finetuning_dataset(dataset, memory, finetuning='balanced', rep_replay=False): 307 | if finetuning == 'balanced': 308 | x, y, t = memory.get() 309 | 310 | # extract current and old task samples from memory 311 | cur_task_idx = t == max(np.unique(t)) 312 | old_task_idx = t != max(np.unique(t)) 313 | 314 | if rep_replay: 315 | # current task samples 316 | first_dataset = ArrayTaskSet(x=x[cur_task_idx], y=y[cur_task_idx], t=t[cur_task_idx], trsf=None, 317 | target_trsf=None, bounding_boxes=None) 318 | 319 | # old task samples 320 | second_dataset = ArrayTaskSet(x=x[old_task_idx], y=y[old_task_idx], t=t[old_task_idx], trsf=None, 321 | target_trsf=None, bounding_boxes=None) 322 | else: 323 | first_dataset = copy.deepcopy(dataset) 324 | first_dataset._x = x[cur_task_idx] 325 | first_dataset._y = y[cur_task_idx] 326 | first_dataset._t = t[cur_task_idx] 327 | 328 | second_dataset = copy.deepcopy(dataset) 329 | second_dataset._x = x[old_task_idx] 330 | second_dataset._y = y[old_task_idx] 331 | second_dataset._t = t[old_task_idx] 332 | 333 | elif finetuning in ('all', 'none'): 334 | # not supported after change 335 | new_dataset = dataset 336 | else: 337 | raise NotImplementedError(f'Unknown finetuning method {finetuning}') 338 | 339 | return first_dataset, second_dataset 340 | -------------------------------------------------------------------------------- /continual/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SAM: 5 | """SAM, ASAM, and Look-SAM 6 | 7 | Modified version of: https://github.com/davda54/sam 8 | Only Look-SAM has been added. 9 | 10 | It speeds up SAM quite a lot but the alpha needs to be tuned to reach same performance. 11 | """ 12 | def __init__(self, base_optimizer, model_without_ddp, rho=0.05, adaptive=False, div='', use_look_sam=False, look_sam_alpha=0., **kwargs): 13 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 14 | 15 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 16 | 17 | self.base_optimizer = base_optimizer 18 | self.param_groups = self.base_optimizer.param_groups 19 | self.model_without_ddp = model_without_ddp 20 | 21 | self.rho = rho 22 | self.adaptive = adaptive 23 | self.div = div 24 | self.look_sam_alpha = look_sam_alpha 25 | self.use_look_sam = use_look_sam 26 | 27 | self.g_v = dict() 28 | 29 | @torch.no_grad() 30 | def first_step(self): 31 | self.e_w = dict() 32 | self.g = dict() 33 | 34 | grad_norm = self._grad_norm() 35 | for group in self.param_groups: 36 | scale = self.rho / (grad_norm + 1e-12) 37 | 38 | for p in group["params"]: 39 | if p.grad is None: continue 40 | e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p) 41 | p.add_(e_w) # climb to the local maximum "w + e(w)" 42 | self.e_w[p] = e_w 43 | self.g[p] = p.grad.clone() 44 | 45 | @torch.no_grad() 46 | def second_step(self, look_sam_update=False): 47 | if self.use_look_sam and look_sam_update: 48 | self.g_v = dict() 49 | 50 | for group in self.param_groups: 51 | for p in group["params"]: 52 | if p.grad is None: continue 53 | 54 | if not self.use_look_sam or look_sam_update: 55 | p.sub_(self.e_w[p]) 56 | 57 | if self.use_look_sam and look_sam_update: 58 | cos = self._cos(self.g[p], p.grad) 59 | norm_gs = p.grad.norm(p=2) 60 | norm_g = self.g[p].norm(p=2) 61 | self.g_v[p] = p.grad - norm_gs * cos * self.g[p] / norm_g 62 | elif self.use_look_sam: 63 | norm_g = p.grad.norm(p=2) 64 | norm_gv = self.g_v[p].norm(p=2) 65 | p.grad.add_(self.look_sam_alpha * (norm_g / norm_gv) * self.g_v[p]) 66 | 67 | self.e_w = None 68 | self.g = None 69 | 70 | def _cos(self, a, b): 71 | return torch.dot(a.view(-1), b.view(-1)) / (a.norm() * b.norm()) 72 | 73 | @torch.no_grad() 74 | def step(self, closure=None): 75 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 76 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 77 | 78 | self.first_step(zero_grad=True) 79 | closure() 80 | self.second_step() 81 | 82 | def _grad_norm(self): 83 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 84 | norm = torch.norm( 85 | torch.stack([ 86 | ((torch.abs(p) if self.adaptive else 1.0) * p.grad).norm(p=2).to(shared_device) 87 | for group in self.param_groups for p in group["params"] 88 | if p.grad is not None 89 | ]), 90 | p=2 91 | ) 92 | return norm 93 | -------------------------------------------------------------------------------- /continual/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | import numpy as np 7 | 8 | import continual.utils as utils 9 | 10 | 11 | class SingleRASampler(torch.utils.data.Sampler): 12 | """Sampler that restricts data loading to a subset of the dataset for distributed, 13 | with repeated augmentation. 14 | It ensures that different each augmented version of a sample will be visible to a 15 | different process (GPU) 16 | Heavily based on torch.utils.data.DistributedSampler 17 | """ 18 | 19 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 20 | if num_replicas is None: 21 | if not dist.is_available(): 22 | raise RuntimeError("Requires distributed package to be available") 23 | num_replicas = dist.get_world_size() 24 | if rank is None: 25 | if not dist.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | # self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 33 | self.num_samples = int(math.ceil(len(self.dataset))) 34 | self.total_size = self.num_samples * self.num_replicas 35 | self.num_selected_samples = len(self.dataset) 36 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 37 | # self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 38 | self.shuffle = shuffle 39 | 40 | def __iter__(self): 41 | # deterministically shuffle based on epoch 42 | g = torch.Generator() 43 | g.manual_seed(self.epoch) 44 | if self.shuffle: 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = list(range(len(self.dataset))) 48 | 49 | # add extra samples to make it evenly divisible 50 | # indices = indices + indices 51 | indices = [ele for ele in indices for i in range(self.num_replicas)] 52 | # indices += indices[:(self.total_size - len(indices))] 53 | assert len(indices) == self.total_size 54 | 55 | # subsample 56 | # indices = indices[self.rank:self.total_size:self.num_replicas] 57 | # assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.total_size 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | 67 | 68 | class RASampler(torch.utils.data.Sampler): 69 | """Sampler that restricts data loading to a subset of the dataset for distributed, 70 | with repeated augmentation. 71 | It ensures that different each augmented version of a sample will be visible to a 72 | different process (GPU) 73 | Heavily based on torch.utils.data.DistributedSampler 74 | """ 75 | 76 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 77 | if num_replicas is None: 78 | if not dist.is_available(): 79 | raise RuntimeError("Requires distributed package to be available") 80 | num_replicas = dist.get_world_size() 81 | if rank is None: 82 | if not dist.is_available(): 83 | raise RuntimeError("Requires distributed package to be available") 84 | rank = dist.get_rank() 85 | self.dataset = dataset 86 | self.num_replicas = num_replicas 87 | self.rank = rank 88 | self.epoch = 0 89 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 90 | self.total_size = self.num_samples * self.num_replicas 91 | self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 92 | # self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 93 | self.shuffle = shuffle 94 | 95 | def __iter__(self): 96 | # deterministically shuffle based on epoch 97 | g = torch.Generator() 98 | g.manual_seed(self.epoch) 99 | if self.shuffle: 100 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 101 | else: 102 | indices = list(range(len(self.dataset))) 103 | 104 | # add extra samples to make it evenly divisible 105 | indices = [ele for ele in indices for i in range(3)] 106 | indices += indices[:(self.total_size - len(indices))] 107 | assert len(indices) == self.total_size 108 | 109 | # subsample 110 | indices = indices[self.rank:self.total_size:self.num_replicas] 111 | assert len(indices) == self.num_samples 112 | 113 | return iter(indices[:self.num_selected_samples]) 114 | 115 | def __len__(self): 116 | return self.num_selected_samples 117 | 118 | def set_epoch(self, epoch): 119 | self.epoch = epoch 120 | 121 | 122 | def get_sampler(dataset_train, dataset_val, args): 123 | if args.distributed: 124 | num_tasks = utils.get_world_size() 125 | global_rank = utils.get_rank() 126 | if args.repeated_aug: 127 | sampler_train = RASampler( 128 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 129 | ) 130 | else: 131 | sampler_train = torch.utils.data.DistributedSampler( 132 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 133 | ) 134 | if args.dist_eval: 135 | if len(dataset_val) % num_tasks != 0: 136 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 137 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 138 | 'equal num of samples per-process.') 139 | sampler_val = torch.utils.data.DistributedSampler( 140 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 141 | else: 142 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 143 | else: 144 | if args.use_repeatedaug_single: 145 | sampler_train = SingleRASampler(dataset_train, num_replicas=2, rank=0, shuffle=True) 146 | else: 147 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 148 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 149 | 150 | return sampler_train, sampler_val 151 | 152 | 153 | def get_train_sampler(dataset_train, args): 154 | if args.distributed: 155 | num_tasks = utils.get_world_size() 156 | global_rank = utils.get_rank() 157 | if args.repeated_aug: 158 | sampler_train = RASampler( 159 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 160 | ) 161 | else: 162 | sampler_train = torch.utils.data.DistributedSampler( 163 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 164 | ) 165 | else: 166 | if args.use_repeatedaug_single: 167 | sampler_train = SingleRASampler(dataset_train, num_replicas=2, rank=0, shuffle=True) 168 | else: 169 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 170 | 171 | return sampler_train 172 | -------------------------------------------------------------------------------- /continual/scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils import dispatch_clip_grad 4 | 5 | 6 | class ContinualScaler: 7 | state_dict_key = "amp_scaler" 8 | 9 | def __init__(self, disable_amp): 10 | self._scaler = torch.cuda.amp.GradScaler(enabled=not disable_amp) 11 | 12 | def __call__( 13 | self, loss, optimizer, model_without_ddp, clip_grad=None, clip_mode='norm', 14 | parameters=None, create_graph=False, 15 | hook=True 16 | ): 17 | self.pre_step(loss, optimizer, parameters, create_graph, clip_grad, clip_mode) 18 | self.post_step(optimizer, model_without_ddp, hook) 19 | 20 | def pre_step(self, loss, optimizer, parameters=None, create_graph=False, clip_grad=None, clip_mode='norm'): 21 | self._scaler.scale(loss).backward(create_graph=create_graph) 22 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 23 | if clip_grad is not None: 24 | assert parameters is not None 25 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 26 | 27 | def post_step(self, optimizer, model_without_ddp, hook=True): 28 | if hook and hasattr(model_without_ddp, 'hook_before_update'): 29 | model_without_ddp.hook_before_update() 30 | 31 | self._scaler.step(optimizer) 32 | 33 | if hook and hasattr(model_without_ddp, 'hook_after_update'): 34 | model_without_ddp.hook_after_update() 35 | 36 | self.update() 37 | 38 | def update(self): 39 | self._scaler.update() 40 | 41 | def state_dict(self): 42 | return self._scaler.state_dict() 43 | 44 | def load_state_dict(self, state_dict): 45 | self._scaler.load_state_dict(state_dict) 46 | -------------------------------------------------------------------------------- /convert_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this script in case you saved rehearsal memory on a computer A, but then want 3 | to resume training, using those rehearsal samples, on a computer B. 4 | 5 | Because for ImageNet we save the path, which may be different on each computer. 6 | """ 7 | 8 | import sys 9 | import glob 10 | import os 11 | import shutil 12 | 13 | import numpy as np 14 | 15 | memory_path = sys.argv[1] 16 | new_base_path = sys.argv[2] 17 | 18 | if os.path.isdir(memory_path): 19 | memory_paths = glob.glob(os.path.abspath(os.path.join(memory_path, "memory_*.npz"))) 20 | else: 21 | memory_paths = [memory_path] 22 | 23 | print(memory_paths) 24 | 25 | for p in sorted(memory_paths): 26 | psrc = p 27 | if not os.path.exists(f"{p}_original"): 28 | shutil.copy(p, f"{p}_original") 29 | else: 30 | psrc = f"{p}_original" 31 | print(p) 32 | 33 | data = np.load(p) 34 | x = [] 35 | for img_path in data["x"]: 36 | id_ = str(img_path).lstrip("b'").rstrip("'").split("train")[-1][1:] 37 | 38 | x.append(os.path.join(new_base_path, "train", id_)) 39 | 40 | np.savez( 41 | p, 42 | x=np.array(x), y=data["y"], t=data["t"] 43 | ) 44 | print("Done!") 45 | -------------------------------------------------------------------------------- /images/BiRT_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeurAI-Lab/BiRT/1340a97bff17fc02e228b754bd80e1fd649ff9cd/images/BiRT_architecture.png -------------------------------------------------------------------------------- /options/arthur.yaml: -------------------------------------------------------------------------------- 1 | data_path: /local/douillard/ 2 | output_basedir: /local/douillard/transformer/checkpoints 3 | -------------------------------------------------------------------------------- /options/data/cifar100_10-10.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR 2 | initial_increment: 10 3 | increment: 10 4 | #memory_size: 2000 5 | 6 | log_category: 10-10 7 | -------------------------------------------------------------------------------- /options/data/cifar100_10-10_500.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR 2 | initial_increment: 10 3 | increment: 10 4 | memory_size: 500 5 | 6 | log_category: 10-10 7 | -------------------------------------------------------------------------------- /options/data/cifar100_2-2.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR 2 | initial_increment: 2 3 | increment: 2 4 | #memory_size: 2000 5 | 6 | log_category: 2-2 7 | -------------------------------------------------------------------------------- /options/data/cifar100_20-20.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR 2 | initial_increment: 20 3 | increment: 20 4 | #memory_size: 2000 5 | 6 | log_category: 20-20 7 | -------------------------------------------------------------------------------- /options/data/cifar100_5-5.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR 2 | initial_increment: 5 3 | increment: 5 4 | #memory_size: 2000 5 | 6 | log_category: 5-5 7 | -------------------------------------------------------------------------------- /options/data/cifar100_joint.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR 2 | 3 | initial_increment: 100 4 | increment: 100 5 | 6 | log_category: joint 7 | -------------------------------------------------------------------------------- /options/data/cifar100_order1.yaml: -------------------------------------------------------------------------------- 1 | class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39] 2 | -------------------------------------------------------------------------------- /options/data/cifar100_order2.yaml: -------------------------------------------------------------------------------- 1 | class_order: [58, 30, 93, 69, 21, 77, 3, 78, 12, 71, 65, 40, 16, 49, 89, 46, 24, 66, 19, 41, 5, 29, 15, 73, 11, 70, 90, 63, 67, 25, 59, 72, 80, 94, 54, 33, 18, 96, 2, 10, 43, 9, 57, 81, 76, 50, 32, 6, 37, 7, 68, 91, 88, 95, 85, 4, 60, 36, 22, 27, 39, 42, 34, 51, 55, 28, 53, 48, 38, 17, 83, 86, 56, 35, 45, 79, 99, 84, 97, 82, 98, 26, 47, 44, 62, 13, 31, 0, 75, 14, 52, 74, 8, 20, 1, 92, 87, 23, 64, 61] 2 | -------------------------------------------------------------------------------- /options/data/cifar100_order3.yaml: -------------------------------------------------------------------------------- 1 | class_order: [71, 54, 45, 32, 4, 8, 48, 66, 1, 91, 28, 82, 29, 22, 80, 27, 86, 23, 37, 47, 55, 9, 14, 68, 25, 96, 36, 90, 58, 21, 57, 81, 12, 26, 16, 89, 79, 49, 31, 38, 46, 20, 92, 88, 40, 39, 98, 94, 19, 95, 72, 24, 64, 18, 60, 50, 63, 61, 83, 76, 69, 35, 0, 52, 7, 65, 42, 73, 74, 30, 41, 3, 6, 53, 13, 56, 70, 77, 34, 97, 75, 2, 17, 93, 33, 84, 99, 51, 62, 87, 5, 15, 10, 78, 67, 44, 59, 85, 43, 11] 2 | -------------------------------------------------------------------------------- /options/data/cifar100_order4.yaml: -------------------------------------------------------------------------------- 1 | class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] -------------------------------------------------------------------------------- /options/data/cifar100_order5.yaml: -------------------------------------------------------------------------------- 1 | class_order: [ 5, 2, 7, 93, 9, 99, 12, 6, 21, 15, 53, 46, 31, 48, 39, 45, 92, 25, 61, 73, 70, 32, 55, 91, 90, 42, 76, 52, 83, 75, 43, 47, 24, 11, 13, 37, 18, 38, 40, 10, 67, 66, 41, 33, 74, 16, 79, 86, 78, 97, 80, 64, 54, 28, 14, 50, 35, 71, 1, 62, 88, 0, 60, 95, 36, 51, 87, 58, 56, 65, 59, 68, 82, 4, 81, 29, 27, 3, 34, 22, 72, 30, 89, 77, 63, 8, 84, 98, 17, 26, 23, 94, 20, 85, 19, 96, 49, 57, 44, 69] -------------------------------------------------------------------------------- /options/data/cifar10_2-2.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR10 2 | initial_increment: 2 3 | increment: 2 4 | #memory_size: 2000 5 | 6 | log_category: 2-2 7 | -------------------------------------------------------------------------------- /options/data/cifar10_2-2_500.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR10 2 | initial_increment: 2 3 | increment: 2 4 | memory_size: 500 5 | 6 | log_category: 2-2 7 | -------------------------------------------------------------------------------- /options/data/cifar10_joint.yaml: -------------------------------------------------------------------------------- 1 | data_set: CIFAR10 2 | initial_increment: 10 3 | increment: 10 4 | 5 | log_category: joint 6 | -------------------------------------------------------------------------------- /options/data/imagenet1000_100-100.yaml: -------------------------------------------------------------------------------- 1 | data_set: imagenet1000 2 | initial_increment: 100 3 | increment: 100 4 | #memory_size: 20000 5 | 6 | log_category: 100-100 7 | -------------------------------------------------------------------------------- /options/data/imagenet1000_joint.yaml: -------------------------------------------------------------------------------- 1 | data_set: imagenet1000 2 | initial_increment: 1000 3 | increment: 1000 4 | 5 | log_category: joint 6 | -------------------------------------------------------------------------------- /options/data/imagenet1000_order1.yaml: -------------------------------------------------------------------------------- 1 | class_order: [54, 7, 894, 512, 126, 337, 988, 11, 284, 493, 133, 783, 192, 979, 622, 215, 240, 548, 238, 419, 274, 108, 2 | 928, 856, 494, 836, 473, 650, 85, 262, 508, 590, 390, 174, 637, 288, 658, 219, 912, 142, 852, 160, 704, 289, 3 | 123, 323, 600, 542, 999, 634, 391, 761, 490, 842, 127, 850, 665, 990, 597, 722, 748, 14, 77, 437, 394, 859, 4 | 279, 539, 75, 466, 886, 312, 303, 62, 966, 413, 959, 782, 509, 400, 471, 632, 275, 730, 105, 523, 224, 186, 5 | 478, 507, 470, 906, 699, 989, 324, 812, 260, 911, 446, 44, 765, 759, 67, 36, 5, 30, 184, 797, 159, 741, 954, 6 | 465, 533, 585, 150, 101, 897, 363, 818, 620, 824, 154, 956, 176, 588, 986, 172, 223, 461, 94, 141, 621, 659, 7 | 360, 136, 578, 163, 427, 70, 226, 925, 596, 336, 412, 731, 755, 381, 810, 69, 898, 310, 120, 752, 93, 39, 8 | 326, 537, 905, 448, 347, 51, 615, 601, 229, 947, 348, 220, 949, 972, 73, 913, 522, 193, 753, 921, 257, 957, 9 | 691, 155, 820, 584, 948, 92, 582, 89, 379, 392, 64, 904, 169, 216, 694, 103, 410, 374, 515, 484, 624, 409, 10 | 156, 455, 846, 344, 371, 468, 844, 276, 740, 562, 503, 831, 516, 663, 630, 763, 456, 179, 996, 936, 248, 11 | 333, 941, 63, 738, 802, 372, 828, 74, 540, 299, 750, 335, 177, 822, 643, 593, 800, 459, 580, 933, 306, 378, 12 | 76, 227, 426, 403, 322, 321, 808, 393, 27, 200, 764, 651, 244, 479, 3, 415, 23, 964, 671, 195, 569, 917, 13 | 611, 644, 707, 355, 855, 8, 534, 657, 571, 811, 681, 543, 313, 129, 978, 592, 573, 128, 243, 520, 887, 892, 14 | 696, 26, 551, 168, 71, 398, 778, 529, 526, 792, 868, 266, 443, 24, 57, 15, 871, 678, 745, 845, 208, 188, 15 | 674, 175, 406, 421, 833, 106, 994, 815, 581, 676, 49, 619, 217, 631, 934, 932, 568, 353, 863, 827, 425, 420, 16 | 99, 823, 113, 974, 438, 874, 343, 118, 340, 472, 552, 937, 0, 10, 675, 316, 879, 561, 387, 726, 255, 407, 17 | 56, 927, 655, 809, 839, 640, 297, 34, 497, 210, 606, 971, 589, 138, 263, 587, 993, 973, 382, 572, 735, 535, 18 | 139, 524, 314, 463, 895, 376, 939, 157, 858, 457, 935, 183, 114, 903, 767, 666, 22, 525, 902, 233, 250, 825, 19 | 79, 843, 221, 214, 205, 166, 431, 860, 292, 976, 739, 899, 475, 242, 961, 531, 110, 769, 55, 701, 532, 586, 20 | 729, 253, 486, 787, 774, 165, 627, 32, 291, 962, 922, 222, 705, 454, 356, 445, 746, 776, 404, 950, 241, 452, 21 | 245, 487, 706, 2, 137, 6, 98, 647, 50, 91, 202, 556, 38, 68, 649, 258, 345, 361, 464, 514, 958, 504, 826, 22 | 668, 880, 28, 920, 918, 339, 315, 320, 768, 201, 733, 575, 781, 864, 617, 171, 795, 132, 145, 368, 147, 327, 23 | 713, 688, 848, 690, 975, 354, 853, 148, 648, 300, 436, 780, 693, 682, 246, 449, 492, 162, 97, 59, 357, 198, 24 | 519, 90, 236, 375, 359, 230, 476, 784, 117, 940, 396, 849, 102, 122, 282, 181, 130, 467, 88, 271, 793, 151, 25 | 847, 914, 42, 834, 521, 121, 29, 806, 607, 510, 837, 301, 669, 78, 256, 474, 840, 52, 505, 547, 641, 987, 26 | 801, 629, 491, 605, 112, 429, 401, 742, 528, 87, 442, 910, 638, 785, 264, 711, 369, 428, 805, 744, 380, 725, 27 | 480, 318, 997, 153, 384, 252, 985, 538, 654, 388, 100, 432, 832, 565, 908, 367, 591, 294, 272, 231, 213, 28 | 196, 743, 817, 433, 328, 970, 969, 4, 613, 182, 685, 724, 915, 311, 931, 865, 86, 119, 203, 268, 718, 317, 29 | 926, 269, 161, 209, 807, 645, 513, 261, 518, 305, 758, 872, 58, 65, 146, 395, 481, 747, 41, 283, 204, 564, 30 | 185, 777, 33, 500, 609, 286, 567, 80, 228, 683, 757, 942, 134, 673, 616, 960, 450, 350, 544, 830, 736, 170, 31 | 679, 838, 819, 485, 430, 190, 566, 511, 482, 232, 527, 411, 560, 281, 342, 614, 662, 47, 771, 861, 692, 686, 32 | 277, 373, 16, 946, 265, 35, 9, 884, 909, 610, 358, 18, 737, 977, 677, 803, 595, 135, 458, 12, 46, 418, 599, 33 | 187, 107, 992, 770, 298, 104, 351, 893, 698, 929, 502, 273, 20, 96, 791, 636, 708, 267, 867, 772, 604, 618, 34 | 346, 330, 554, 816, 664, 716, 189, 31, 721, 712, 397, 43, 943, 804, 296, 109, 576, 869, 955, 17, 506, 963, 35 | 786, 720, 628, 779, 982, 633, 891, 734, 980, 386, 365, 794, 325, 841, 878, 370, 695, 293, 951, 66, 594, 717, 36 | 116, 488, 796, 983, 646, 499, 53, 1, 603, 45, 424, 875, 254, 237, 199, 414, 307, 362, 557, 866, 341, 19, 37 | 965, 143, 555, 687, 235, 790, 125, 173, 364, 882, 727, 728, 563, 495, 21, 558, 709, 719, 877, 352, 83, 998, 38 | 991, 469, 967, 760, 498, 814, 612, 715, 290, 72, 131, 259, 441, 924, 773, 48, 625, 501, 440, 82, 684, 862, 39 | 574, 309, 408, 680, 623, 439, 180, 652, 968, 889, 334, 61, 766, 399, 598, 798, 653, 930, 149, 249, 890, 308, 40 | 881, 40, 835, 577, 422, 703, 813, 857, 995, 602, 583, 167, 670, 212, 751, 496, 608, 84, 639, 579, 178, 489, 41 | 37, 197, 789, 530, 111, 876, 570, 700, 444, 287, 366, 883, 385, 536, 460, 851, 81, 144, 60, 251, 13, 953, 42 | 270, 944, 319, 885, 710, 952, 517, 278, 656, 919, 377, 550, 207, 660, 984, 447, 553, 338, 234, 383, 749, 43 | 916, 626, 462, 788, 434, 714, 799, 821, 477, 549, 661, 206, 667, 541, 642, 689, 194, 152, 981, 938, 854, 44 | 483, 332, 280, 546, 389, 405, 545, 239, 896, 672, 923, 402, 423, 907, 888, 140, 870, 559, 756, 25, 211, 158, 45 | 723, 635, 302, 702, 453, 218, 164, 829, 247, 775, 191, 732, 115, 331, 901, 416, 873, 754, 900, 435, 762, 46 | 124, 304, 329, 349, 295, 95, 451, 285, 225, 945, 697, 417] 47 | -------------------------------------------------------------------------------- /options/data/imagenet100_10-10.yaml: -------------------------------------------------------------------------------- 1 | data_set: imagenet100 2 | initial_increment: 10 3 | increment: 10 4 | #memory_size: 2000 5 | 6 | log_category: 10-10 7 | -------------------------------------------------------------------------------- /options/data/imagenet100_joint.yaml: -------------------------------------------------------------------------------- 1 | data_set: imagenet100 2 | initial_increment: 100 3 | increment: 100 4 | 5 | log_category: joint 6 | -------------------------------------------------------------------------------- /options/data/imagenet100_order1.yaml: -------------------------------------------------------------------------------- 1 | class_order: [68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 2 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 3 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 4 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33] 5 | -------------------------------------------------------------------------------- /options/data/imagenet100_order2.yaml: -------------------------------------------------------------------------------- 1 | class_order: [38, 34, 87, 75, 81, 17, 99, 67, 69, 12, 28, 25, 32, 42, 61, 5, 82, 2 | 58, 13, 94, 47, 51, 84, 39, 49, 8, 57, 1, 55, 36, 10, 20, 64, 2, 3 | 78, 44, 43, 48, 23, 21, 53, 37, 74, 27, 92, 77, 98, 18, 79, 66, 90, 4 | 46, 68, 26, 9, 83, 80, 30, 22, 16, 29, 97, 41, 85, 0, 52, 15, 14, 5 | 86, 63, 24, 59, 54, 11, 76, 70, 6, 7, 56, 93, 89, 50, 71, 4, 19, 6 | 88, 96, 62, 3, 31, 91, 95, 72, 60, 73, 65, 33, 45, 40, 35] 7 | -------------------------------------------------------------------------------- /options/data/imagenet100_order3.yaml: -------------------------------------------------------------------------------- 1 | class_order: [86, 37, 98, 64, 38, 90, 70, 27, 9, 44, 3, 59, 94, 57, 52, 89, 5, 2 | 99, 72, 10, 78, 41, 32, 7, 61, 71, 23, 46, 29, 74, 42, 11, 21, 54, 3 | 45, 53, 77, 24, 35, 22, 88, 2, 49, 1, 17, 31, 58, 0, 50, 73, 96, 4 | 33, 62, 56, 97, 87, 20, 6, 55, 80, 51, 66, 85, 63, 18, 15, 67, 76, 5 | 26, 75, 65, 14, 36, 68, 43, 40, 92, 82, 39, 93, 25, 91, 84, 79, 34, 6 | 12, 69, 16, 48, 81, 47, 60, 8, 95, 19, 4, 13, 30, 83, 28] -------------------------------------------------------------------------------- /options/data/tinyimg_20-20.yaml: -------------------------------------------------------------------------------- 1 | data_set: TINYIMG 2 | initial_increment: 20 3 | increment: 20 4 | #memory_size: 2000 5 | 6 | log_category: 20-20 7 | -------------------------------------------------------------------------------- /options/data/tinyimg_joint.yaml: -------------------------------------------------------------------------------- 1 | data_set: TINYIMG 2 | initial_increment: 200 3 | increment: 200 4 | 5 | log_category: joint 6 | -------------------------------------------------------------------------------- /options/data/tinyimg_order1.yaml: -------------------------------------------------------------------------------- 1 | class_order: [154, 32, 34, 177, 45, 56, 167, 194, 22, 81, 118, 190, 124, 101, 78, 33, 153, 83, 55, 76, 179, 64, 163, 66, 59, 156, 88, 108, 48, 171, 119, 123, 44, 172, 19, 120, 182, 93, 39, 187, 164, 180, 30, 100, 188, 175, 192, 129, 104, 4, 86, 72, 102, 173, 98, 110, 132, 122, 144, 99, 197, 155, 106, 165, 89, 54, 152, 11, 189, 195, 9, 184, 79, 109, 28, 134, 84, 27, 121, 136, 169, 117, 21, 199, 142, 133, 157, 43, 51, 73, 87, 68, 50, 25, 139, 112, 168, 77, 131, 196, 12, 170, 75, 38, 185, 31, 15, 46, 115, 166, 159, 130, 160, 113, 186, 94, 80, 96, 37, 97, 65, 82, 63, 49, 181, 114, 18, 91, 71, 191, 7, 126, 135, 127, 10, 16, 176, 42, 17, 1, 5, 41, 8, 151, 74, 26, 149, 14, 69, 47, 36, 85, 150, 174, 183, 161, 52, 29, 58, 60, 128, 116, 111, 13, 95, 107, 70, 140, 162, 148, 20, 23, 53, 158, 67, 2, 57, 92, 0, 3, 141, 90, 147, 145, 198, 105, 146, 24, 35, 61, 103, 138, 178, 137, 193, 143, 62, 6, 40, 125] -------------------------------------------------------------------------------- /options/data/tinyimg_order2.yaml: -------------------------------------------------------------------------------- 1 | class_order: [ 48, 131, 195, 39, 16, 103, 164, 179, 59, 198, 171, 94, 134, 112, 185, 0, 168, 100, 163, 83, 50, 28, 141, 13, 110, 189, 91, 26, 150, 142, 82, 180, 101, 144, 57, 1, 122, 71, 160, 104, 174, 86, 14, 9, 191, 19, 32, 132, 33, 135, 7, 97, 44, 115, 66, 116, 165, 156, 178, 52, 80, 63, 118, 22, 42, 113, 69, 123, 151, 84, 31, 36, 126, 18, 47, 133, 169, 24, 147, 85, 148, 67, 92, 172, 106, 54, 99, 111, 5, 70, 89, 53, 23, 117, 196, 75, 170, 96, 11, 76, 199, 152, 2, 109, 61, 8, 15, 114, 186, 68, 6, 130, 159, 21, 173, 137, 167, 162, 38, 41, 65, 102, 51, 35, 3, 4, 149, 98, 176, 56, 139, 193, 20, 45, 64, 143, 157, 90, 153, 37, 95, 155, 183, 128, 81, 182, 60, 34, 87, 175, 29, 194, 12, 58, 181, 197, 55, 120, 72, 124, 184, 161, 88, 17, 93, 78, 30, 79, 138, 127, 154, 10, 121, 25, 145, 46, 187, 108, 158, 77, 177, 136, 40, 140, 105, 107, 73, 125, 146, 119, 27, 188, 166, 74, 192, 190, 129, 43, 49, 62] -------------------------------------------------------------------------------- /options/data/tinyimg_order3.yaml: -------------------------------------------------------------------------------- 1 | class_order: [107, 74, 135, 156, 180, 57, 79, 174, 115, 29, 40, 170, 164, 98, 31, 133, 68, 161, 188, 52, 53, 0, 165, 55, 197, 87, 147, 82, 3, 75, 172, 27, 124, 142, 45, 84, 67, 121, 138, 182, 94, 69, 118, 5, 105, 83, 100, 176, 61, 95, 37, 85, 18, 187, 117, 152, 104, 43, 51, 139, 23, 126, 70, 193, 177, 88, 185, 71, 48, 63, 9, 160, 155, 158, 81, 129, 12, 72, 130, 134, 33, 169, 14, 89, 166, 140, 157, 16, 32, 28, 15, 7, 6, 131, 183, 30, 90, 78, 8, 109, 21, 119, 191, 120, 179, 49, 175, 114, 116, 122, 17, 111, 159, 24, 132, 137, 145, 44, 58, 141, 150, 198, 192, 143, 10, 110, 195, 60, 136, 92, 144, 153, 127, 20, 39, 4, 86, 154, 181, 184, 125, 50, 108, 151, 13, 80, 19, 194, 64, 2, 47, 1, 123, 46, 38, 77, 76, 22, 42, 199, 171, 162, 103, 168, 106, 26, 35, 128, 112, 66, 41, 59, 62, 73, 186, 36, 178, 163, 97, 56, 190, 149, 101, 34, 102, 173, 91, 189, 93, 54, 65, 167, 148, 196, 146, 96, 25, 99, 113, 11] -------------------------------------------------------------------------------- /options/model/cifar_birt.yaml: -------------------------------------------------------------------------------- 1 | ####################### 2 | # DyTox, for CIFAR100 # 3 | ####################### 4 | 5 | # Model definition 6 | model: convit 7 | embed_dim: 384 8 | depth: 6 9 | num_heads: 12 10 | patch_size: 4 11 | input_size: 32 12 | local_up_to_layer: 5 13 | class_attention: true 14 | 15 | # Training setting 16 | no_amp: true 17 | eval_every: 50 18 | 19 | # Base hyperparameter 20 | weight_decay: 0.000001 21 | batch_size: 128 22 | incremental_batch_size: 128 23 | incremental_lr: 0.0005 24 | rehearsal: icarl_all 25 | 26 | # Knowledge Distillation 27 | auto_kd: true 28 | 29 | # Finetuning 30 | finetuning: balanced 31 | finetuning_epochs: 20 32 | 33 | # Dytox model 34 | dytox: true 35 | freeze_task: [old_task_tokens, old_heads] 36 | freeze_ft: [sab] 37 | 38 | # Divergence head to get diversity 39 | # head_div: 0 40 | head_div: 0.1 41 | head_div_mode: tr 42 | 43 | # Independent Classifiers 44 | ind_clf: 1-1 45 | bce_loss: true 46 | 47 | 48 | # Advanced Augmentations, here disabled 49 | 50 | ## Erasing 51 | reprob: 0.0 52 | remode: pixel 53 | recount: 1 54 | resplit: false 55 | 56 | ## MixUp & CutMix 57 | mixup: 0.0 58 | cutmix: 0.0 59 | -------------------------------------------------------------------------------- /options/model/imagenet_birt.yaml: -------------------------------------------------------------------------------- 1 | ####################### 2 | # DyTox, for CIFAR100 # 3 | ####################### 4 | 5 | # Model definition 6 | model: convit 7 | embed_dim: 384 8 | depth: 6 9 | num_heads: 12 10 | patch_size: 16 11 | input_size: 224 12 | local_up_to_layer: 5 13 | class_attention: true 14 | 15 | #batch_size: 64 16 | #incremental_batch_size: 64 17 | 18 | # Training setting 19 | no_amp: false 20 | eval_every: 250 21 | 22 | # Base hyperparameter 23 | weight_decay: 0.000001 24 | batch_size: 128 25 | incremental_batch_size: 128 26 | incremental_lr: 0.0005 27 | rehearsal: icarl_all 28 | 29 | # Knowledge Distillation 30 | auto_kd: true 31 | 32 | # Finetuning 33 | finetuning: balanced 34 | finetuning_epochs: 20 35 | 36 | # Dytox model 37 | dytox: true 38 | freeze_task: [old_task_tokens, old_heads] 39 | freeze_ft: [sab] 40 | 41 | # Divergence head to get diversity 42 | head_div: 0.1 43 | head_div_mode: tr 44 | 45 | # Independent Classifiers 46 | ind_clf: 1-1 47 | bce_loss: true 48 | 49 | 50 | # Advanced Augmentations, here disabled 51 | 52 | ## Erasing 53 | reprob: 0.0 54 | remode: pixel 55 | recount: 1 56 | resplit: false 57 | 58 | ## MixUp & CutMix 59 | mixup: 0.0 60 | cutmix: 0.0 61 | -------------------------------------------------------------------------------- /options/model/tinyimg_birt.yaml: -------------------------------------------------------------------------------- 1 | ####################### 2 | # DyTox, for CIFAR100 # 3 | ####################### 4 | 5 | # Model definition 6 | model: convit 7 | embed_dim: 384 8 | depth: 6 9 | num_heads: 12 10 | patch_size: 8 11 | input_size: 64 12 | local_up_to_layer: 5 13 | class_attention: true 14 | 15 | # Training setting 16 | no_amp: true 17 | eval_every: 50 18 | 19 | # Base hyperparameter 20 | weight_decay: 0.000001 21 | batch_size: 128 22 | incremental_batch_size: 128 23 | incremental_lr: 0.0005 24 | rehearsal: icarl_all 25 | 26 | # Knowledge Distillation 27 | auto_kd: true 28 | 29 | # Finetuning 30 | finetuning: balanced 31 | finetuning_epochs: 20 32 | 33 | # Dytox model 34 | dytox: true 35 | freeze_task: [old_task_tokens, old_heads] 36 | freeze_ft: [sab] 37 | 38 | # Divergence head to get diversity 39 | # head_div: 0 40 | head_div: 0.1 41 | head_div_mode: tr 42 | 43 | # Independent Classifiers 44 | ind_clf: 1-1 45 | bce_loss: true 46 | 47 | 48 | # Advanced Augmentations, here disabled 49 | 50 | ## Erasing 51 | reprob: 0.0 52 | remode: pixel 53 | recount: 1 54 | resplit: false 55 | 56 | ## MixUp & CutMix 57 | mixup: 0.0 58 | cutmix: 0.0 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | torchvision>=0.8.1 3 | timm>=0.3.2 4 | continuum>=1.0.27 5 | --------------------------------------------------------------------------------