├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── cifar100.py ├── mnist_permutations.py ├── mnist_rotations.py └── raw │ └── raw.py ├── main.py ├── metrics ├── __init__.py └── metrics.py ├── model ├── __init__.py ├── common.py ├── ewc.py ├── gem.py ├── icarl.py ├── independent.py ├── multimodal.py └── single.py ├── requirements.txt ├── results └── plot_results.py └── run_experiments.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.txt 3 | *.pyc 4 | *.pdf 5 | *.tar.gz 6 | *.npz 7 | 8 | /data/raw/cifar-100-python 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to GradientEpisodicMemory 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to GradientEpisodicMemory, you agree that your contributions 31 | will be licensed under the LICENSE file in the root directory of this source 32 | tree. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient Episodic Memory for Continual Learning 2 | 3 | Source code for [the paper](https://arxiv.org/abs/1706.08840): 4 | 5 | ``` 6 | @inproceedings{GradientEpisodicMemory, 7 | title={Gradient Episodic Memory for Continual Learning}, 8 | author={Lopez-Paz, David and Ranzato, Marc'Aurelio}, 9 | booktitle={NIPS}, 10 | year={2017} 11 | } 12 | ``` 13 | 14 | To replicate the experiments, execute `./run_experiments.sh`. 15 | 16 | This source code is released under a Attribution-NonCommercial 4.0 International 17 | license, find out more about it [here](LICENSE). 18 | -------------------------------------------------------------------------------- /data/cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os.path 9 | import torch 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--i', default='raw/cifar100.pt', help='input directory') 14 | parser.add_argument('--o', default='cifar100.pt', help='output file') 15 | parser.add_argument('--n_tasks', default=10, type=int, help='number of tasks') 16 | parser.add_argument('--seed', default=0, type=int, help='random seed') 17 | args = parser.parse_args() 18 | 19 | torch.manual_seed(args.seed) 20 | 21 | tasks_tr = [] 22 | tasks_te = [] 23 | 24 | x_tr, y_tr, x_te, y_te = torch.load(os.path.join(args.i)) 25 | x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0 26 | x_te = x_te.float().view(x_te.size(0), -1) / 255.0 27 | 28 | cpt = int(100 / args.n_tasks) 29 | 30 | for t in range(args.n_tasks): 31 | c1 = t * cpt 32 | c2 = (t + 1) * cpt 33 | i_tr = ((y_tr >= c1) & (y_tr < c2)).nonzero().view(-1) 34 | i_te = ((y_te >= c1) & (y_te < c2)).nonzero().view(-1) 35 | tasks_tr.append([(c1, c2), x_tr[i_tr].clone(), y_tr[i_tr].clone()]) 36 | tasks_te.append([(c1, c2), x_te[i_te].clone(), y_te[i_te].clone()]) 37 | 38 | torch.save([tasks_tr, tasks_te], args.o) 39 | -------------------------------------------------------------------------------- /data/mnist_permutations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os.path 9 | import torch 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--i', default='raw/', help='input directory') 14 | parser.add_argument('--o', default='mnist_permutations.pt', help='output file') 15 | parser.add_argument('--n_tasks', default=3, type=int, help='number of tasks') 16 | parser.add_argument('--seed', default=0, type=int, help='random seed') 17 | args = parser.parse_args() 18 | 19 | torch.manual_seed(args.seed) 20 | 21 | tasks_tr = [] 22 | tasks_te = [] 23 | 24 | x_tr, y_tr = torch.load(os.path.join(args.i, 'mnist_train.pt')) 25 | x_te, y_te = torch.load(os.path.join(args.i, 'mnist_test.pt')) 26 | x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0 27 | x_te = x_te.float().view(x_te.size(0), -1) / 255.0 28 | y_tr = y_tr.view(-1).long() 29 | y_te = y_te.view(-1).long() 30 | 31 | for t in range(args.n_tasks): 32 | p = torch.randperm(x_tr.size(1)).long().view(-1) 33 | 34 | tasks_tr.append(['random permutation', x_tr.index_select(1, p), y_tr]) 35 | tasks_te.append(['random permutation', x_te.index_select(1, p), y_te]) 36 | 37 | torch.save([tasks_tr, tasks_te], args.o) 38 | -------------------------------------------------------------------------------- /data/mnist_rotations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchvision import transforms 8 | from PIL import Image 9 | import argparse 10 | import os.path 11 | import random 12 | import torch 13 | 14 | 15 | def rotate_dataset(d, rotation): 16 | result = torch.FloatTensor(d.size(0), 784) 17 | tensor = transforms.ToTensor() 18 | 19 | for i in range(d.size(0)): 20 | img = Image.fromarray(d[i].numpy(), mode='L') 21 | result[i] = tensor(img.rotate(rotation)).view(784) 22 | return result 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--i', default='raw/', help='input directory') 28 | parser.add_argument('--o', default='mnist_rotations.pt', help='output file') 29 | parser.add_argument('--n_tasks', default=10, type=int, help='number of tasks') 30 | parser.add_argument('--min_rot', default=0., 31 | type=float, help='minimum rotation') 32 | parser.add_argument('--max_rot', default=90., 33 | type=float, help='maximum rotation') 34 | parser.add_argument('--seed', default=0, type=int, help='random seed') 35 | 36 | args = parser.parse_args() 37 | 38 | torch.manual_seed(args.seed) 39 | 40 | tasks_tr = [] 41 | tasks_te = [] 42 | 43 | x_tr, y_tr = torch.load(os.path.join(args.i, 'mnist_train.pt')) 44 | x_te, y_te = torch.load(os.path.join(args.i, 'mnist_test.pt')) 45 | 46 | for t in range(args.n_tasks): 47 | min_rot = 1.0 * t / args.n_tasks * (args.max_rot - args.min_rot) + \ 48 | args.min_rot 49 | max_rot = 1.0 * (t + 1) / args.n_tasks * \ 50 | (args.max_rot - args.min_rot) + args.min_rot 51 | rot = random.random() * (max_rot - min_rot) + min_rot 52 | 53 | tasks_tr.append([rot, rotate_dataset(x_tr, rot), y_tr]) 54 | tasks_te.append([rot, rotate_dataset(x_te, rot), y_te]) 55 | 56 | torch.save([tasks_tr, tasks_te], args.o) 57 | -------------------------------------------------------------------------------- /data/raw/raw.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import subprocess 9 | import pickle 10 | import torch 11 | import os 12 | 13 | cifar_path = "cifar-100-python.tar.gz" 14 | mnist_path = "mnist.npz" 15 | 16 | # URL from: https://www.cs.toronto.edu/~kriz/cifar.html 17 | if not os.path.exists(cifar_path): 18 | subprocess.call("wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz", shell=True) 19 | 20 | subprocess.call("tar xzfv cifar-100-python.tar.gz", shell=True) 21 | 22 | # URL from: https://github.com/fchollet/keras/blob/master/keras/datasets/mnist.py 23 | if not os.path.exists(mnist_path): 24 | subprocess.call("wget https://s3.amazonaws.com/img-datasets/mnist.npz", shell=True) 25 | 26 | def unpickle(file): 27 | with open(file, 'rb') as fo: 28 | dict = pickle.load(fo, encoding='bytes') 29 | return dict 30 | 31 | cifar100_train = unpickle('cifar-100-python/train') 32 | cifar100_test = unpickle('cifar-100-python/test') 33 | 34 | x_tr = torch.from_numpy(cifar100_train[b'data']) 35 | y_tr = torch.LongTensor(cifar100_train[b'fine_labels']) 36 | x_te = torch.from_numpy(cifar100_test[b'data']) 37 | y_te = torch.LongTensor(cifar100_test[b'fine_labels']) 38 | 39 | torch.save((x_tr, y_tr, x_te, y_te), 'cifar100.pt') 40 | 41 | f = np.load('mnist.npz') 42 | x_tr = torch.from_numpy(f['x_train']) 43 | y_tr = torch.from_numpy(f['y_train']).long() 44 | x_te = torch.from_numpy(f['x_test']) 45 | y_te = torch.from_numpy(f['y_test']).long() 46 | f.close() 47 | 48 | torch.save((x_tr, y_tr), 'mnist_train.pt') 49 | torch.save((x_te, y_te), 'mnist_test.pt') 50 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib 8 | import datetime 9 | import argparse 10 | import random 11 | import uuid 12 | import time 13 | import os 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from metrics.metrics import confusion_matrix 19 | 20 | # continuum iterator ######################################################### 21 | 22 | 23 | def load_datasets(args): 24 | d_tr, d_te = torch.load(args.data_path + '/' + args.data_file) 25 | n_inputs = d_tr[0][1].size(1) 26 | n_outputs = 0 27 | for i in range(len(d_tr)): 28 | n_outputs = max(n_outputs, d_tr[i][2].max().item()) 29 | n_outputs = max(n_outputs, d_te[i][2].max().item()) 30 | return d_tr, d_te, n_inputs, n_outputs + 1, len(d_tr) 31 | 32 | 33 | class Continuum: 34 | 35 | def __init__(self, data, args): 36 | self.data = data 37 | self.batch_size = args.batch_size 38 | n_tasks = len(data) 39 | task_permutation = range(n_tasks) 40 | 41 | if args.shuffle_tasks == 'yes': 42 | task_permutation = torch.randperm(n_tasks).tolist() 43 | 44 | sample_permutations = [] 45 | 46 | for t in range(n_tasks): 47 | N = data[t][1].size(0) 48 | if args.samples_per_task <= 0: 49 | n = N 50 | else: 51 | n = min(args.samples_per_task, N) 52 | 53 | p = torch.randperm(N)[0:n] 54 | sample_permutations.append(p) 55 | 56 | self.permutation = [] 57 | 58 | for t in range(n_tasks): 59 | task_t = task_permutation[t] 60 | for _ in range(args.n_epochs): 61 | task_p = [[task_t, i] for i in sample_permutations[task_t]] 62 | random.shuffle(task_p) 63 | self.permutation += task_p 64 | 65 | self.length = len(self.permutation) 66 | self.current = 0 67 | 68 | def __iter__(self): 69 | return self 70 | 71 | def next(self): 72 | return self.__next__() 73 | 74 | def __next__(self): 75 | if self.current >= self.length: 76 | raise StopIteration 77 | else: 78 | ti = self.permutation[self.current][0] 79 | j = [] 80 | i = 0 81 | while (((self.current + i) < self.length) and 82 | (self.permutation[self.current + i][0] == ti) and 83 | (i < self.batch_size)): 84 | j.append(self.permutation[self.current + i][1]) 85 | i += 1 86 | self.current += i 87 | j = torch.LongTensor(j) 88 | return self.data[ti][1][j], ti, self.data[ti][2][j] 89 | 90 | # train handle ############################################################### 91 | 92 | 93 | def eval_tasks(model, tasks, args): 94 | model.eval() 95 | result = [] 96 | for i, task in enumerate(tasks): 97 | t = i 98 | x = task[1] 99 | y = task[2] 100 | rt = 0 101 | 102 | eval_bs = x.size(0) 103 | 104 | for b_from in range(0, x.size(0), eval_bs): 105 | b_to = min(b_from + eval_bs, x.size(0) - 1) 106 | if b_from == b_to: 107 | xb = x[b_from].view(1, -1) 108 | yb = torch.LongTensor([y[b_to]]).view(1, -1) 109 | else: 110 | xb = x[b_from:b_to] 111 | yb = y[b_from:b_to] 112 | if args.cuda: 113 | xb = xb.cuda() 114 | _, pb = torch.max(model(xb, t).data.cpu(), 1, keepdim=False) 115 | rt += (pb == yb).float().sum() 116 | 117 | result.append(rt / x.size(0)) 118 | 119 | return result 120 | 121 | 122 | def life_experience(model, continuum, x_te, args): 123 | result_a = [] 124 | result_t = [] 125 | 126 | current_task = 0 127 | time_start = time.time() 128 | 129 | for (i, (x, t, y)) in enumerate(continuum): 130 | if(((i % args.log_every) == 0) or (t != current_task)): 131 | result_a.append(eval_tasks(model, x_te, args)) 132 | result_t.append(current_task) 133 | current_task = t 134 | 135 | v_x = x.view(x.size(0), -1) 136 | v_y = y.long() 137 | 138 | if args.cuda: 139 | v_x = v_x.cuda() 140 | v_y = v_y.cuda() 141 | 142 | model.train() 143 | model.observe(v_x, t, v_y) 144 | 145 | result_a.append(eval_tasks(model, x_te, args)) 146 | result_t.append(current_task) 147 | 148 | time_end = time.time() 149 | time_spent = time_end - time_start 150 | 151 | return torch.Tensor(result_t), torch.Tensor(result_a), time_spent 152 | 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser(description='Continuum learning') 156 | 157 | # model parameters 158 | parser.add_argument('--model', type=str, default='single', 159 | help='model to train') 160 | parser.add_argument('--n_hiddens', type=int, default=100, 161 | help='number of hidden neurons at each layer') 162 | parser.add_argument('--n_layers', type=int, default=2, 163 | help='number of hidden layers') 164 | 165 | # memory parameters 166 | parser.add_argument('--n_memories', type=int, default=0, 167 | help='number of memories per task') 168 | parser.add_argument('--memory_strength', default=0, type=float, 169 | help='memory strength (meaning depends on memory)') 170 | parser.add_argument('--finetune', default='no', type=str, 171 | help='whether to initialize nets in indep. nets') 172 | 173 | # optimizer parameters 174 | parser.add_argument('--n_epochs', type=int, default=1, 175 | help='Number of epochs per task') 176 | parser.add_argument('--batch_size', type=int, default=10, 177 | help='batch size') 178 | parser.add_argument('--lr', type=float, default=1e-3, 179 | help='SGD learning rate') 180 | 181 | # experiment parameters 182 | parser.add_argument('--cuda', type=str, default='no', 183 | help='Use GPU?') 184 | parser.add_argument('--seed', type=int, default=0, 185 | help='random seed') 186 | parser.add_argument('--log_every', type=int, default=100, 187 | help='frequency of logs, in minibatches') 188 | parser.add_argument('--save_path', type=str, default='results/', 189 | help='save models at the end of training') 190 | 191 | # data parameters 192 | parser.add_argument('--data_path', default='data/', 193 | help='path where data is located') 194 | parser.add_argument('--data_file', default='mnist_permutations.pt', 195 | help='data file') 196 | parser.add_argument('--samples_per_task', type=int, default=-1, 197 | help='training samples per task (all if negative)') 198 | parser.add_argument('--shuffle_tasks', type=str, default='no', 199 | help='present tasks in order') 200 | args = parser.parse_args() 201 | 202 | args.cuda = True if args.cuda == 'yes' else False 203 | args.finetune = True if args.finetune == 'yes' else False 204 | 205 | # multimodal model has one extra layer 206 | if args.model == 'multimodal': 207 | args.n_layers -= 1 208 | 209 | # unique identifier 210 | uid = uuid.uuid4().hex 211 | 212 | # initialize seeds 213 | torch.backends.cudnn.enabled = False 214 | torch.manual_seed(args.seed) 215 | np.random.seed(args.seed) 216 | random.seed(args.seed) 217 | if args.cuda: 218 | torch.cuda.manual_seed_all(args.seed) 219 | 220 | # load data 221 | x_tr, x_te, n_inputs, n_outputs, n_tasks = load_datasets(args) 222 | 223 | # set up continuum 224 | continuum = Continuum(x_tr, args) 225 | 226 | # load model 227 | Model = importlib.import_module('model.' + args.model) 228 | model = Model.Net(n_inputs, n_outputs, n_tasks, args) 229 | if args.cuda: 230 | model.cuda() 231 | 232 | # run model on continuum 233 | result_t, result_a, spent_time = life_experience( 234 | model, continuum, x_te, args) 235 | 236 | # prepare saving path and file name 237 | if not os.path.exists(args.save_path): 238 | os.makedirs(args.save_path) 239 | 240 | fname = args.model + '_' + args.data_file + '_' 241 | fname += datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 242 | fname += '_' + uid 243 | fname = os.path.join(args.save_path, fname) 244 | 245 | # save confusion matrix and print one line of stats 246 | stats = confusion_matrix(result_t, result_a, fname + '.txt') 247 | one_liner = str(vars(args)) + ' # ' 248 | one_liner += ' '.join(["%.3f" % stat for stat in stats]) 249 | print(fname + ': ' + one_liner + ' # ' + str(spent_time)) 250 | 251 | # save all results in binary file 252 | torch.save((result_t, result_a, model.state_dict(), 253 | stats, one_liner, args), fname + '.pt') 254 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import print_function 8 | 9 | import torch 10 | 11 | 12 | def task_changes(result_t): 13 | n_tasks = int(result_t.max() + 1) 14 | changes = [] 15 | current = result_t[0] 16 | for i, t in enumerate(result_t): 17 | if t != current: 18 | changes.append(i) 19 | current = t 20 | 21 | return n_tasks, changes 22 | 23 | 24 | def confusion_matrix(result_t, result_a, fname=None): 25 | nt, changes = task_changes(result_t) 26 | 27 | baseline = result_a[0] 28 | changes = torch.LongTensor(changes + [result_a.size(0)]) - 1 29 | result = result_a[changes] 30 | 31 | # acc[t] equals result[t,t] 32 | acc = result.diag() 33 | fin = result[nt - 1] 34 | # bwt[t] equals result[T,t] - acc[t] 35 | bwt = result[nt - 1] - acc 36 | 37 | # fwt[t] equals result[t-1,t] - baseline[t] 38 | fwt = torch.zeros(nt) 39 | for t in range(1, nt): 40 | fwt[t] = result[t - 1, t] - baseline[t] 41 | 42 | if fname is not None: 43 | f = open(fname, 'w') 44 | 45 | print(' '.join(['%.4f' % r for r in baseline]), file=f) 46 | print('|', file=f) 47 | for row in range(result.size(0)): 48 | print(' '.join(['%.4f' % r for r in result[row]]), file=f) 49 | print('', file=f) 50 | # print('Diagonal Accuracy: %.4f' % acc.mean(), file=f) 51 | print('Final Accuracy: %.4f' % fin.mean(), file=f) 52 | print('Backward: %.4f' % bwt.mean(), file=f) 53 | print('Forward: %.4f' % fwt.mean(), file=f) 54 | f.close() 55 | 56 | stats = [] 57 | # stats.append(acc.mean()) 58 | stats.append(fin.mean()) 59 | stats.append(bwt.mean()) 60 | stats.append(fwt.mean()) 61 | 62 | return stats 63 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.functional import relu, avg_pool2d 11 | 12 | 13 | def Xavier(m): 14 | if m.__class__.__name__ == 'Linear': 15 | fan_in, fan_out = m.weight.data.size(1), m.weight.data.size(0) 16 | std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out)) 17 | a = math.sqrt(3.0) * std 18 | m.weight.data.uniform_(-a, a) 19 | m.bias.data.fill_(0.0) 20 | 21 | 22 | class MLP(nn.Module): 23 | def __init__(self, sizes): 24 | super(MLP, self).__init__() 25 | layers = [] 26 | 27 | for i in range(0, len(sizes) - 1): 28 | layers.append(nn.Linear(sizes[i], sizes[i + 1])) 29 | if i < (len(sizes) - 2): 30 | layers.append(nn.ReLU()) 31 | 32 | self.net = nn.Sequential(*layers) 33 | self.net.apply(Xavier) 34 | 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | 39 | def conv3x3(in_planes, out_planes, stride=1): 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 41 | padding=1, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(BasicBlock, self).__init__() 49 | self.conv1 = conv3x3(in_planes, planes, stride) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion * planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 58 | stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = relu(self.bn1(self.conv1(x))) 64 | out = self.bn2(self.conv2(out)) 65 | out += self.shortcut(x) 66 | out = relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes, nf): 72 | super(ResNet, self).__init__() 73 | self.in_planes = nf 74 | 75 | self.conv1 = conv3x3(3, nf * 1) 76 | self.bn1 = nn.BatchNorm2d(nf * 1) 77 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) 81 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes) 82 | 83 | def _make_layer(self, block, planes, num_blocks, stride): 84 | strides = [stride] + [1] * (num_blocks - 1) 85 | layers = [] 86 | for stride in strides: 87 | layers.append(block(self.in_planes, planes, stride)) 88 | self.in_planes = planes * block.expansion 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | bsz = x.size(0) 93 | out = relu(self.bn1(self.conv1(x.view(bsz, 3, 32, 32)))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | return out 102 | 103 | 104 | def ResNet18(nclasses, nf=20): 105 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf) 106 | -------------------------------------------------------------------------------- /model/ewc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from .common import MLP, ResNet18 9 | 10 | 11 | class Net(torch.nn.Module): 12 | 13 | def __init__(self, 14 | n_inputs, 15 | n_outputs, 16 | n_tasks, 17 | args): 18 | super(Net, self).__init__() 19 | nl, nh = args.n_layers, args.n_hiddens 20 | self.reg = args.memory_strength 21 | 22 | # setup network 23 | self.is_cifar = (args.data_file == 'cifar100.pt') 24 | if self.is_cifar: 25 | self.net = ResNet18(n_outputs) 26 | else: 27 | self.net = MLP([n_inputs] + [nh] * nl + [n_outputs]) 28 | 29 | # setup optimizer 30 | self.opt = torch.optim.SGD(self.net.parameters(), lr=args.lr) 31 | 32 | # setup losses 33 | self.bce = torch.nn.CrossEntropyLoss() 34 | 35 | # setup memories 36 | self.current_task = 0 37 | self.fisher = {} 38 | self.optpar = {} 39 | self.memx = None 40 | self.memy = None 41 | 42 | if self.is_cifar: 43 | self.nc_per_task = n_outputs / n_tasks 44 | else: 45 | self.nc_per_task = n_outputs 46 | self.n_outputs = n_outputs 47 | self.n_memories = args.n_memories 48 | 49 | def compute_offsets(self, task): 50 | if self.is_cifar: 51 | offset1 = task * self.nc_per_task 52 | offset2 = (task + 1) * self.nc_per_task 53 | else: 54 | offset1 = 0 55 | offset2 = self.n_outputs 56 | return int(offset1), int(offset2) 57 | 58 | def forward(self, x, t): 59 | output = self.net(x) 60 | if self.is_cifar: 61 | # make sure we predict classes within the current task 62 | offset1, offset2 = self.compute_offsets(t) 63 | if offset1 > 0: 64 | output[:, :offset1].data.fill_(-10e10) 65 | if offset2 < self.n_outputs: 66 | output[:, int(offset2):self.n_outputs].data.fill_(-10e10) 67 | return output 68 | 69 | def observe(self, x, t, y): 70 | self.net.train() 71 | 72 | # next task? 73 | if t != self.current_task: 74 | self.net.zero_grad() 75 | 76 | if self.is_cifar: 77 | offset1, offset2 = self.compute_offsets(self.current_task) 78 | self.bce((self.net(self.memx)[:, offset1: offset2]), 79 | self.memy - offset1).backward() 80 | else: 81 | self.bce(self(self.memx, 82 | self.current_task), 83 | self.memy).backward() 84 | self.fisher[self.current_task] = [] 85 | self.optpar[self.current_task] = [] 86 | for p in self.net.parameters(): 87 | pd = p.data.clone() 88 | pg = p.grad.data.clone().pow(2) 89 | self.optpar[self.current_task].append(pd) 90 | self.fisher[self.current_task].append(pg) 91 | self.current_task = t 92 | self.memx = None 93 | self.memy = None 94 | 95 | if self.memx is None: 96 | self.memx = x.data.clone() 97 | self.memy = y.data.clone() 98 | else: 99 | if self.memx.size(0) < self.n_memories: 100 | self.memx = torch.cat((self.memx, x.data.clone())) 101 | self.memy = torch.cat((self.memy, y.data.clone())) 102 | if self.memx.size(0) > self.n_memories: 103 | self.memx = self.memx[:self.n_memories] 104 | self.memy = self.memy[:self.n_memories] 105 | 106 | self.net.zero_grad() 107 | if self.is_cifar: 108 | offset1, offset2 = self.compute_offsets(t) 109 | loss = self.bce((self.net(x)[:, offset1: offset2]), 110 | y - offset1) 111 | else: 112 | loss = self.bce(self(x, t), y) 113 | for tt in range(t): 114 | for i, p in enumerate(self.net.parameters()): 115 | l = self.reg * self.fisher[tt][i] 116 | l = l * (p - self.optpar[tt][i]).pow(2) 117 | loss += l.sum() 118 | loss.backward() 119 | self.opt.step() 120 | -------------------------------------------------------------------------------- /model/gem.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | 11 | import numpy as np 12 | import quadprog 13 | 14 | from .common import MLP, ResNet18 15 | 16 | # Auxiliary functions useful for GEM's inner optimization. 17 | 18 | def compute_offsets(task, nc_per_task, is_cifar): 19 | """ 20 | Compute offsets for cifar to determine which 21 | outputs to select for a given task. 22 | """ 23 | if is_cifar: 24 | offset1 = task * nc_per_task 25 | offset2 = (task + 1) * nc_per_task 26 | else: 27 | offset1 = 0 28 | offset2 = nc_per_task 29 | return offset1, offset2 30 | 31 | 32 | def store_grad(pp, grads, grad_dims, tid): 33 | """ 34 | This stores parameter gradients of past tasks. 35 | pp: parameters 36 | grads: gradients 37 | grad_dims: list with number of parameters per layers 38 | tid: task id 39 | """ 40 | # store the gradients 41 | grads[:, tid].fill_(0.0) 42 | cnt = 0 43 | for param in pp(): 44 | if param.grad is not None: 45 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 46 | en = sum(grad_dims[:cnt + 1]) 47 | grads[beg: en, tid].copy_(param.grad.data.view(-1)) 48 | cnt += 1 49 | 50 | 51 | def overwrite_grad(pp, newgrad, grad_dims): 52 | """ 53 | This is used to overwrite the gradients with a new gradient 54 | vector, whenever violations occur. 55 | pp: parameters 56 | newgrad: corrected gradient 57 | grad_dims: list storing number of parameters at each layer 58 | """ 59 | cnt = 0 60 | for param in pp(): 61 | if param.grad is not None: 62 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 63 | en = sum(grad_dims[:cnt + 1]) 64 | this_grad = newgrad[beg: en].contiguous().view( 65 | param.grad.data.size()) 66 | param.grad.data.copy_(this_grad) 67 | cnt += 1 68 | 69 | 70 | def project2cone2(gradient, memories, margin=0.5, eps=1e-3): 71 | """ 72 | Solves the GEM dual QP described in the paper given a proposed 73 | gradient "gradient", and a memory of task gradients "memories". 74 | Overwrites "gradient" with the final projected update. 75 | 76 | input: gradient, p-vector 77 | input: memories, (t * p)-vector 78 | output: x, p-vector 79 | """ 80 | memories_np = memories.cpu().t().double().numpy() 81 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy() 82 | t = memories_np.shape[0] 83 | P = np.dot(memories_np, memories_np.transpose()) 84 | P = 0.5 * (P + P.transpose()) + np.eye(t) * eps 85 | q = np.dot(memories_np, gradient_np) * -1 86 | G = np.eye(t) 87 | h = np.zeros(t) + margin 88 | v = quadprog.solve_qp(P, q, G, h)[0] 89 | x = np.dot(v, memories_np) + gradient_np 90 | gradient.copy_(torch.Tensor(x).view(-1, 1)) 91 | 92 | 93 | class Net(nn.Module): 94 | def __init__(self, 95 | n_inputs, 96 | n_outputs, 97 | n_tasks, 98 | args): 99 | super(Net, self).__init__() 100 | nl, nh = args.n_layers, args.n_hiddens 101 | self.margin = args.memory_strength 102 | self.is_cifar = (args.data_file == 'cifar100.pt') 103 | if self.is_cifar: 104 | self.net = ResNet18(n_outputs) 105 | else: 106 | self.net = MLP([n_inputs] + [nh] * nl + [n_outputs]) 107 | 108 | self.ce = nn.CrossEntropyLoss() 109 | self.n_outputs = n_outputs 110 | 111 | self.opt = optim.SGD(self.parameters(), args.lr) 112 | 113 | self.n_memories = args.n_memories 114 | self.gpu = args.cuda 115 | 116 | # allocate episodic memory 117 | self.memory_data = torch.FloatTensor( 118 | n_tasks, self.n_memories, n_inputs) 119 | self.memory_labs = torch.LongTensor(n_tasks, self.n_memories) 120 | if args.cuda: 121 | self.memory_data = self.memory_data.cuda() 122 | self.memory_labs = self.memory_labs.cuda() 123 | 124 | # allocate temporary synaptic memory 125 | self.grad_dims = [] 126 | for param in self.parameters(): 127 | self.grad_dims.append(param.data.numel()) 128 | self.grads = torch.Tensor(sum(self.grad_dims), n_tasks) 129 | if args.cuda: 130 | self.grads = self.grads.cuda() 131 | 132 | # allocate counters 133 | self.observed_tasks = [] 134 | self.old_task = -1 135 | self.mem_cnt = 0 136 | if self.is_cifar: 137 | self.nc_per_task = int(n_outputs / n_tasks) 138 | else: 139 | self.nc_per_task = n_outputs 140 | 141 | def forward(self, x, t): 142 | output = self.net(x) 143 | if self.is_cifar: 144 | # make sure we predict classes within the current task 145 | offset1 = int(t * self.nc_per_task) 146 | offset2 = int((t + 1) * self.nc_per_task) 147 | if offset1 > 0: 148 | output[:, :offset1].data.fill_(-10e10) 149 | if offset2 < self.n_outputs: 150 | output[:, offset2:self.n_outputs].data.fill_(-10e10) 151 | return output 152 | 153 | def observe(self, x, t, y): 154 | # update memory 155 | if t != self.old_task: 156 | self.observed_tasks.append(t) 157 | self.old_task = t 158 | 159 | # Update ring buffer storing examples from current task 160 | bsz = y.data.size(0) 161 | endcnt = min(self.mem_cnt + bsz, self.n_memories) 162 | effbsz = endcnt - self.mem_cnt 163 | self.memory_data[t, self.mem_cnt: endcnt].copy_( 164 | x.data[: effbsz]) 165 | if bsz == 1: 166 | self.memory_labs[t, self.mem_cnt] = y.data[0] 167 | else: 168 | self.memory_labs[t, self.mem_cnt: endcnt].copy_( 169 | y.data[: effbsz]) 170 | self.mem_cnt += effbsz 171 | if self.mem_cnt == self.n_memories: 172 | self.mem_cnt = 0 173 | 174 | # compute gradient on previous tasks 175 | if len(self.observed_tasks) > 1: 176 | for tt in range(len(self.observed_tasks) - 1): 177 | self.zero_grad() 178 | # fwd/bwd on the examples in the memory 179 | past_task = self.observed_tasks[tt] 180 | 181 | offset1, offset2 = compute_offsets(past_task, self.nc_per_task, 182 | self.is_cifar) 183 | ptloss = self.ce( 184 | self.forward( 185 | self.memory_data[past_task], 186 | past_task)[:, offset1: offset2], 187 | self.memory_labs[past_task] - offset1) 188 | ptloss.backward() 189 | store_grad(self.parameters, self.grads, self.grad_dims, 190 | past_task) 191 | 192 | # now compute the grad on the current minibatch 193 | self.zero_grad() 194 | 195 | offset1, offset2 = compute_offsets(t, self.nc_per_task, self.is_cifar) 196 | loss = self.ce(self.forward(x, t)[:, offset1: offset2], y - offset1) 197 | loss.backward() 198 | 199 | # check if gradient violates constraints 200 | if len(self.observed_tasks) > 1: 201 | # copy gradient 202 | store_grad(self.parameters, self.grads, self.grad_dims, t) 203 | indx = torch.cuda.LongTensor(self.observed_tasks[:-1]) if self.gpu \ 204 | else torch.LongTensor(self.observed_tasks[:-1]) 205 | dotp = torch.mm(self.grads[:, t].unsqueeze(0), 206 | self.grads.index_select(1, indx)) 207 | if (dotp < 0).sum() != 0: 208 | project2cone2(self.grads[:, t].unsqueeze(1), 209 | self.grads.index_select(1, indx), self.margin) 210 | # copy gradients back 211 | overwrite_grad(self.parameters, self.grads[:, t], 212 | self.grad_dims) 213 | self.opt.step() 214 | -------------------------------------------------------------------------------- /model/icarl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | import numpy as np 10 | import random 11 | 12 | from .common import ResNet18 13 | 14 | 15 | class Net(torch.nn.Module): 16 | # Re-implementation of 17 | # S.-A. Rebuffi, A. Kolesnikov, G. Sperl, and C. H. Lampert. 18 | # iCaRL: Incremental classifier and representation learning. 19 | # CVPR, 2017. 20 | def __init__(self, 21 | n_inputs, 22 | n_outputs, 23 | n_tasks, 24 | args): 25 | super(Net, self).__init__() 26 | self.nt = n_tasks 27 | self.reg = args.memory_strength 28 | self.n_memories = args.n_memories 29 | self.num_exemplars = 0 30 | self.n_feat = n_outputs 31 | self.n_classes = n_outputs 32 | self.samples_per_task = args.samples_per_task 33 | if self.samples_per_task <= 0: 34 | error('set explicitly args.samples_per_task') 35 | self.examples_seen = 0 36 | 37 | # setup network 38 | assert(args.data_file == 'cifar100.pt') 39 | self.net = ResNet18(n_outputs) 40 | 41 | # setup optimizer 42 | self.opt = torch.optim.SGD(self.parameters(), lr=args.lr) 43 | 44 | # setup losses 45 | self.bce = torch.nn.CrossEntropyLoss() 46 | self.kl = torch.nn.KLDivLoss() # for distillation 47 | self.lsm = torch.nn.LogSoftmax(dim=1) 48 | self.sm = torch.nn.Softmax(dim=1) 49 | 50 | # memory 51 | self.memx = None # stores raw inputs, PxD 52 | self.memy = None 53 | self.mem_class_x = {} # stores exemplars class by class 54 | self.mem_class_y = {} 55 | 56 | self.gpu = args.cuda 57 | self.nc_per_task = int(n_outputs / n_tasks) 58 | self.n_outputs = n_outputs 59 | 60 | def compute_offsets(self, task): 61 | offset1 = task * self.nc_per_task 62 | offset2 = (task + 1) * self.nc_per_task 63 | return int(offset1), int(offset2) 64 | 65 | def forward(self, x, t): 66 | # nearest neighbor 67 | nd = self.n_feat 68 | ns = x.size(0) 69 | if t * self.nc_per_task not in self.mem_class_x.keys(): 70 | # no exemplar in memory yet, output uniform distr. over classes in 71 | # task t above, we check presence of first class for this task, we 72 | # should check them all 73 | out = torch.Tensor(ns, self.n_classes).fill_(-10e10) 74 | out[:, int(t * self.nc_per_task): int((t + 1) * self.nc_per_task)].fill_( 75 | 1.0 / self.nc_per_task) 76 | if self.gpu: 77 | out = out.cuda() 78 | return out 79 | means = torch.ones(self.nc_per_task, nd) * float('inf') 80 | if self.gpu: 81 | means = means.cuda() 82 | offset1, offset2 = self.compute_offsets(t) 83 | for cc in range(offset1, offset2): 84 | means[cc - 85 | offset1] = self.net(self.mem_class_x[cc]).data.mean(0) 86 | classpred = torch.LongTensor(ns) 87 | preds = self.net(x).data.clone() 88 | for ss in range(ns): 89 | dist = (means - preds[ss].expand(self.nc_per_task, nd)).norm(2, 1) 90 | _, ii = dist.min(0) 91 | ii = ii.squeeze() 92 | classpred[ss] = ii.item() + offset1 93 | 94 | out = torch.zeros(ns, self.n_classes) 95 | if self.gpu: 96 | out = out.cuda() 97 | for ss in range(ns): 98 | out[ss, classpred[ss]] = 1 99 | return out # return 1-of-C code, ns x nc 100 | 101 | def forward_training(self, x, t): 102 | output = self.net(x) 103 | # make sure we predict classes within the current task 104 | offset1, offset2 = self.compute_offsets(t) 105 | if offset1 > 0: 106 | output[:, :offset1].data.fill_(-10e10) 107 | if offset2 < self.n_outputs: 108 | output[:, offset2:self.n_outputs].data.fill_(-10e10) 109 | return output 110 | 111 | def observe(self, x, t, y): 112 | self.net.train() 113 | self.examples_seen += x.size(0) 114 | 115 | if self.memx is None: 116 | self.memx = x.data.clone() 117 | self.memy = y.data.clone() 118 | else: 119 | self.memx = torch.cat((self.memx, x.data.clone())) 120 | self.memy = torch.cat((self.memy, y.data.clone())) 121 | 122 | self.net.zero_grad() 123 | offset1, offset2 = self.compute_offsets(t) 124 | loss = self.bce((self.net(x)[:, offset1: offset2]), 125 | y - offset1) 126 | 127 | if self.num_exemplars > 0: 128 | # distillation 129 | for tt in range(t): 130 | # first generate a minibatch with one example per class from 131 | # previous tasks 132 | inp_dist = torch.zeros(self.nc_per_task, x.size(1)) 133 | target_dist = torch.zeros(self.nc_per_task, self.n_feat) 134 | offset1, offset2 = self.compute_offsets(tt) 135 | if self.gpu: 136 | inp_dist = inp_dist.cuda() 137 | target_dist = target_dist.cuda() 138 | for cc in range(self.nc_per_task): 139 | indx = random.randint(0, len(self.mem_class_x[cc + offset1]) - 1) 140 | inp_dist[cc] = self.mem_class_x[cc + offset1][indx].clone() 141 | target_dist[cc] = self.mem_class_y[cc + 142 | offset1][indx].clone() 143 | # Add distillation loss 144 | loss += self.reg * self.kl( 145 | self.lsm(self.net(inp_dist) 146 | [:, offset1: offset2]), 147 | self.sm(target_dist[:, offset1: offset2])) * self.nc_per_task 148 | # bprop and update 149 | loss.backward() 150 | self.opt.step() 151 | 152 | # check whether this is the last minibatch of the current task 153 | # We assume only 1 epoch! 154 | if self.examples_seen == self.samples_per_task: 155 | self.examples_seen = 0 156 | # get labels from previous task; we assume labels are consecutive 157 | if self.gpu: 158 | all_labs = torch.LongTensor(np.unique(self.memy.cpu().numpy())) 159 | else: 160 | all_labs = torch.LongTensor(np.unique(self.memy.numpy())) 161 | num_classes = all_labs.size(0) 162 | assert(num_classes == self.nc_per_task) 163 | # Reduce exemplar set by updating value of num. exemplars per class 164 | self.num_exemplars = int(self.n_memories / 165 | (num_classes + len(self.mem_class_x.keys()))) 166 | offset1, offset2 = self.compute_offsets(t) 167 | for ll in range(num_classes): 168 | lab = all_labs[ll].cuda() 169 | indxs = (self.memy == lab).nonzero().squeeze() 170 | cdata = self.memx.index_select(0, indxs) 171 | 172 | # Construct exemplar set for last task 173 | mean_feature = self.net(cdata)[ 174 | :, offset1: offset2].data.clone().mean(0) 175 | nd = self.nc_per_task 176 | exemplars = torch.zeros(self.num_exemplars, x.size(1)) 177 | if self.gpu: 178 | exemplars = exemplars.cuda() 179 | ntr = cdata.size(0) 180 | # used to keep track of which examples we have already used 181 | taken = torch.zeros(ntr) 182 | model_output = self.net(cdata)[ 183 | :, offset1: offset2].data.clone() 184 | for ee in range(self.num_exemplars): 185 | prev = torch.zeros(1, nd) 186 | if self.gpu: 187 | prev = prev.cuda() 188 | if ee > 0: 189 | prev = self.net(exemplars[:ee])[ 190 | :, offset1: offset2].data.clone().sum(0) 191 | cost = (mean_feature.expand(ntr, nd) - (model_output 192 | + prev.expand(ntr, nd)) / (ee + 1)).norm(2, 1).squeeze() 193 | _, indx = cost.sort(0) 194 | winner = 0 195 | while winner < indx.size(0) and taken[indx[winner]] == 1: 196 | winner += 1 197 | if winner < indx.size(0): 198 | taken[indx[winner]] = 1 199 | exemplars[ee] = cdata[indx[winner]].clone() 200 | else: 201 | exemplars = exemplars[:indx.size(0), :].clone() 202 | self.num_exemplars = indx.size(0) 203 | break 204 | # update memory with exemplars 205 | self.mem_class_x[lab.item()] = exemplars.clone() 206 | 207 | # recompute outputs for distillation purposes 208 | for cc in self.mem_class_x.keys(): 209 | self.mem_class_y[cc] = self.net( 210 | self.mem_class_x[cc]).data.clone() 211 | self.memx = None 212 | self.memy = None 213 | -------------------------------------------------------------------------------- /model/independent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from .common import MLP, ResNet18 9 | 10 | 11 | class Net(torch.nn.Module): 12 | 13 | def __init__(self, 14 | n_inputs, 15 | n_outputs, 16 | n_tasks, 17 | args): 18 | super(Net, self).__init__() 19 | nl, nh = args.n_layers, args.n_hiddens 20 | self.nets = torch.nn.ModuleList() 21 | self.opts = [] 22 | 23 | self.is_cifar = (args.data_file == 'cifar100.pt') 24 | if self.is_cifar: 25 | self.nc_per_task = n_outputs / n_tasks 26 | self.n_outputs = n_outputs 27 | 28 | # setup network 29 | for _ in range(n_tasks): 30 | if self.is_cifar: 31 | self.nets.append( 32 | ResNet18(int(n_outputs / n_tasks), int(20 / n_tasks))) 33 | else: 34 | self.nets.append( 35 | MLP([n_inputs] + [int(nh / n_tasks)] * nl + [n_outputs])) 36 | 37 | # setup optimizer 38 | for t in range(n_tasks): 39 | self.opts.append(torch.optim.SGD(self.nets[t].parameters(), 40 | lr=args.lr)) 41 | 42 | # setup loss 43 | self.bce = torch.nn.CrossEntropyLoss() 44 | 45 | self.finetune = args.finetune 46 | self.gpu = args.cuda 47 | self.old_task = 0 48 | 49 | def forward(self, x, t): 50 | output = self.nets[t](x) 51 | if self.is_cifar: 52 | bigoutput = torch.Tensor(x.size(0), self.n_outputs) 53 | if self.gpu: 54 | bigoutput = bigoutput.cuda() 55 | bigoutput.fill_(-10e10) 56 | bigoutput[:, int(t * self.nc_per_task): int((t + 1) * self.nc_per_task)].copy_( 57 | output.data) 58 | return bigoutput 59 | else: 60 | return output 61 | 62 | def observe(self, x, t, y): 63 | # detect beginning of a new task 64 | if self.finetune and t > 0 and t != self.old_task: 65 | # initialize current network like the previous one 66 | for ppold, ppnew in zip(self.nets[self.old_task].parameters(), 67 | self.nets[t].parameters()): 68 | ppnew.data.copy_(ppold.data) 69 | self.old_task = t 70 | 71 | self.train() 72 | self.zero_grad() 73 | if self.is_cifar: 74 | self.bce(self.nets[t](x), y - int(t * self.nc_per_task)).backward() 75 | else: 76 | self.bce(self(x, t), y).backward() 77 | self.opts[t].step() 78 | -------------------------------------------------------------------------------- /model/multimodal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def reset_bias(m): 12 | m.bias.data.fill_(0.0) 13 | 14 | 15 | class Net(nn.Module): 16 | def __init__(self, 17 | n_inputs, 18 | n_outputs, 19 | n_tasks, 20 | args): 21 | super(Net, self).__init__() 22 | 23 | self.i_layer = nn.ModuleList() 24 | self.h_layer = nn.ModuleList() 25 | self.o_layer = nn.ModuleList() 26 | 27 | self.n_layers = args.n_layers 28 | nh = args.n_hiddens 29 | 30 | if self.n_layers > 0: 31 | # dedicated input layer 32 | for _ in range(n_tasks): 33 | self.i_layer += [nn.Linear(n_inputs, nh)] 34 | reset_bias(self.i_layer[-1]) 35 | 36 | # shared hidden layer 37 | self.h_layer += [nn.ModuleList()] 38 | for _ in range(self.n_layers): 39 | self.h_layer[0] += [nn.Linear(nh, nh)] 40 | reset_bias(self.h_layer[0][0]) 41 | 42 | # shared output layer 43 | self.o_layer += [nn.Linear(nh, n_outputs)] 44 | reset_bias(self.o_layer[-1]) 45 | 46 | # linear model falls back to independent models 47 | else: 48 | self.i_layer += [nn.Linear(n_inputs, n_outputs)] 49 | reset_bias(self.i_layer[-1]) 50 | 51 | self.relu = nn.ReLU() 52 | self.soft = nn.LogSoftmax(dim=1) 53 | self.loss = nn.NLLLoss() 54 | self.optimizer = torch.optim.SGD(self.parameters(), args.lr) 55 | 56 | def forward(self, x, t): 57 | h = x 58 | 59 | if self.n_layers == 0: 60 | y = self.soft(self.i_layer[t if isinstance(t, int) else t[0]](h)) 61 | else: 62 | # task-specific input 63 | h = self.relu(self.i_layer[t if isinstance(t, int) else t[0]](h)) 64 | # shared hiddens 65 | for l in range(self.n_layers): 66 | h = self.relu(self.h_layer[0][l](h)) 67 | # shared output 68 | y = self.soft(self.o_layer[0](h)) 69 | 70 | return y 71 | 72 | def observe(self, x, t, y): 73 | self.zero_grad() 74 | self.loss(self.forward(x, t), y).backward() 75 | self.optimizer.step() 76 | -------------------------------------------------------------------------------- /model/single.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from .common import MLP, ResNet18 9 | 10 | 11 | class Net(torch.nn.Module): 12 | 13 | def __init__(self, 14 | n_inputs, 15 | n_outputs, 16 | n_tasks, 17 | args): 18 | super(Net, self).__init__() 19 | nl, nh = args.n_layers, args.n_hiddens 20 | 21 | # setup network 22 | self.is_cifar = (args.data_file == 'cifar100.pt') 23 | if self.is_cifar: 24 | self.net = ResNet18(n_outputs) 25 | else: 26 | self.net = MLP([n_inputs] + [nh] * nl + [n_outputs]) 27 | 28 | # setup optimizer 29 | self.opt = torch.optim.SGD(self.parameters(), lr=args.lr) 30 | 31 | # setup losses 32 | self.bce = torch.nn.CrossEntropyLoss() 33 | 34 | if self.is_cifar: 35 | self.nc_per_task = n_outputs / n_tasks 36 | else: 37 | self.nc_per_task = n_outputs 38 | self.n_outputs = n_outputs 39 | 40 | def compute_offsets(self, task): 41 | if self.is_cifar: 42 | offset1 = task * self.nc_per_task 43 | offset2 = (task + 1) * self.nc_per_task 44 | else: 45 | offset1 = 0 46 | offset2 = self.n_outputs 47 | return int(offset1), int(offset2) 48 | 49 | def forward(self, x, t): 50 | output = self.net(x) 51 | if self.is_cifar: 52 | # make sure we predict classes within the current task 53 | offset1, offset2 = self.compute_offsets(t) 54 | if offset1 > 0: 55 | output[:, :offset1].data.fill_(-10e10) 56 | if offset2 < self.n_outputs: 57 | output[:, offset2:self.n_outputs].data.fill_(-10e10) 58 | return output 59 | 60 | def observe(self, x, t, y): 61 | self.train() 62 | self.zero_grad() 63 | if self.is_cifar: 64 | offset1, offset2 = self.compute_offsets(t) 65 | self.bce((self.net(x)[:, offset1: offset2]), 66 | y - offset1).backward() 67 | else: 68 | self.bce(self(x, t), y).backward() 69 | self.opt.step() 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | datetime 3 | glob 4 | importlib 5 | math 6 | matplotlib 7 | numpy 8 | os 9 | PIL 10 | pickle 11 | quadprog 12 | random 13 | subprocess 14 | time 15 | torch 16 | torchvision 17 | uuid 18 | -------------------------------------------------------------------------------- /results/plot_results.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import matplotlib as mpl 8 | mpl.use('Agg') 9 | 10 | # if 'roman' in mpl.font_manager.weight_dict.keys(): 11 | # del mpl.font_manager.weight_dict['roman'] 12 | # mpl.font_manager._rebuild() 13 | 14 | mpl.rcParams["font.family"] = "Times New Roman" 15 | mpl.rcParams["font.family"] = "DejaVu Serif" 16 | 17 | from matplotlib import pyplot as plt 18 | from glob import glob 19 | import numpy as np 20 | import torch 21 | 22 | models = ['single', 'independent', 'multimodal', 'icarl', 'ewc', 'gem'] 23 | datasets = ['mnist_permutations', 'mnist_rotations', 'cifar100'] 24 | 25 | names_datasets = {'mnist_permutations': 'MNIST permutations', 26 | 'mnist_rotations': 'MNIST rotations', 27 | 'cifar100': 'CIFAR-100'} 28 | 29 | names_models = {'single': 'single', 30 | 'independent': 'independent', 31 | 'multimodal': 'multimodal', 32 | 'icarl': 'iCARL', 33 | 'ewc': 'EWC', 34 | 'gem': 'GEM'} 35 | 36 | colors = {'single': 'C0', 37 | 'independent': 'C1', 38 | 'multimodal': 'C2', 39 | 'icarl': 'C2', 40 | 'ewc': 'C3', 41 | 'gem': 'C4'} 42 | 43 | barplot = {} 44 | 45 | for dataset in datasets: 46 | barplot[dataset] = {} 47 | for model in models: 48 | barplot[dataset][model] = {} 49 | matches = glob(model + '*' + dataset + '*.pt') 50 | if len(matches): 51 | data = torch.load(matches[0], map_location=lambda storage, loc: storage) 52 | acc, bwt, fwt = data[3][:] 53 | barplot[dataset][model]['acc'] = acc 54 | barplot[dataset][model]['bwt'] = bwt 55 | barplot[dataset][model]['fwt'] = fwt 56 | 57 | for dataset in datasets: 58 | x_lab = [] 59 | y_acc = [] 60 | y_bwt = [] 61 | y_fwt = [] 62 | 63 | for i, model in enumerate(models): 64 | if barplot[dataset][model] != {}: 65 | x_lab.append(model) 66 | y_acc.append(barplot[dataset][model]['acc']) 67 | y_bwt.append(barplot[dataset][model]['bwt']) 68 | y_fwt.append(barplot[dataset][model]['fwt']) 69 | 70 | x_ind = np.arange(len(y_acc)) 71 | 72 | plt.figure(figsize=(7, 3)) 73 | all_colors = [] 74 | for xi, yi, li in zip(x_ind, y_acc, x_lab): 75 | plt.bar(xi, yi, label=names_models[li], color=colors[li]) 76 | all_colors.append(colors[li]) 77 | plt.bar(x_ind + (len(y_acc) + 1) * 1, y_bwt, color=all_colors) 78 | plt.bar(x_ind + (len(y_acc) + 1) * 2, y_fwt, color=all_colors) 79 | plt.xticks([2, 8, 14], ['ACC', 'BWT', 'FWT'], fontsize=16) 80 | plt.yticks(fontsize=16) 81 | plt.xlim(-1, len(y_acc) * 3 + 2) 82 | plt.ylabel('classification accuracy', fontsize=16) 83 | plt.title(names_datasets[dataset], fontsize=16) 84 | plt.legend(fontsize=12) 85 | plt.tight_layout() 86 | plt.savefig('barplot_%s.pdf' % dataset, bbox_inches='tight') 87 | # plt.show() 88 | 89 | evoplot = {} 90 | 91 | for dataset in datasets: 92 | evoplot[dataset] = {} 93 | for model in models: 94 | matches = glob(model + '*' + dataset + '*.pt') 95 | if len(matches): 96 | data = torch.load(matches[0], map_location=lambda storage, loc: storage) 97 | evoplot[dataset][model] = data[1][:, 0].numpy() 98 | 99 | for dataset in datasets: 100 | 101 | plt.figure(figsize=(7, 3)) 102 | for model in models: 103 | if model in evoplot[dataset]: 104 | x = np.arange(len(evoplot[dataset][model])) 105 | x = (x - x.min()) / (x.max() - x.min()) * 20 106 | plt.plot(x, evoplot[dataset][model], color=colors[model], lw=3) 107 | plt.xticks(range(0, 21, 2)) 108 | 109 | plt.xticks(fontsize=16) 110 | plt.yticks(fontsize=16) 111 | #plt.xlabel('task number', fontsize=16) 112 | plt.title(names_datasets[dataset], fontsize=16) 113 | plt.tight_layout() 114 | plt.savefig('evoplot_%s.pdf' % dataset, bbox_inches='tight') 115 | # plt.show() 116 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MY_PYTHON="python" 4 | MNIST_ROTA="--n_layers 2 --n_hiddens 100 --data_path data/ --save_path results/ --batch_size 10 --log_every 100 --samples_per_task 1000 --data_file mnist_rotations.pt --cuda no --seed 0" 5 | MNIST_PERM="--n_layers 2 --n_hiddens 100 --data_path data/ --save_path results/ --batch_size 10 --log_every 100 --samples_per_task 1000 --data_file mnist_permutations.pt --cuda no --seed 0" 6 | CIFAR_100i="--n_layers 2 --n_hiddens 100 --data_path data/ --save_path results/ --batch_size 10 --log_every 100 --samples_per_task 2500 --data_file cifar100.pt --cuda yes --seed 0" 7 | 8 | # build datasets 9 | cd data/ 10 | cd raw/ 11 | 12 | $MY_PYTHON raw.py 13 | 14 | cd .. 15 | 16 | $MY_PYTHON mnist_permutations.py \ 17 | --o mnist_permutations.pt \ 18 | --seed 0 \ 19 | --n_tasks 20 20 | 21 | $MY_PYTHON mnist_rotations.py \ 22 | --o mnist_rotations.pt\ 23 | --seed 0 \ 24 | --min_rot 0 \ 25 | --max_rot 180 \ 26 | --n_tasks 20 27 | 28 | $MY_PYTHON cifar100.py \ 29 | --o cifar100.pt \ 30 | --seed 0 \ 31 | --n_tasks 20 32 | 33 | cd .. 34 | 35 | # model "single" 36 | $MY_PYTHON main.py $MNIST_ROTA --model single --lr 0.003 37 | $MY_PYTHON main.py $MNIST_PERM --model single --lr 0.03 38 | $MY_PYTHON main.py $CIFAR_100i --model single --lr 1.0 39 | 40 | # model "independent" 41 | $MY_PYTHON main.py $MNIST_ROTA --model independent --lr 0.1 --finetune yes 42 | $MY_PYTHON main.py $MNIST_PERM --model independent --lr 0.03 --finetune yes 43 | $MY_PYTHON main.py $CIFAR_100i --model independent --lr 0.3 --finetune yes 44 | 45 | # model "multimodal" 46 | $MY_PYTHON main.py $MNIST_ROTA --model multimodal --lr 0.1 47 | $MY_PYTHON main.py $MNIST_PERM --model multimodal --lr 0.1 48 | 49 | # model "EWC" 50 | $MY_PYTHON main.py $MNIST_ROTA --model ewc --lr 0.01 --n_memories 1000 --memory_strength 1000 51 | $MY_PYTHON main.py $MNIST_PERM --model ewc --lr 0.1 --n_memories 10 --memory_strength 3 52 | $MY_PYTHON main.py $CIFAR_100i --model ewc --lr 1.0 --n_memories 10 --memory_strength 1 53 | 54 | # model "iCARL" 55 | $MY_PYTHON main.py $CIFAR_100i --model icarl --lr 1.0 --n_memories 1280 --memory_strength 1 56 | 57 | # model "GEM" 58 | $MY_PYTHON main.py $MNIST_ROTA --model gem --lr 0.1 --n_memories 256 --memory_strength 0.5 59 | $MY_PYTHON main.py $MNIST_PERM --model gem --lr 0.1 --n_memories 256 --memory_strength 0.5 60 | $MY_PYTHON main.py $CIFAR_100i --model gem --lr 0.1 --n_memories 256 --memory_strength 0.5 61 | 62 | # plot results 63 | cd results/ 64 | $MY_PYTHON plot_results.py 65 | cd .. 66 | --------------------------------------------------------------------------------