├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── external_libs ├── __init__.py ├── continual_learning_algorithms │ ├── .gitignore │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── conv_split_cifar.py │ ├── extract_res.py │ ├── fc_mnist.py │ ├── model │ │ ├── __init__.py │ │ └── model.py │ ├── replicate_cifar.sh │ ├── replicate_mnist.sh │ ├── replicate_mnist_stable.sh │ ├── run.sh │ └── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── er_utils.py │ │ ├── resnet_utils.py │ │ ├── utils.py │ │ ├── vgg_utils.py │ │ └── vis_utils.py └── hessian_eigenthings │ ├── LICENSE │ ├── __init__.py │ ├── hvp_operator.py │ ├── lanczos.py │ ├── power_iter.py │ ├── spectral_density.py │ └── utils.py ├── replicate_appendix_c5.sh ├── replicate_experiment_1.sh ├── replicate_experiment_2.sh ├── requirements.txt ├── setup_and_install.sh └── stable_sgd ├── __init__.py ├── data_utils.py ├── main.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/vim,macos,python,sublimetext,pycharm+all 2 | # Edit at https://www.gitignore.io/?templates=vim,macos,python,sublimetext,pycharm+all 3 | 4 | ### macOS ### 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | # Thumbnails 14 | ._* 15 | 16 | # Files that might appear in the root of a volume 17 | .DocumentRevisions-V100 18 | .fseventsd 19 | .Spotlight-V100 20 | .TemporaryItems 21 | .Trashes 22 | .VolumeIcon.icns 23 | .com.apple.timemachine.donotpresent 24 | 25 | # Directories potentially created on remote AFP share 26 | .AppleDB 27 | .AppleDesktop 28 | Network Trash Folder 29 | Temporary Items 30 | .apdisk 31 | 32 | ### PyCharm+all ### 33 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 34 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 35 | 36 | # User-specific stuff 37 | .idea/**/workspace.xml 38 | .idea/**/tasks.xml 39 | .idea/**/usage.statistics.xml 40 | .idea/**/dictionaries 41 | .idea/**/shelf 42 | 43 | # Generated files 44 | .idea/**/contentModel.xml 45 | 46 | # Sensitive or high-churn files 47 | .idea/**/dataSources/ 48 | .idea/**/dataSources.ids 49 | .idea/**/dataSources.local.xml 50 | .idea/**/sqlDataSources.xml 51 | .idea/**/dynamic.xml 52 | .idea/**/uiDesigner.xml 53 | .idea/**/dbnavigator.xml 54 | 55 | # Gradle 56 | .idea/**/gradle.xml 57 | .idea/**/libraries 58 | 59 | # Gradle and Maven with auto-import 60 | # When using Gradle or Maven with auto-import, you should exclude module files, 61 | # since they will be recreated, and may cause churn. Uncomment if using 62 | # auto-import. 63 | # .idea/modules.xml 64 | # .idea/*.iml 65 | # .idea/modules 66 | # *.iml 67 | # *.ipr 68 | 69 | # CMake 70 | cmake-build-*/ 71 | 72 | # Mongo Explorer plugin 73 | .idea/**/mongoSettings.xml 74 | 75 | # File-based project format 76 | *.iws 77 | 78 | # IntelliJ 79 | out/ 80 | 81 | # mpeltonen/sbt-idea plugin 82 | .idea_modules/ 83 | 84 | # JIRA plugin 85 | atlassian-ide-plugin.xml 86 | 87 | # Cursive Clojure plugin 88 | .idea/replstate.xml 89 | 90 | # Crashlytics plugin (for Android Studio and IntelliJ) 91 | com_crashlytics_export_strings.xml 92 | crashlytics.properties 93 | crashlytics-build.properties 94 | fabric.properties 95 | 96 | # Editor-based Rest Client 97 | .idea/httpRequests 98 | 99 | # Android studio 3.1+ serialized cache file 100 | .idea/caches/build_file_checksums.ser 101 | 102 | ### PyCharm+all Patch ### 103 | # Ignores the whole .idea folder and all .iml files 104 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 105 | 106 | .idea/ 107 | 108 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 109 | 110 | *.iml 111 | modules.xml 112 | .idea/misc.xml 113 | *.ipr 114 | 115 | # Sonarlint plugin 116 | .idea/sonarlint 117 | 118 | ### Python ### 119 | # Byte-compiled / optimized / DLL files 120 | __pycache__/ 121 | *.py[cod] 122 | *$py.class 123 | 124 | # C extensions 125 | *.so 126 | 127 | # Distribution / packaging 128 | .Python 129 | build/ 130 | develop-eggs/ 131 | dist/ 132 | downloads/ 133 | eggs/ 134 | .eggs/ 135 | lib/ 136 | lib64/ 137 | parts/ 138 | sdist/ 139 | var/ 140 | wheels/ 141 | pip-wheel-metadata/ 142 | share/python-wheels/ 143 | *.egg-info/ 144 | .installed.cfg 145 | *.egg 146 | MANIFEST 147 | 148 | # PyInstaller 149 | # Usually these files are written by a python script from a template 150 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 151 | *.manifest 152 | *.spec 153 | 154 | # Installer logs 155 | pip-log.txt 156 | pip-delete-this-directory.txt 157 | 158 | # Unit test / coverage reports 159 | htmlcov/ 160 | .tox/ 161 | .nox/ 162 | .coverage 163 | .coverage.* 164 | .cache 165 | nosetests.xml 166 | coverage.xml 167 | *.cover 168 | 169 | 170 | .hypothesis/ 171 | .pytest_cache/ 172 | 173 | # Translations 174 | *.mo 175 | *.pot 176 | 177 | # Flask stuff: 178 | instance/ 179 | .webassets-cache 180 | 181 | 182 | # Scrapy stuff: 183 | .scrapy 184 | 185 | # Sphinx documentation 186 | docs/_build/ 187 | 188 | # PyBuilder 189 | target/ 190 | 191 | # Jupyter Notebook 192 | .ipynb_checkpoints 193 | 194 | # IPython 195 | profile_default/ 196 | ipython_config.py 197 | 198 | 199 | # pyenv 200 | .python-version 201 | 202 | # pipenv 203 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 204 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 205 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 206 | # install all needed dependencies. 207 | #Pipfile.lock 208 | 209 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 210 | __pypackages__/ 211 | 212 | # Celery stuff 213 | celerybeat-schedule 214 | celerybeat.pid 215 | # celery beat schedule file 216 | celerybeat-schedule 217 | 218 | # Environments 219 | .env 220 | .venv 221 | env/ 222 | venv/ 223 | ENV/ 224 | env.bak/ 225 | venv.bak/ 226 | 227 | 228 | /site 229 | 230 | # mypy 231 | .mypy_cache/ 232 | .dmypy.json 233 | dmypy.json 234 | 235 | # Pyre type checker 236 | .pyre/ 237 | 238 | 239 | ### SublimeText ### 240 | # Cache files for Sublime Text 241 | *.tmlanguage.cache 242 | *.tmPreferences.cache 243 | *.stTheme.cache 244 | 245 | # Workspace files are user-specific 246 | *.sublime-workspace 247 | 248 | # Project files should be checked into the repository, unless a significant 249 | # proportion of contributors will probably not be using Sublime Text 250 | # *.sublime-project 251 | 252 | # SFTP configuration file 253 | sftp-config.json 254 | 255 | # Package control specific files 256 | Package Control.last-run 257 | Package Control.ca-list 258 | Package Control.ca-bundle 259 | Package Control.system-ca-bundle 260 | Package Control.cache/ 261 | Package Control.ca-certs/ 262 | Package Control.merged-ca-bundle 263 | Package Control.user-ca-bundle 264 | oscrypto-ca-bundle.crt 265 | bh_unicode_properties.cache 266 | 267 | # Sublime-github package stores a github token in this file 268 | # https://packagecontrol.io/packages/sublime-github 269 | GitHub.sublime-settings 270 | 271 | ### Vim ### 272 | # Swap 273 | [._]*.s[a-v][a-z] 274 | [._]*.sw[a-p] 275 | [._]s[a-rt-v][a-z] 276 | [._]ss[a-gi-z] 277 | [._]sw[a-p] 278 | 279 | # Session 280 | Session.vim 281 | Sessionx.vim 282 | 283 | # Temporary 284 | .netrwhist 285 | *~ 286 | # Auto-generated tag files 287 | tags 288 | # Persistent undo 289 | [._]*.un~ 290 | stash 291 | stash/* 292 | models 293 | models/ 294 | .python-version 295 | ./.python-version 296 | 297 | ./data/ 298 | ./tmp/ 299 | ./temp/ 300 | data/ 301 | tmp/ 302 | temp/ 303 | outputs/ 304 | output/ 305 | # End of https://www.gitignore.io/api/vim,macos,python,sublimetext,pycharm+all 306 | 307 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Iman Mirzadeh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Understanding the Role of Training Regimes in Continual Learning 2 | Towards increasing stability of neural networks for continual learning (NeurIPS'20) 3 | 4 | **Note: I will add an updated version of the code soon. If you have problem reproducing the results, please see the instructions for reproducing [experiment 1](https://github.com/imirzadeh/stable-continual-learning/issues/1) and [experiment 2](https://github.com/imirzadeh/stable-continual-learning/issues/5).** 5 | 6 | 7 | ## 1. Code Structure 8 | The high level structure of the code is as follows: 9 | 10 | ``` 11 | root 12 | ├── stable_sgd 13 | │ 14 | └── external_libs 15 | └── continual_learning_algorithms 16 | └── hessian_eigenthings 17 | ``` 18 | 19 | 1. `stable_sgd` : implementations of our stable and plastic training regimen for SGD (in Pytorch). 20 | 2. `external_libs`: third-party implementations we used for our experiments such as: 21 | 2.1 `continual_learning_algorithms` Open source implementations for A-GEM, ER-Reservoir, and EWC (in Tensorflow). 22 | 2.2 `hessian_eigenthings`: Open source implementation of deflated power iteration for eigenspectrum calculations (in Pytorch). 23 | 24 | ## 2. Setup & Installation 25 | The code is tested on Python 3.6+, PyTorch 1.5.0, and Tensorflow 1.15.2. In addition, there are some other numerical and visualization libraries that are included in ``requirements.txt`` file. However, for convenience, we provide a script for setup: 26 | ``` 27 | bash setup_and_install.sh 28 | ``` 29 | 30 | ## 3. Replicating the Results 31 | Note: I will add an updated version of the code soon. If you have problem reproducing the results, please see the instructions for reproducing [experiment 1](https://github.com/imirzadeh/stable-continual-learning/issues/1) and [experiment 2](https://github.com/imirzadeh/stable-continual-learning/issues/5). 32 | 33 | We provide scripts to replicate the results: 34 | * 3.1 Run ```bash replicate_experiment_1.sh``` for experiment 1 (stable vs plastic). 35 | * 3.2 Run ```bash replicate_experiment_2.sh``` for experiment 2 (Comparison with other methods with 20 tasks). 36 | * 3.3 Run ```bash replicate_appendix_c5.sh``` for the experiment in appendix C5 (Stabilizing other methods). 37 | 38 | For faster replication, here we have only 3 runs per method per experiment, but we used 5 runs for the reported results. 39 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imirzadeh/stable-continual-learning/be4049c6f1c433e6d4593bf3113280dcbeb127b9/__init__.py -------------------------------------------------------------------------------- /external_libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imirzadeh/stable-continual-learning/be4049c6f1c433e6d4593bf3113280dcbeb127b9/external_libs/__init__.py -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pyc 3 | *.pdf 4 | *.tar.gz 5 | *.npz 6 | *.swp 7 | *tmp_* 8 | *MNIST_data* 9 | *CIFAR_data* 10 | *CUB_data* 11 | *AWA_data* 12 | *snapshots* 13 | *resnet-18-pretrained* 14 | *_experiment* 15 | .nfs* 16 | results* 17 | cross_validate_* 18 | importance_vis* 19 | *.pickle 20 | 21 | 22 | # Created by https://www.gitignore.io/api/vim,macos,sublimetext 23 | # Edit at https://www.gitignore.io/?templates=vim,macos,sublimetext 24 | 25 | ### macOS ### 26 | # General 27 | .DS_Store 28 | .AppleDouble 29 | .LSOverride 30 | 31 | # Icon must end with two \r 32 | Icon 33 | 34 | # Thumbnails 35 | ._* 36 | 37 | # Files that might appear in the root of a volume 38 | .DocumentRevisions-V100 39 | .fseventsd 40 | .Spotlight-V100 41 | .TemporaryItems 42 | .Trashes 43 | .VolumeIcon.icns 44 | .com.apple.timemachine.donotpresent 45 | 46 | # Directories potentially created on remote AFP share 47 | .AppleDB 48 | .AppleDesktop 49 | Network Trash Folder 50 | Temporary Items 51 | .apdisk 52 | 53 | ### SublimeText ### 54 | # Cache files for Sublime Text 55 | *.tmlanguage.cache 56 | *.tmPreferences.cache 57 | *.stTheme.cache 58 | 59 | # Workspace files are user-specific 60 | *.sublime-workspace 61 | 62 | # Project files should be checked into the repository, unless a significant 63 | # proportion of contributors will probably not be using Sublime Text 64 | # *.sublime-project 65 | 66 | # SFTP configuration file 67 | sftp-config.json 68 | 69 | # Package control specific files 70 | Package Control.last-run 71 | Package Control.ca-list 72 | Package Control.ca-bundle 73 | Package Control.system-ca-bundle 74 | Package Control.cache/ 75 | Package Control.ca-certs/ 76 | Package Control.merged-ca-bundle 77 | Package Control.user-ca-bundle 78 | oscrypto-ca-bundle.crt 79 | bh_unicode_properties.cache 80 | 81 | # Sublime-github package stores a github token in this file 82 | # https://packagecontrol.io/packages/sublime-github 83 | GitHub.sublime-settings 84 | 85 | ### Vim ### 86 | # Swap 87 | [._]*.s[a-v][a-z] 88 | [._]*.sw[a-p] 89 | [._]s[a-rt-v][a-z] 90 | [._]ss[a-gi-z] 91 | [._]sw[a-p] 92 | 93 | # Session 94 | Session.vim 95 | Sessionx.vim 96 | 97 | # Temporary 98 | .netrwhist 99 | *~ 100 | 101 | # Auto-generated tag files 102 | tags 103 | 104 | # Persistent undo 105 | [._]*.un~ 106 | 107 | # Coc configuration directory 108 | .vim 109 | 110 | # End of https://www.gitignore.io/api/vim,macos,sublimetext 111 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to agem 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 2 spaces for indentation rather than tabs 31 | * 80 character line length 32 | * ... 33 | 34 | ## License 35 | By contributing to agem, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017-present, Facebook, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/README.md: -------------------------------------------------------------------------------- 1 | # Efficient Lifelong Learning with A-GEM 2 | 3 | This is the official implementation of the [Averaged Gradient Episodic Memory (A-GEM)](https://arxiv.org/abs/1812.00420) and [Experience Replay with Tiny Memories](http://arxiv.org/abs/1902.10486) in Tensorflow. 4 | 5 | ## Requirements 6 | 7 | TensorFlow >= v1.9.0. 8 | 9 | ## Training 10 | 11 | To replicate the results of the paper on a particular dataset, execute (see the Note below for downloading the CUB and AWA datasets): 12 | ```bash 13 | $ ./replicate_results_iclr19.sh 14 | ``` 15 | Example runs are: 16 | ```bash 17 | $ ./replicate_results_iclr19.sh MNIST 3 /* Train PNN and A-GEM on MNIST */ 18 | $ ./replicate_results_iclr19.sh CUB 1 1 /* Train JE models of RWALK and A-GEM on CUB */ 19 | ``` 20 | 21 | ### Note 22 | For CUB and AWA experiments, download the dataset prior to running the above script. Run following for downloading the datasets: 23 | 24 | ```bash 25 | $ ./download_cub_awa.sh 26 | ``` 27 | The plotting code is provided under the folder `plotting_code/`. Update the paths in the plotting code accordingly. 28 | 29 | ## Experience Replay 30 | The code provides an implementation of experience replay (ER) with reservoir sampling on MNIST and CIFAR datasets. To run the ER experiments execute the following script: 31 | ```bash 32 | $ ./replicate_results_er.sh 33 | ``` 34 | 35 | When using this code, please cite our papers: 36 | 37 | ``` 38 | @inproceedings{AGEM, 39 | title={Efficient Lifelong Learning with A-GEM}, 40 | author={Chaudhry, Arslan and Ranzato, Marc’Aurelio and Rohrbach, Marcus and Elhoseiny, Mohamed}, 41 | booktitle={ICLR}, 42 | year={2019} 43 | } 44 | 45 | @article{chaudhryER_2019, 46 | title={Continual Learning with Tiny Episodic Memories}, 47 | author={Chaudhry, Arslan and Rohrbach, Marcus and Elhoseiny, Mohamed and Ajanthan, Thalaiyasingam and Dokania, Puneet K and Torr, Philip HS and Ranzato, Marc’Aurelio}, 48 | journal={arXiv preprint arXiv:1902.10486, 2019}, 49 | year={2019} 50 | } 51 | 52 | @inproceedings{chaudhry2018riemannian, 53 | title={Riemannian Walk for Incremental Learning: Understanding Forgetting and Intransigence}, 54 | author={Chaudhry, Arslan and Dokania, Puneet K and Ajanthan, Thalaiyasingam and Torr, Philip HS}, 55 | booktitle={ECCV}, 56 | year={2018} 57 | } 58 | ``` 59 | 60 | ## Questions/ Bugs 61 | * For questions, contact the author Arslan Chaudhry (arslan.chaudhry@eng.ox.ac.uk). 62 | * Feel free to open the bugs if anything is broken. 63 | 64 | ## License 65 | This source code is released under The MIT License found in the LICENSE file in the root directory of this source tree. 66 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/conv_split_cifar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Training script for split CIFAR 100 experiment. 8 | """ 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import os 13 | import sys 14 | import math 15 | import time 16 | 17 | import datetime 18 | import numpy as np 19 | import tensorflow as tf 20 | from copy import deepcopy 21 | from six.moves import cPickle as pickle 22 | 23 | from utils.data_utils import construct_split_cifar 24 | from utils.utils import get_sample_weights, sample_from_dataset, update_episodic_memory, concatenate_datasets, samples_for_each_class, sample_from_dataset_icarl, compute_fgt, load_task_specific_data 25 | from utils.utils import average_acc_stats_across_runs, average_fgt_stats_across_runs, update_reservior 26 | from utils.vis_utils import plot_acc_multiple_runs, plot_histogram, snapshot_experiment_meta_data, snapshot_experiment_eval, snapshot_task_labels 27 | from model import Model 28 | 29 | ############################################################### 30 | ################ Some definitions ############################# 31 | ### These will be edited by the command line options ########## 32 | ############################################################### 33 | 34 | ## Training Options 35 | NUM_RUNS = 5 # Number of experiments to average over 36 | TRAIN_ITERS = 2000 # Number of training iterations per task 37 | BATCH_SIZE = 64 38 | LEARNING_RATE = 0.1 39 | RANDOM_SEED = 3235 40 | VALID_OPTIMS = ['SGD', 'MOMENTUM', 'ADAM'] 41 | OPTIM = 'SGD' 42 | OPT_MOMENTUM = 0.9 43 | OPT_POWER = 0.9 44 | VALID_ARCHS = ['CNN', 'RESNET-S', 'RESNET-B', 'VGG'] 45 | ARCH = 'RESNET-S' 46 | 47 | ## Model options 48 | MODELS = ['VAN', 'PI', 'EWC', 'MAS', 'RWALK', 'M-EWC', 'S-GEM', 'A-GEM', 'FTR_EXT', 'PNN', 'ER'] #List of valid models 49 | IMP_METHOD = 'EWC' 50 | SYNAP_STGTH = 10 51 | FISHER_EMA_DECAY = 0.9 # Exponential moving average decay factor for Fisher computation (online Fisher) 52 | FISHER_UPDATE_AFTER = 50 # Number of training iterations for which the F_{\theta}^t is computed (see Eq. 10 in RWalk paper) 53 | SAMPLES_PER_CLASS = 1 54 | IMG_HEIGHT = 32 55 | IMG_WIDTH = 32 56 | IMG_CHANNELS = 3 57 | TOTAL_CLASSES = 100 # Total number of classes in the dataset 58 | VISUALIZE_IMPORTANCE_MEASURE = False 59 | MEASURE_CONVERGENCE_AFTER = 0.9 60 | EPS_MEM_BATCH_SIZE = 256 61 | DEBUG_EPISODIC_MEMORY = False 62 | K_FOR_CROSS_VAL = 3 63 | TIME_MY_METHOD = False 64 | COUNT_VIOLATONS = False 65 | MEASURE_PERF_ON_EPS_MEMORY = False 66 | 67 | ## Logging, saving and testing options 68 | LOG_DIR = './split_cifar_results' 69 | RESNET18_CIFAR10_CHECKPOINT = './resnet-18-pretrained-cifar10/model.ckpt-19999' 70 | ## Evaluation options 71 | 72 | ## Task split 73 | NUM_TASKS = 23 74 | MULTI_TASK = False 75 | 76 | 77 | # Define function to load/ store training weights. We will use ImageNet initialization later on 78 | def save(saver, sess, logdir, step): 79 | '''Save weights. 80 | 81 | Args: 82 | saver: TensorFlow Saver object. 83 | sess: TensorFlow session. 84 | logdir: path to the snapshots directory. 85 | step: current training step. 86 | ''' 87 | model_name = 'model.ckpt' 88 | checkpoint_path = os.path.join(logdir, model_name) 89 | 90 | if not os.path.exists(logdir): 91 | os.makedirs(logdir) 92 | saver.save(sess, checkpoint_path, global_step=step) 93 | print('The checkpoint has been created.') 94 | 95 | def load(saver, sess, ckpt_path): 96 | '''Load trained weights. 97 | 98 | Args: 99 | saver: TensorFlow Saver object. 100 | sess: TensorFlow session. 101 | ckpt_path: path to checkpoint file with parameters. 102 | ''' 103 | saver.restore(sess, ckpt_path) 104 | print("Restored model parameters from {}".format(ckpt_path)) 105 | 106 | 107 | def get_arguments(): 108 | """Parse all the arguments provided from the CLI. 109 | 110 | Returns: 111 | A list of parsed arguments. 112 | """ 113 | parser = argparse.ArgumentParser(description="Script for split cifar experiment.") 114 | parser.add_argument("--cross-validate-mode", action="store_true", 115 | help="If option is chosen then snapshoting after each batch is disabled") 116 | parser.add_argument("--online-cross-val", action="store_true", 117 | help="If option is chosen then enable the online cross validation of the learning rate") 118 | parser.add_argument("--train-single-epoch", action="store_true", 119 | help="If option is chosen then train for single epoch") 120 | parser.add_argument("--eval-single-head", action="store_true", 121 | help="If option is chosen then evaluate on a single head setting.") 122 | parser.add_argument("--arch", type=str, default=ARCH, 123 | help="Network Architecture for the experiment.\ 124 | \n \nSupported values: %s"%(VALID_ARCHS)) 125 | parser.add_argument("--num-runs", type=int, default=NUM_RUNS, 126 | help="Total runs/ experiments over which accuracy is averaged.") 127 | parser.add_argument("--train-iters", type=int, default=TRAIN_ITERS, 128 | help="Number of training iterations for each task.") 129 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 130 | help="Mini-batch size for each task.") 131 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 132 | help="Random Seed.") 133 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 134 | help="Starting Learning rate for each task.") 135 | parser.add_argument("--optim", type=str, default=OPTIM, 136 | help="Optimizer for the experiment. \ 137 | \n \nSupported values: %s"%(VALID_OPTIMS)) 138 | parser.add_argument("--imp-method", type=str, default=IMP_METHOD, 139 | help="Model to be used for LLL. \ 140 | \n \nSupported values: %s"%(MODELS)) 141 | parser.add_argument("--synap-stgth", type=float, default=SYNAP_STGTH, 142 | help="Synaptic strength for the regularization.") 143 | parser.add_argument("--fisher-ema-decay", type=float, default=FISHER_EMA_DECAY, 144 | help="Exponential moving average decay for Fisher calculation at each step.") 145 | parser.add_argument("--fisher-update-after", type=int, default=FISHER_UPDATE_AFTER, 146 | help="Number of training iterations after which the Fisher will be updated.") 147 | parser.add_argument("--mem-size", type=int, default=SAMPLES_PER_CLASS, 148 | help="Total size of episodic memory.") 149 | parser.add_argument("--eps-mem-batch", type=int, default=EPS_MEM_BATCH_SIZE, 150 | help="Number of samples per class from previous tasks.") 151 | parser.add_argument("--log-dir", type=str, default=LOG_DIR, 152 | help="Directory where the plots and model accuracies will be stored.") 153 | return parser.parse_args() 154 | 155 | def train_task_sequence(model, sess, datasets, args): 156 | """ 157 | Train and evaluate LLL system such that we only see a example once 158 | Args: 159 | Returns: 160 | dict A dictionary containing mean and stds for the experiment 161 | """ 162 | # List to store accuracy for each run 163 | runs = [] 164 | task_labels_dataset = [] 165 | 166 | if model.imp_method == 'A-GEM' or model.imp_method == 'ER': 167 | use_episodic_memory = True 168 | else: 169 | use_episodic_memory = False 170 | 171 | batch_size = args.batch_size 172 | # Loop over number of runs to average over 173 | for runid in range(args.num_runs): 174 | print('\t\tRun %d:'%(runid)) 175 | 176 | # Initialize the random seeds 177 | np.random.seed(args.random_seed+runid) 178 | 179 | # Get the task labels from the total number of tasks and full label space 180 | task_labels = [] 181 | classes_per_task = TOTAL_CLASSES// NUM_TASKS 182 | total_classes = classes_per_task * model.num_tasks 183 | if args.online_cross_val: 184 | label_array = np.arange(total_classes) 185 | else: 186 | class_label_offset = K_FOR_CROSS_VAL * classes_per_task 187 | label_array = np.arange(class_label_offset, total_classes+class_label_offset) 188 | 189 | np.random.shuffle(label_array) 190 | for tt in range(model.num_tasks): 191 | tt_offset = tt*classes_per_task 192 | task_labels.append(list(label_array[tt_offset:tt_offset+classes_per_task])) 193 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 194 | 195 | # Store the task labels 196 | task_labels_dataset.append(task_labels) 197 | 198 | # Set episodic memory size 199 | episodic_mem_size = args.mem_size * total_classes 200 | 201 | # Initialize all the variables in the model 202 | sess.run(tf.global_variables_initializer()) 203 | 204 | # Run the init ops 205 | model.init_updates(sess) 206 | 207 | # List to store accuracies for a run 208 | evals = [] 209 | 210 | # List to store the classes that we have so far - used at test time 211 | test_labels = [] 212 | 213 | if use_episodic_memory: 214 | # Reserve a space for episodic memory 215 | episodic_images = np.zeros([episodic_mem_size, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS]) 216 | episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) 217 | episodic_filled_counter = 0 218 | nd_logit_mask = np.zeros([model.num_tasks, TOTAL_CLASSES]) 219 | count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) 220 | episodic_filled_counter = 0 221 | examples_seen_so_far = 0 222 | 223 | # Mask for softmax 224 | logit_mask = np.zeros(TOTAL_CLASSES) 225 | if COUNT_VIOLATONS: 226 | violation_count = np.zeros(model.num_tasks) 227 | vc = 0 228 | 229 | # Training loop for all the tasks 230 | for task in range(len(task_labels)): 231 | print('\t\tTask %d:'%(task)) 232 | 233 | # If not the first task then restore weights from previous task 234 | if(task > 0 and model.imp_method != 'PNN'): 235 | model.restore(sess) 236 | 237 | if model.imp_method == 'PNN': 238 | pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) 239 | pnn_train_phase[task] = True 240 | pnn_logit_mask = np.zeros([model.num_tasks, TOTAL_CLASSES]) 241 | 242 | # If not in the cross validation mode then concatenate the train and validation sets 243 | task_tr_images, task_tr_labels = load_task_specific_data(datasets[0]['train'], task_labels[task]) 244 | task_val_images, task_val_labels = load_task_specific_data(datasets[0]['validation'], task_labels[task]) 245 | task_train_images, task_train_labels = concatenate_datasets(task_tr_images, task_tr_labels, task_val_images, task_val_labels) 246 | 247 | # If multi_task is set then train using all the datasets of all the tasks 248 | if MULTI_TASK: 249 | if task == 0: 250 | for t_ in range(1, len(task_labels)): 251 | task_tr_images, task_tr_labels = load_task_specific_data(datasets[0]['train'], task_labels[t_]) 252 | task_train_images = np.concatenate((task_train_images, task_tr_images), axis=0) 253 | task_train_labels = np.concatenate((task_train_labels, task_tr_labels), axis=0) 254 | 255 | else: 256 | # Skip training for this task 257 | continue 258 | 259 | print('Received {} images, {} labels at task {}'.format(task_train_images.shape[0], task_train_labels.shape[0], task)) 260 | print('Unique labels in the task: {}'.format(np.unique(np.nonzero(task_train_labels)[1]))) 261 | 262 | # Test for the tasks that we've seen so far 263 | test_labels += task_labels[task] 264 | 265 | # Assign equal weights to all the examples 266 | task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) 267 | 268 | num_train_examples = task_train_images.shape[0] 269 | 270 | logit_mask[:] = 0 271 | # Train a task observing sequence of data 272 | if args.train_single_epoch: 273 | # Ceiling operation 274 | num_iters = (num_train_examples + batch_size - 1) // batch_size 275 | if args.cross_validate_mode: 276 | logit_mask[task_labels[task]] = 1.0 277 | else: 278 | num_iters = args.train_iters 279 | # Set the mask only once before starting the training for the task 280 | logit_mask[task_labels[task]] = 1.0 281 | 282 | if MULTI_TASK: 283 | logit_mask[:] = 1.0 284 | 285 | # Randomly suffle the training examples 286 | perm = np.arange(num_train_examples) 287 | np.random.shuffle(perm) 288 | train_x = task_train_images[perm] 289 | train_y = task_train_labels[perm] 290 | task_sample_weights = task_sample_weights[perm] 291 | 292 | # Array to store accuracies when training for task T 293 | ftask = [] 294 | 295 | # Number of iterations after which convergence is checked 296 | convergence_iters = int(num_iters * MEASURE_CONVERGENCE_AFTER) 297 | 298 | # Training loop for task T 299 | for iters in range(num_iters): 300 | 301 | if args.train_single_epoch and not args.cross_validate_mode and not MULTI_TASK: 302 | if (iters <= 20) or (iters > 20 and iters % 50 == 0): 303 | # Snapshot the current performance across all tasks after each mini-batch 304 | fbatch = test_task_sequence(model, sess, datasets[0]['test'], task_labels, task) 305 | ftask.append(fbatch) 306 | if model.imp_method == 'PNN': 307 | pnn_train_phase[:] = False 308 | pnn_train_phase[task] = True 309 | pnn_logit_mask[:] = 0 310 | pnn_logit_mask[task][task_labels[task]] = 1.0 311 | elif model.imp_method == 'A-GEM': 312 | nd_logit_mask[:] = 0 313 | nd_logit_mask[task][task_labels[task]] = 1.0 314 | else: 315 | # Set the output labels over which the model needs to be trained 316 | logit_mask[:] = 0 317 | logit_mask[task_labels[task]] = 1.0 318 | 319 | if args.train_single_epoch: 320 | offset = iters * batch_size 321 | if (offset+batch_size <= num_train_examples): 322 | residual = batch_size 323 | else: 324 | residual = num_train_examples - offset 325 | 326 | if model.imp_method == 'PNN': 327 | feed_dict = {model.x: train_x[offset:offset+residual], model.y_[task]: train_y[offset:offset+residual], 328 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5} 329 | train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} 330 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} 331 | feed_dict.update(train_phase_dict) 332 | feed_dict.update(logit_mask_dict) 333 | else: 334 | feed_dict = {model.x: train_x[offset:offset+residual], model.y_: train_y[offset:offset+residual], 335 | model.sample_weights: task_sample_weights[offset:offset+residual], 336 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5, 337 | model.train_phase: True} 338 | else: 339 | offset = (iters * batch_size) % (num_train_examples - batch_size) 340 | if model.imp_method == 'PNN': 341 | feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size], 342 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5} 343 | train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} 344 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} 345 | feed_dict.update(train_phase_dict) 346 | feed_dict.update(logit_mask_dict) 347 | else: 348 | feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size], 349 | model.sample_weights: task_sample_weights[offset:offset+batch_size], 350 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5, 351 | model.train_phase: True} 352 | 353 | if model.imp_method == 'VAN': 354 | feed_dict[model.output_mask] = logit_mask 355 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 356 | 357 | elif model.imp_method == 'PNN': 358 | _, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) 359 | 360 | elif model.imp_method == 'FTR_EXT': 361 | feed_dict[model.output_mask] = logit_mask 362 | if task == 0: 363 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 364 | else: 365 | _, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict) 366 | 367 | elif model.imp_method == 'EWC' or model.imp_method == 'M-EWC': 368 | feed_dict[model.output_mask] = logit_mask 369 | # If first iteration of the first task then set the initial value of the running fisher 370 | if task == 0 and iters == 0: 371 | sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) 372 | # Update fisher after every few iterations 373 | if (iters + 1) % model.fisher_update_after == 0: 374 | sess.run(model.set_running_fisher) 375 | sess.run(model.reset_tmp_fisher) 376 | 377 | if (iters >= convergence_iters) and (model.imp_method == 'M-EWC'): 378 | _, _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.set_tmp_fisher, model.train, model.update_small_omega, 379 | model.reg_loss], feed_dict=feed_dict) 380 | else: 381 | _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) 382 | 383 | elif model.imp_method == 'PI': 384 | feed_dict[model.output_mask] = logit_mask 385 | _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, 386 | model.reg_loss], feed_dict=feed_dict) 387 | 388 | elif model.imp_method == 'MAS': 389 | feed_dict[model.output_mask] = logit_mask 390 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 391 | 392 | elif model.imp_method == 'A-GEM': 393 | if task == 0: 394 | nd_logit_mask[:] = 0 395 | nd_logit_mask[task][task_labels[task]] = 1.0 396 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} 397 | feed_dict.update(logit_mask_dict) 398 | feed_dict[model.mem_batch_size] = batch_size 399 | # Normal application of gradients 400 | _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) 401 | else: 402 | ## Compute and store the reference gradients on the previous tasks 403 | # Set the mask for all the previous tasks so far 404 | nd_logit_mask[:] = 0 405 | for tt in range(task): 406 | nd_logit_mask[tt][task_labels[tt]] = 1.0 407 | 408 | if episodic_filled_counter <= args.eps_mem_batch: 409 | mem_sample_mask = np.arange(episodic_filled_counter) 410 | else: 411 | # Sample a random subset from episodic memory buffer 412 | mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once 413 | # Store the reference gradient 414 | ref_feed_dict = {model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], 415 | model.keep_prob: 1.0, model.train_phase: True} 416 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} 417 | ref_feed_dict.update(logit_mask_dict) 418 | ref_feed_dict[model.mem_batch_size] = float(len(mem_sample_mask)) 419 | sess.run(model.store_ref_grads, feed_dict=ref_feed_dict) 420 | 421 | # Compute the gradient for current task and project if need be 422 | nd_logit_mask[:] = 0 423 | nd_logit_mask[task][task_labels[task]] = 1.0 424 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} 425 | feed_dict.update(logit_mask_dict) 426 | feed_dict[model.mem_batch_size] = batch_size 427 | if COUNT_VIOLATONS: 428 | vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) 429 | else: 430 | _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) 431 | # Put the batch in the ring buffer 432 | for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): 433 | cls = np.unique(np.nonzero(er_y_))[-1] 434 | # Write the example at the location pointed by count_cls[cls] 435 | cls_to_index_map = np.where(np.array(task_labels[task]) == cls)[0][0] 436 | with_in_task_offset = args.mem_size * cls_to_index_map 437 | mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter 438 | episodic_images[mem_index] = er_x 439 | episodic_labels[mem_index] = er_y_ 440 | count_cls[cls] = (count_cls[cls] + 1) % args.mem_size 441 | 442 | elif model.imp_method == 'RWALK': 443 | feed_dict[model.output_mask] = logit_mask 444 | # If first iteration of the first task then set the initial value of the running fisher 445 | if task == 0 and iters == 0: 446 | sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) 447 | # Store the current value of the weights 448 | sess.run(model.weights_delta_old_grouped) 449 | # Update fisher and importance score after every few iterations 450 | if (iters + 1) % model.fisher_update_after == 0: 451 | # Update the importance score using distance in riemannian manifold 452 | sess.run(model.update_big_omega_riemann) 453 | # Now that the score is updated, compute the new value for running Fisher 454 | sess.run(model.set_running_fisher) 455 | # Store the current value of the weights 456 | sess.run(model.weights_delta_old_grouped) 457 | # Reset the delta_L 458 | sess.run([model.reset_small_omega]) 459 | 460 | _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, 461 | model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) 462 | 463 | elif model.imp_method == 'ER': 464 | mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size 465 | if mem_filled_so_far < args.eps_mem_batch: 466 | er_mem_indices = np.arange(mem_filled_so_far) 467 | else: 468 | er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) 469 | np.random.shuffle(er_mem_indices) 470 | nd_logit_mask[:] = 0 471 | for tt in range(task+1): 472 | nd_logit_mask[tt][task_labels[tt]] = 1.0 473 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} 474 | er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0) 475 | er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0) 476 | feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch, 477 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, 478 | model.train_phase: True} 479 | feed_dict.update(logit_mask_dict) 480 | feed_dict[model.mem_batch_size] = float(er_train_x_batch.shape[0]) 481 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 482 | 483 | # Reservoir update 484 | for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): 485 | update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) 486 | examples_seen_so_far += 1 487 | 488 | if (iters % 100 == 0): 489 | print('Step {:d} {:.3f}'.format(iters, loss)) 490 | 491 | if (math.isnan(loss)): 492 | print('ERROR: NaNs NaNs NaNs!!!') 493 | sys.exit(0) 494 | 495 | print('\t\t\t\tTraining for Task%d done!'%(task)) 496 | 497 | if use_episodic_memory: 498 | episodic_filled_counter += args.mem_size * classes_per_task 499 | 500 | if model.imp_method == 'A-GEM': 501 | if COUNT_VIOLATONS: 502 | violation_count[task] = vc 503 | print('Task {}: Violation Count: {}'.format(task, violation_count)) 504 | sess.run(model.reset_violation_count, feed_dict=feed_dict) 505 | 506 | # Compute the inter-task updates, Fisher/ importance scores etc 507 | # Don't calculate the task updates for the last task 508 | if (task < (len(task_labels) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: 509 | model.task_updates(sess, task, task_train_images, task_labels[task]) # TODO: For MAS, should the gradients be for current task or all the previous tasks 510 | print('\t\t\t\tTask updates after Task%d done!'%(task)) 511 | 512 | if VISUALIZE_IMPORTANCE_MEASURE: 513 | if runid == 0: 514 | for i in range(len(model.fisher_diagonal_at_minima)): 515 | if i == 0: 516 | flatten_fisher = np.array(model.fisher_diagonal_at_minima[i].eval()).flatten() 517 | else: 518 | flatten_fisher = np.concatenate((flatten_fisher, 519 | np.array(model.fisher_diagonal_at_minima[i].eval()).flatten())) 520 | 521 | #flatten_fisher [flatten_fisher > 0.1] = 0.1 522 | if args.train_single_epoch: 523 | plot_histogram(flatten_fisher, 100, '/private/home/arslanch/Dropbox/LLL_experiments/Single_Epoch/importance_vis/single_epoch/m_ewc/hist_fisher_task%s.png'%(task)) 524 | else: 525 | plot_histogram(flatten_fisher, 100, '/private/home/arslanch/Dropbox/LLL_experiments/Single_Epoch/importance_vis/single_epoch/m_ewc/hist_fisher_task%s.png'%(task)) 526 | 527 | if args.train_single_epoch and not args.cross_validate_mode: 528 | fbatch = test_task_sequence(model, sess, datasets[0]['test'], task_labels, task) 529 | print('Task: {}, Acc: {}'.format(task, fbatch)) 530 | ftask.append(fbatch) 531 | ftask = np.array(ftask) 532 | if model.imp_method == 'PNN': 533 | pnn_train_phase[:] = False 534 | pnn_train_phase[task] = True 535 | pnn_logit_mask[:] = 0 536 | pnn_logit_mask[task][task_labels[task]] = 1.0 537 | else: 538 | if MEASURE_PERF_ON_EPS_MEMORY: 539 | eps_mem = { 540 | 'images': episodic_images, 541 | 'labels': episodic_labels, 542 | } 543 | # Measure perf on episodic memory 544 | ftask = test_task_sequence(model, sess, eps_mem, task_labels, task, classes_per_task=classes_per_task) 545 | else: 546 | # List to store accuracy for all the tasks for the current trained model 547 | ftask = test_task_sequence(model, sess, datasets[0]['test'], task_labels, task) 548 | print('Task: {}, Acc: {}'.format(task, ftask)) 549 | 550 | # Store the accuracies computed at task T in a list 551 | evals.append(ftask) 552 | 553 | # Reset the optimizer 554 | model.reset_optimizer(sess) 555 | 556 | #-> End for loop task 557 | 558 | runs.append(np.array(evals)) 559 | # End for loop runid 560 | 561 | runs = np.array(runs) 562 | 563 | return runs, task_labels_dataset 564 | 565 | def test_task_sequence(model, sess, test_data, test_tasks, task, classes_per_task=0): 566 | """ 567 | Snapshot the current performance 568 | """ 569 | if TIME_MY_METHOD: 570 | # Only compute the training time 571 | return np.zeros(model.num_tasks) 572 | 573 | final_acc = np.zeros(model.num_tasks) 574 | if model.imp_method == 'PNN' or model.imp_method == 'A-GEM' or model.imp_method == 'ER': 575 | logit_mask = np.zeros([model.num_tasks, TOTAL_CLASSES]) 576 | else: 577 | logit_mask = np.zeros(TOTAL_CLASSES) 578 | 579 | if MEASURE_PERF_ON_EPS_MEMORY: 580 | for tt, labels in enumerate(test_tasks): 581 | # Multi-head evaluation setting 582 | logit_mask[:] = 0 583 | logit_mask[labels] = 1.0 584 | mem_offset = tt*SAMPLES_PER_CLASS*classes_per_task 585 | feed_dict = {model.x: test_data['images'][mem_offset:mem_offset+SAMPLES_PER_CLASS*classes_per_task], 586 | model.y_: test_data['labels'][mem_offset:mem_offset+SAMPLES_PER_CLASS*classes_per_task], model.keep_prob: 1.0, model.train_phase: False, model.output_mask: logit_mask} 587 | acc = model.accuracy.eval(feed_dict = feed_dict) 588 | final_acc[tt] = acc 589 | return final_acc 590 | 591 | for tt, labels in enumerate(test_tasks): 592 | 593 | if not MULTI_TASK: 594 | if tt > task: 595 | return final_acc 596 | 597 | task_test_images, task_test_labels = load_task_specific_data(test_data, labels) 598 | if model.imp_method == 'PNN': 599 | pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) 600 | logit_mask[:] = 0 601 | logit_mask[tt][labels] = 1.0 602 | feed_dict = {model.x: task_test_images, 603 | model.y_[tt]: task_test_labels, model.keep_prob: 1.0} 604 | train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} 605 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, logit_mask)} 606 | feed_dict.update(train_phase_dict) 607 | feed_dict.update(logit_mask_dict) 608 | acc = model.accuracy[tt].eval(feed_dict = feed_dict) 609 | 610 | elif model.imp_method == 'A-GEM' or model.imp_method == 'ER': 611 | logit_mask[:] = 0 612 | logit_mask[tt][labels] = 1.0 613 | feed_dict = {model.x: task_test_images, 614 | model.y_: task_test_labels, model.keep_prob: 1.0, model.train_phase: False} 615 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, logit_mask)} 616 | feed_dict.update(logit_mask_dict) 617 | acc = model.accuracy[tt].eval(feed_dict = feed_dict) 618 | 619 | else: 620 | logit_mask[:] = 0 621 | logit_mask[labels] = 1.0 622 | feed_dict = {model.x: task_test_images, 623 | model.y_: task_test_labels, model.keep_prob: 1.0, model.train_phase: False, model.output_mask: logit_mask} 624 | acc = model.accuracy.eval(feed_dict = feed_dict) 625 | 626 | final_acc[tt] = acc 627 | 628 | return final_acc 629 | 630 | def main(): 631 | """ 632 | Create the model and start the training 633 | """ 634 | 635 | # Get the CL arguments 636 | args = get_arguments() 637 | 638 | # Check if the network architecture is valid 639 | if args.arch not in VALID_ARCHS: 640 | raise ValueError("Network architecture %s is not supported!"%(args.arch)) 641 | 642 | # Check if the method to compute importance is valid 643 | if args.imp_method not in MODELS: 644 | raise ValueError("Importance measure %s is undefined!"%(args.imp_method)) 645 | 646 | # Check if the optimizer is valid 647 | if args.optim not in VALID_OPTIMS: 648 | raise ValueError("Optimizer %s is undefined!"%(args.optim)) 649 | 650 | # Create log directories to store the results 651 | if not os.path.exists(args.log_dir): 652 | print('Log directory %s created!'%(args.log_dir)) 653 | os.makedirs(args.log_dir) 654 | 655 | # Generate the experiment key and store the meta data in a file 656 | exper_meta_data = {'ARCH': args.arch, 657 | 'DATASET': 'SPLIT_CIFAR', 658 | 'NUM_RUNS': args.num_runs, 659 | 'TRAIN_SINGLE_EPOCH': args.train_single_epoch, 660 | 'IMP_METHOD': args.imp_method, 661 | 'SYNAP_STGTH': args.synap_stgth, 662 | 'FISHER_EMA_DECAY': args.fisher_ema_decay, 663 | 'FISHER_UPDATE_AFTER': args.fisher_update_after, 664 | 'OPTIM': args.optim, 665 | 'LR': args.learning_rate, 666 | 'BATCH_SIZE': args.batch_size, 667 | 'MEM_SIZE': args.mem_size} 668 | experiment_id = "SPLIT_CIFAR_HERDING_%s_%r_%s_%s_%s_%s_%s-"%(args.arch, args.train_single_epoch, args.imp_method, 669 | str(args.synap_stgth).replace('.', '_'), str(args.learning_rate).replace('.', '_'), 670 | str(args.batch_size), str(args.mem_size)) + datetime.datetime.now().strftime("%y-%m-%d-%H-%M") 671 | snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data) 672 | 673 | # Get the task labels from the total number of tasks and full label space 674 | if args.online_cross_val: 675 | num_tasks = K_FOR_CROSS_VAL 676 | else: 677 | num_tasks = NUM_TASKS - K_FOR_CROSS_VAL 678 | 679 | # Load the split cifar dataset 680 | data_labs = [np.arange(TOTAL_CLASSES)] 681 | datasets = construct_split_cifar(data_labs) 682 | 683 | # Variables to store the accuracies and standard deviations of the experiment 684 | acc_mean = dict() 685 | acc_std = dict() 686 | 687 | # Reset the default graph 688 | tf.reset_default_graph() 689 | graph = tf.Graph() 690 | with graph.as_default(): 691 | 692 | # Set the random seed 693 | tf.set_random_seed(args.random_seed) 694 | 695 | # Define Input and Output of the model 696 | x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS]) 697 | if args.imp_method == 'PNN': 698 | y_ = [] 699 | for i in range(num_tasks): 700 | y_.append(tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES])) 701 | else: 702 | y_ = tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES]) 703 | 704 | # Define the optimizer 705 | if args.optim == 'ADAM': 706 | opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) 707 | 708 | elif args.optim == 'SGD': 709 | opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) 710 | 711 | elif args.optim == 'MOMENTUM': 712 | base_lr = tf.constant(args.learning_rate) 713 | learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - train_step / training_iters), OPT_POWER)) 714 | opt = tf.train.MomentumOptimizer(args.learning_rate, OPT_MOMENTUM) 715 | 716 | # Create the Model/ contruct the graph 717 | model = Model(x, y_, num_tasks, opt, args.imp_method, args.synap_stgth, args.fisher_update_after, 718 | args.fisher_ema_decay, network_arch=args.arch) 719 | 720 | # Set up tf session and initialize variables. 721 | config = tf.ConfigProto() 722 | config.gpu_options.allow_growth = True 723 | 724 | time_start = time.time() 725 | with tf.Session(config=config, graph=graph) as sess: 726 | runs, task_labels_dataset = train_task_sequence(model, sess, datasets, args) 727 | # Close the session 728 | sess.close() 729 | time_end = time.time() 730 | time_spent = time_end - time_start 731 | 732 | # Store all the results in one dictionary to process later 733 | exper_acc = dict(mean=runs) 734 | exper_labels = dict(labels=task_labels_dataset) 735 | 736 | # If cross-validation flag is enabled, store the stuff in a text file 737 | if args.cross_validate_mode: 738 | acc_mean, acc_std = average_acc_stats_across_runs(runs, model.imp_method) 739 | fgt_mean, fgt_std = average_fgt_stats_across_runs(runs, model.imp_method) 740 | cross_validate_dump_file = args.log_dir + '/' + 'SPLIT_CIFAR_%s_%s'%(args.imp_method, args.optim) + '.txt' 741 | with open(cross_validate_dump_file, 'a') as f: 742 | if MULTI_TASK: 743 | f.write('HERDING: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.format(args.arch, args.learning_rate, args.synap_stgth, acc_mean[-1,:].mean())) 744 | else: 745 | f.write('ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {} \t Fgt: {} \t Time: {}\n'.format(args.arch, args.learning_rate, 746 | args.synap_stgth, acc_mean, fgt_mean, str(time_spent))) 747 | 748 | # Store the experiment output to a file 749 | snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc) 750 | snapshot_task_labels(args.log_dir, experiment_id, exper_labels) 751 | 752 | if __name__ == '__main__': 753 | main() 754 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/extract_res.py: -------------------------------------------------------------------------------- 1 | from six.moves import cPickle as pickle 2 | 3 | with open('./ring.pickle', 'rb') as f: 4 | data = pickle.load(f)['mean'] 5 | 6 | print(data.shape) 7 | print(data[0][-1][-1]) 8 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/fc_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Training script for permute MNIST experiment. 8 | """ 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import os 13 | import sys 14 | import math 15 | import time 16 | 17 | import datetime 18 | import numpy as np 19 | import tensorflow as tf 20 | from copy import deepcopy 21 | from six.moves import cPickle as pickle 22 | 23 | from utils.data_utils import construct_permute_mnist, construct_rotate_mnist 24 | from utils.utils import get_sample_weights, sample_from_dataset, update_episodic_memory, concatenate_datasets, samples_for_each_class, sample_from_dataset_icarl, compute_fgt, update_reservior 25 | from utils.vis_utils import plot_acc_multiple_runs, plot_histogram, snapshot_experiment_meta_data, snapshot_experiment_eval 26 | from model import Model 27 | 28 | ############################################################### 29 | ################ Some definitions ############################# 30 | ### These will be edited by the command line options ########## 31 | ############################################################### 32 | 33 | ## Training Options 34 | EPOCHS_PER_TASK = 1 35 | NUM_RUNS = 10 # Number of experiments to average over 36 | TRAIN_ITERS = 5000 # Number of training iterations per task 37 | BATCH_SIZE = 16 38 | LEARNING_RATE = 1e-3 39 | RANDOM_SEED = 1235 40 | VALID_OPTIMS = ['SGD', 'MOMENTUM', 'ADAM'] 41 | OPTIM = 'SGD' 42 | OPT_POWER = 0.9 43 | OPT_MOMENTUM = 0.9 44 | VALID_ARCHS = ['FC-S', 'FC-B'] 45 | ARCH = 'FC-S' 46 | 47 | ## Model options 48 | MODELS = ['VAN', 'PI', 'EWC', 'MAS', 'RWALK', 'A-GEM', 'S-GEM', 'FTR_EXT', 'PNN', 'ER'] #List of valid models 49 | IMP_METHOD = 'EWC' 50 | SYNAP_STGTH = 10 51 | FISHER_EMA_DECAY = 0.9 # Exponential moving average decay factor for Fisher computation (online Fisher) 52 | FISHER_UPDATE_AFTER = 10 # Number of training iterations for which the F_{\theta}^t is computed (see Eq. 10 in RWalk paper) 53 | SAMPLES_PER_CLASS = 1 # Number of samples per task 54 | INPUT_FEATURE_SIZE = 784 55 | IMG_HEIGHT = 28 56 | IMG_WIDTH = 28 57 | IMG_CHANNELS = 1 58 | TOTAL_CLASSES = 10 # Total number of classes in the dataset 59 | EPS_MEM_BATCH_SIZE = 64 60 | DEBUG_EPISODIC_MEMORY = False 61 | USE_GPU = True 62 | K_FOR_CROSS_VAL = 3 63 | TIME_MY_METHOD = False 64 | COUNT_VIOLATIONS = False 65 | MEASURE_PERF_ON_EPS_MEMORY = False 66 | 67 | ## Logging, saving and testing options 68 | LOG_DIR = './permute_mnist_results' 69 | 70 | ## Evaluation options 71 | 72 | ## Num Tasks 73 | NUM_TASKS = 23 74 | MULTI_TASK = False 75 | 76 | def get_arguments(): 77 | """Parse all the arguments provided from the CLI. 78 | 79 | Returns: 80 | A list of parsed arguments. 81 | """ 82 | parser = argparse.ArgumentParser(description="Script for permutted mnist experiment.") 83 | parser.add_argument("--cross-validate-mode", action="store_true", 84 | help="If option is chosen then snapshoting after each batch is disabled") 85 | parser.add_argument("--online-cross-val", action="store_false", 86 | help="If option is chosen then enable the online cross validation of the learning rate") 87 | parser.add_argument("--train-single-epoch", action="store_false", 88 | help="If option is chosen then train for single epoch") 89 | parser.add_argument("--eval-single-head", action="store_true", 90 | help="If option is chosen then evaluate on a single head setting.") 91 | parser.add_argument("--arch", type=str, default=ARCH, help="Network Architecture for the experiment.\ 92 | \n \nSupported values: %s"%(VALID_ARCHS)) 93 | parser.add_argument("--num-runs", type=int, default=NUM_RUNS, 94 | help="Total runs/ experiments over which accuracy is averaged.") 95 | parser.add_argument("--train-iters", type=int, default=TRAIN_ITERS, 96 | help="Number of training iterations for each task.") 97 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 98 | help="Mini-batch size for each task.") 99 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 100 | help="Random Seed.") 101 | parser.add_argument('--dataset', type=str, default='rot-mnist', help='dataset (benchmark). could be rot-mnist or perm-mnist') 102 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 103 | help="Starting Learning rate for each task.") 104 | parser.add_argument("--optim", type=str, default=OPTIM, 105 | help="Optimizer for the experiment. \ 106 | \n \nSupported values: %s"%(VALID_OPTIMS)) 107 | parser.add_argument("--imp-method", type=str, default=IMP_METHOD, 108 | help="Model to be used for LLL. \ 109 | \n \nSupported values: %s"%(MODELS)) 110 | parser.add_argument("--synap-stgth", type=float, default=SYNAP_STGTH, 111 | help="Synaptic strength for the regularization.") 112 | parser.add_argument("--fisher-ema-decay", type=float, default=FISHER_EMA_DECAY, 113 | help="Exponential moving average decay for Fisher calculation at each step.") 114 | parser.add_argument("--fisher-update-after", type=int, default=FISHER_UPDATE_AFTER, 115 | help="Number of training iterations after which the Fisher will be updated.") 116 | parser.add_argument("--mem-size", type=int, default=SAMPLES_PER_CLASS, 117 | help="Number of samples per class from previous tasks.") 118 | parser.add_argument("--decay", type=float, default=0.4, 119 | help="learning rate decay factor (gamma)") 120 | parser.add_argument("--eps-mem-batch", type=int, default=EPS_MEM_BATCH_SIZE, 121 | help="Number of samples per class from previous tasks.") 122 | parser.add_argument("--examples-per-task", type=int, default=50000, 123 | help="Number of examples per task.") 124 | parser.add_argument("--log-dir", type=str, default=LOG_DIR, 125 | help="Directory where the plots and model accuracies will be stored.") 126 | return parser.parse_args() 127 | 128 | def train_task_sequence(model, sess, args): 129 | """ 130 | Train and evaluate LLL system such that we only see a example once 131 | Args: 132 | Returns: 133 | dict A dictionary containing mean and stds for the experiment 134 | """ 135 | # List to store accuracy for each run 136 | runs = [] 137 | 138 | batch_size = args.batch_size 139 | 140 | if model.imp_method == 'A-GEM' or model.imp_method == 'ER': 141 | use_episodic_memory = True 142 | else: 143 | use_episodic_memory = False 144 | 145 | # Loop over number of runs to average over 146 | for runid in range(args.num_runs): 147 | print('\t\tRun %d:'%(runid)) 148 | 149 | # Initialize the random seeds 150 | np.random.seed(args.random_seed+runid) 151 | 152 | # Load the permute mnist dataset 153 | if 'rot' in args.dataset or 'rotation' in args.dataset: 154 | datasets = construct_rotate_mnist(model.num_tasks) 155 | else: 156 | datasets = construct_permute_mnist(model.num_tasks) 157 | 158 | print("total datasets => ", len(datasets)) 159 | episodic_mem_size = args.mem_size*model.num_tasks*TOTAL_CLASSES 160 | 161 | # Initialize all the variables in the model 162 | sess.run(tf.global_variables_initializer()) 163 | 164 | # Run the init ops 165 | model.init_updates(sess) 166 | 167 | # List to store accuracies for a run 168 | evals = [] 169 | 170 | # List to store the classes that we have so far - used at test time 171 | test_labels = np.arange(TOTAL_CLASSES) 172 | 173 | if use_episodic_memory: 174 | # Reserve a space for episodic memory 175 | episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) 176 | episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) 177 | count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) 178 | episodic_filled_counter = 0 179 | examples_seen_so_far = 0 180 | 181 | # Mask for softmax 182 | # Since all the classes are present in all the tasks so nothing to mask 183 | logit_mask = np.ones(TOTAL_CLASSES) 184 | if model.imp_method == 'PNN': 185 | pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) 186 | pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) 187 | 188 | if COUNT_VIOLATIONS: 189 | violation_count = np.zeros(model.num_tasks) 190 | vc = 0 191 | 192 | # Training loop for all the tasks 193 | for task in range(len(datasets)): 194 | print('\t\tTask %d:'%(task)) 195 | 196 | # If not the first task then restore weights from previous task 197 | if(task > 0 and model.imp_method != 'PNN'): 198 | model.restore(sess) 199 | 200 | # Extract training images and labels for the current task 201 | task_train_images = datasets[task]['train']['images'] 202 | task_train_labels = datasets[task]['train']['labels'] 203 | 204 | # If multi_task is set the train using datasets of all the tasks 205 | if MULTI_TASK: 206 | if task == 0: 207 | for t_ in range(1, len(datasets)): 208 | task_train_images = np.concatenate((task_train_images, datasets[t_]['train']['images']), axis=0) 209 | task_train_labels = np.concatenate((task_train_labels, datasets[t_]['train']['labels']), axis=0) 210 | else: 211 | # Skip training for this task 212 | continue 213 | 214 | # Assign equal weights to all the examples 215 | task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) 216 | total_train_examples = task_train_images.shape[0] 217 | # Randomly suffle the training examples 218 | perm = np.arange(total_train_examples) 219 | np.random.shuffle(perm) 220 | train_x = task_train_images[perm][:args.examples_per_task] 221 | train_y = task_train_labels[perm][:args.examples_per_task] 222 | task_sample_weights = task_sample_weights[perm][:args.examples_per_task] 223 | 224 | print('Received {} images, {} labels at task {}'.format(train_x.shape[0], train_y.shape[0], task)) 225 | 226 | # Array to store accuracies when training for task T 227 | ftask = [] 228 | 229 | num_train_examples = train_x.shape[0] 230 | 231 | # Train a task observing sequence of data 232 | if args.train_single_epoch: 233 | num_iters = num_train_examples // batch_size 234 | else: 235 | num_iters = int(EPOCHS_PER_TASK*(num_train_examples // batch_size)) 236 | 237 | # Training loop for task T 238 | for iters in range(num_iters): 239 | 240 | if args.train_single_epoch and not args.cross_validate_mode: 241 | if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): 242 | # Snapshot the current performance across all tasks after each mini-batch 243 | fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val) 244 | ftask.append(fbatch) 245 | 246 | offset = (iters * batch_size) % (num_train_examples - batch_size) 247 | residual = batch_size 248 | 249 | if model.imp_method == 'PNN': 250 | pnn_train_phase[:] = False 251 | pnn_train_phase[task] = True 252 | feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size], 253 | model.sample_weights: task_sample_weights[offset:offset+batch_size], 254 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0} 255 | train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} 256 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} 257 | feed_dict.update(train_phase_dict) 258 | feed_dict.update(logit_mask_dict) 259 | else: 260 | feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size], 261 | model.sample_weights: task_sample_weights[offset:offset+batch_size], 262 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, 263 | model.output_mask: logit_mask, model.train_phase: True} 264 | 265 | if model.imp_method == 'VAN': 266 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 267 | 268 | elif model.imp_method == 'PNN': 269 | feed_dict[model.task_id] = task 270 | _, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) 271 | 272 | elif model.imp_method == 'FTR_EXT': 273 | if task == 0: 274 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 275 | else: 276 | _, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict) 277 | 278 | elif model.imp_method == 'EWC': 279 | # If first iteration of the first task then set the initial value of the running fisher 280 | if task == 0 and iters == 0: 281 | sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) 282 | # Update fisher after every few iterations 283 | if (iters + 1) % model.fisher_update_after == 0: 284 | sess.run(model.set_running_fisher) 285 | sess.run(model.reset_tmp_fisher) 286 | 287 | _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) 288 | 289 | elif model.imp_method == 'PI': 290 | _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, 291 | model.reg_loss], feed_dict=feed_dict) 292 | 293 | elif model.imp_method == 'MAS': 294 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 295 | 296 | elif model.imp_method == 'A-GEM': 297 | if task == 0: 298 | # Normal application of gradients 299 | _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) 300 | else: 301 | ## Compute and store the reference gradients on the previous tasks 302 | if episodic_filled_counter <= args.eps_mem_batch: 303 | mem_sample_mask = np.arange(episodic_filled_counter) 304 | else: 305 | # Sample a random subset from episodic memory buffer 306 | mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once 307 | # Store the reference gradient 308 | sess.run(model.store_ref_grads, feed_dict={model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], 309 | model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True}) 310 | if COUNT_VIOLATIONS: 311 | vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) 312 | else: 313 | # Compute the gradient for current task and project if need be 314 | _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) 315 | # Put the batch in the ring buffer 316 | for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): 317 | cls = np.unique(np.nonzero(er_y_))[-1] 318 | # Write the example at the location pointed by count_cls[cls] 319 | cls_to_index_map = cls 320 | with_in_task_offset = args.mem_size * cls_to_index_map 321 | mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter 322 | episodic_images[mem_index] = er_x 323 | episodic_labels[mem_index] = er_y_ 324 | count_cls[cls] = (count_cls[cls] + 1) % args.mem_size 325 | 326 | elif model.imp_method == 'RWALK': 327 | # If first iteration of the first task then set the initial value of the running fisher 328 | if task == 0 and iters == 0: 329 | sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) 330 | # Store the current value of the weights 331 | sess.run(model.weights_delta_old_grouped) 332 | # Update fisher and importance score after every few iterations 333 | if (iters + 1) % model.fisher_update_after == 0: 334 | # Update the importance score using distance in riemannian manifold 335 | sess.run(model.update_big_omega_riemann) 336 | # Now that the score is updated, compute the new value for running Fisher 337 | sess.run(model.set_running_fisher) 338 | # Store the current value of the weights 339 | sess.run(model.weights_delta_old_grouped) 340 | # Reset the delta_L 341 | sess.run([model.reset_small_omega]) 342 | 343 | _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, 344 | model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) 345 | 346 | elif model.imp_method == 'ER': 347 | mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size 348 | if mem_filled_so_far < args.eps_mem_batch: 349 | er_mem_indices = np.arange(mem_filled_so_far) 350 | else: 351 | er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) 352 | np.random.shuffle(er_mem_indices) 353 | # Train on a batch of episodic memory first 354 | er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0) 355 | er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0) 356 | feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch, 357 | model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, 358 | model.output_mask: logit_mask, model.train_phase: True} 359 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 360 | for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): 361 | update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) 362 | examples_seen_so_far += 1 363 | 364 | if (iters % 1000 == 0): 365 | print("get global step") 366 | tf.print(model.global_step, output_stream=sys.stderr) 367 | tf.print(model.learning_rate, output_stream=sys.stderr) 368 | lr = sess.run([model.learning_rate], feed_dict=feed_dict) 369 | print(lr) 370 | print('Step {:d} {:.3f}'.format(iters, loss)) 371 | 372 | if (math.isnan(loss)): 373 | print('ERROR: NaNs NaNs Nans!!!') 374 | sys.exit(0) 375 | 376 | print('\t\t\t\tTraining for Task%d done!'%(task)) 377 | 378 | # Upaate the episodic memory filled counter 379 | if use_episodic_memory: 380 | episodic_filled_counter += args.mem_size * TOTAL_CLASSES 381 | 382 | if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: 383 | violation_count[task] = vc 384 | print('Task {}: Violation Count: {}'.format(task, violation_count)) 385 | sess.run(model.reset_violation_count, feed_dict=feed_dict) 386 | 387 | # Compute the inter-task updates, Fisher/ importance scores etc 388 | # Don't calculate the task updates for the last task 389 | if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: 390 | model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) 391 | print('\t\t\t\tTask updates after Task%d done!'%(task)) 392 | 393 | if args.train_single_epoch and not args.cross_validate_mode: 394 | fbatch = test_task_sequence(model, sess, datasets, False) 395 | ftask.append(fbatch) 396 | ftask = np.array(ftask) 397 | else: 398 | if MEASURE_PERF_ON_EPS_MEMORY: 399 | eps_mem = { 400 | 'images': episodic_images, 401 | 'labels': episodic_labels, 402 | } 403 | # Measure perf on episodic memory 404 | ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val) 405 | else: 406 | # List to store accuracy for all the tasks for the current trained model 407 | ftask = test_task_sequence(model, sess, datasets, args.online_cross_val) 408 | 409 | # Store the accuracies computed at task T in a list 410 | evals.append(ftask) 411 | 412 | # Reset the optimizer 413 | model.reset_optimizer(sess) 414 | 415 | #-> End for loop task 416 | 417 | runs.append(np.array(evals)) 418 | # End for loop runid 419 | 420 | runs = np.array(runs) 421 | 422 | return runs 423 | 424 | def test_task_sequence(model, sess, test_data, cross_validate_mode): 425 | """ 426 | Snapshot the current performance 427 | """ 428 | if TIME_MY_METHOD: 429 | # Only compute the training time 430 | return np.zeros(model.num_tasks) 431 | 432 | list_acc = [] 433 | if model.imp_method == 'PNN': 434 | pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) 435 | else: 436 | logit_mask = np.ones(TOTAL_CLASSES) 437 | 438 | if MEASURE_PERF_ON_EPS_MEMORY: 439 | for task in range(model.num_tasks): 440 | mem_offset = task*SAMPLES_PER_CLASS*TOTAL_CLASSES 441 | feed_dict = {model.x: test_data['images'][mem_offset:mem_offset+SAMPLES_PER_CLASS*TOTAL_CLASSES], 442 | model.y_: test_data['labels'][mem_offset:mem_offset+SAMPLES_PER_CLASS*TOTAL_CLASSES], model.keep_prob: 1.0, 443 | model.output_mask: logit_mask, model.train_phase: False} 444 | acc = model.accuracy.eval(feed_dict = feed_dict) 445 | list_acc.append(acc) 446 | print(list_acc) 447 | return list_acc 448 | 449 | for task, _ in enumerate(test_data): 450 | 451 | if model.imp_method == 'PNN': 452 | pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) 453 | feed_dict = {model.x: test_data[task]['test']['images'], 454 | model.y_[task]: test_data[task]['test']['labels'], model.keep_prob: 1.0} 455 | train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} 456 | logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} 457 | feed_dict.update(train_phase_dict) 458 | feed_dict.update(logit_mask_dict) 459 | acc = model.accuracy[task].eval(feed_dict = feed_dict) 460 | else: 461 | feed_dict = {model.x: test_data[task]['test']['images'], 462 | model.y_: test_data[task]['test']['labels'], model.keep_prob: 1.0, 463 | model.output_mask: logit_mask, model.train_phase: False} 464 | acc = model.accuracy.eval(feed_dict = feed_dict) 465 | 466 | list_acc.append(acc) 467 | print("accuracy_list => ", list_acc) 468 | return list_acc 469 | 470 | def main(): 471 | """ 472 | Create the model and start the training 473 | """ 474 | 475 | # Get the CL arguments 476 | args = get_arguments() 477 | # 478 | 479 | # Check if the network architecture is valid 480 | if args.arch not in VALID_ARCHS: 481 | raise ValueError("Network architecture %s is not supported!"%(args.arch)) 482 | 483 | # Check if the method to compute importance is valid 484 | if args.imp_method not in MODELS: 485 | raise ValueError("Importance measure %s is undefined!"%(args.imp_method)) 486 | 487 | # Check if the optimizer is valid 488 | if args.optim not in VALID_OPTIMS: 489 | raise ValueError("Optimizer %s is undefined!"%(args.optim)) 490 | 491 | # Create log directories to store the results 492 | if not os.path.exists(args.log_dir): 493 | print('Log directory %s created!'%(args.log_dir)) 494 | os.makedirs(args.log_dir) 495 | 496 | # Generate the experiment key and store the meta data in a file 497 | exper_meta_data = {'DATASET': 'PERMUTE_MNIST', 498 | 'NUM_RUNS': args.num_runs, 499 | 'TRAIN_SINGLE_EPOCH': args.train_single_epoch, 500 | 'IMP_METHOD': args.imp_method, 501 | 'SYNAP_STGTH': args.synap_stgth, 502 | 'FISHER_EMA_DECAY': args.fisher_ema_decay, 503 | 'FISHER_UPDATE_AFTER': args.fisher_update_after, 504 | 'OPTIM': args.optim, 505 | 'LR': args.learning_rate, 506 | 'BATCH_SIZE': args.batch_size, 507 | 'MEM_SIZE': args.mem_size} 508 | experiment_id = "PERMUTE_MNIST_HERDING_%s_%s_%s_%s_%r_%s-"%(args.arch, args.train_single_epoch, args.imp_method, str(args.synap_stgth).replace('.', '_'), 509 | str(args.batch_size), str(args.mem_size)) + datetime.datetime.now().strftime("%y-%m-%d-%H-%M") 510 | snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data) 511 | 512 | # Get the subset of data depending on training or cross-validation mode 513 | 514 | args.online_cross_val = False 515 | if args.online_cross_val: 516 | num_tasks = K_FOR_CROSS_VAL 517 | else: 518 | num_tasks = NUM_TASKS #- K_FOR_CROSS_VAL 519 | 520 | # Variables to store the accuracies and standard deviations of the experiment 521 | acc_mean = dict() 522 | acc_std = dict() 523 | 524 | # Reset the default graph 525 | tf.reset_default_graph() 526 | graph = tf.Graph() 527 | with graph.as_default(): 528 | 529 | # Set the random seed 530 | tf.set_random_seed(args.random_seed) 531 | 532 | # Define Input and Output of the model 533 | x = tf.placeholder(tf.float32, shape=[None, INPUT_FEATURE_SIZE]) 534 | #x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS]) 535 | if args.imp_method == 'PNN': 536 | y_ = [] 537 | for i in range(num_tasks): 538 | y_.append(tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES])) 539 | else: 540 | y_ = tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES]) 541 | 542 | # Define the optimizer 543 | if args.optim == 'ADAM': 544 | opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) 545 | 546 | elif args.optim == 'SGD': 547 | opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) 548 | 549 | elif args.optim == 'MOMENTUM': 550 | base_lr = tf.constant(args.learning_rate) 551 | learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - train_step / training_iters), OPT_POWER)) 552 | opt = tf.train.MomentumOptimizer(args.learning_rate, OPT_MOMENTUM) 553 | 554 | # Create the Model/ contruct the graph 555 | model = Model(x, y_, num_tasks, opt, args.imp_method, args.synap_stgth, args.fisher_update_after, 556 | args.fisher_ema_decay, network_arch=args.arch, all_args=args) 557 | 558 | # Set up tf session and initialize variables. 559 | if USE_GPU: 560 | config = tf.ConfigProto() 561 | config.gpu_options.allow_growth = True 562 | else: 563 | config = tf.ConfigProto( 564 | device_count = {'GPU': 0} 565 | ) 566 | 567 | time_start = time.time() 568 | with tf.Session(config=config, graph=graph) as sess: 569 | runs = train_task_sequence(model, sess, args) 570 | # Close the session 571 | sess.close() 572 | time_end = time.time() 573 | time_spent = time_end - time_start 574 | 575 | # Store all the results in one dictionary to process later 576 | exper_acc = dict(mean=runs) 577 | 578 | # If cross-validation flag is enabled, store the stuff in a text file 579 | if args.cross_validate_mode: 580 | acc_mean = runs.mean(0) 581 | acc_std = runs.std(0) 582 | cross_validate_dump_file = args.log_dir + '/' + 'PERMUTE_MNIST_%s_%s'%(args.imp_method, args.optim) + '.txt' 583 | with open(cross_validate_dump_file, 'a') as f: 584 | if MULTI_TASK: 585 | f.write('GPU:{} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.format(USE_GPU, args.arch, args.learning_rate, 586 | args.synap_stgth, acc_mean[-1, :].mean())) 587 | else: 588 | f.write('GPU: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {} \t Fgt: {} \t Time: {}\n'.format(USE_GPU, args.arch, args.learning_rate, 589 | args.synap_stgth, acc_mean[-1, :].mean(), compute_fgt(acc_mean), str(time_spent))) 590 | 591 | # Store the experiment output to a file 592 | snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc) 593 | 594 | if __name__ == '__main__': 595 | main() 596 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import Model 7 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/replicate_cifar.sh: -------------------------------------------------------------------------------- 1 | ARCH='RESNET-S' 2 | NUM_RUNS=1 3 | BATCH_SIZE=10 4 | EPS_MEM_BATCH_SIZE=10 5 | 6 | OPTIM='SGD' 7 | lr=0.03 8 | lam=0.0 9 | LOG_DIR='results/cifar100' 10 | if [ ! -d $LOG_DIR ]; then 11 | mkdir -pv $LOG_DIR 12 | fi 13 | 14 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 1234 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --synap-stgth 0.0 --log-dir $LOG_DIR 15 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 7295 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --synap-stgth 0.0 --log-dir $LOG_DIR 16 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 5234 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --synap-stgth 0.0 --log-dir $LOG_DIR 17 | 18 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 1234 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'EWC' --synap-stgth 10.0 --log-dir $LOG_DIR 19 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 7295 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'EWC' --synap-stgth 10.0 --log-dir $LOG_DIR 20 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 5234 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'EWC' --synap-stgth 10.0 --log-dir $LOG_DIR 21 | 22 | 23 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 1234 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --synap-stgth 0.0 --log-dir $LOG_DIR 24 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 7295 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --synap-stgth 0.0 --log-dir $LOG_DIR 25 | python3 ./conv_split_cifar.py --train-single-epoch --mem-size 1 --arch $ARCH --random-seed 5234 --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --synap-stgth 0.0 --log-dir $LOG_DIR 26 | 27 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/replicate_mnist.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | #! /bin/bash 4 | 5 | IMP_METHOD='A-GEM' 6 | NUM_RUNS=1 7 | BATCH_SIZE=32 8 | EPS_MEM_BATCH_SIZE=10 9 | MEM_SIZE=1 10 | LOG_DIR='results/mnist' 11 | lr=0.1 12 | ARCH='FC-S' 13 | OPTIM='SGD' 14 | lambda=10 15 | 16 | 17 | echo "Replicating results for stable ER-Reservoir" 18 | python3 ./fc_mnist.py --dataset $1 --random-seed 1345 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --log-dir $LOG_DIR 19 | python3 ./fc_mnist.py --dataset $1 --random-seed 1455 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --log-dir $LOG_DIR 20 | python3 ./fc_mnist.py --dataset $1 --random-seed 1668 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --log-dir $LOG_DIR 21 | 22 | 23 | echo "Replicating results for stable A-GEM" 24 | python3 ./fc_mnist.py --dataset $1 --random-seed 1345 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 25 | python3 ./fc_mnist.py --dataset $1 --random-seed 1455 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 26 | python3 ./fc_mnist.py --dataset $1 --random-seed 1668 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 27 | 28 | 29 | echo "Replicating results for stable EWC" 30 | python3 ./fc_mnist.py --dataset $1 --random-seed 1345 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'EWC' --log-dir $LOG_DIR 31 | python3 ./fc_mnist.py --dataset $1 --random-seed 1455 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'EWC' --log-dir $LOG_DIR 32 | python3 ./fc_mnist.py --dataset $1 --random-seed 1668 --decay 1.0 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'EWC' --log-dir $LOG_DIR 33 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/replicate_mnist_stable.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | #! /bin/bash 4 | 5 | IMP_METHOD='A-GEM' 6 | NUM_RUNS=1 7 | BATCH_SIZE=32 8 | EPS_MEM_BATCH_SIZE=10 9 | MEM_SIZE=1 10 | LOG_DIR='results/mnist' 11 | lr=0.1 12 | ARCH='FC-S' 13 | OPTIM='SGD' 14 | lambda=10 15 | 16 | 17 | 18 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1345 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --log-dir $LOG_DIR 19 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1455 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --log-dir $LOG_DIR 20 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1668 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'ER' --log-dir $LOG_DIR 21 | 22 | 23 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1345 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 24 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1455 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 25 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1668 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 26 | 27 | 28 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1345 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 29 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1455 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 30 | python3 ./fc_mnist.py --dataset 'rot-mnist' --random-seed 1668 --decay 0.65 --examples-per-task 50000 --arch $ARCH --num-runs $NUM_RUNS --batch-size $BATCH_SIZE --optim $OPTIM --learning-rate $lr --imp-method 'A-GEM' --log-dir $LOG_DIR 31 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/run.sh: -------------------------------------------------------------------------------- 1 | apt-get install libsm6 libxrender1 libfontconfig1 libsm6 libxext6 libxrender-dev python-cairocffi python3-cairocffi libcurl4-openssl-dev libssl-dev 2 | python3 -m pip install opencv-contrib-python 3 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .data_utils import construct_permute_mnist, construct_split_mnist, construct_split_cifar, construct_rotate_mnist 7 | from .data_utils import image_scaling, random_crop_and_pad_image, random_horizontal_flip 8 | from .utils import clone_variable_list, create_fc_layer, create_conv_layer, sample_from_dataset, update_episodic_memory, update_episodic_memory_with_less_data, concatenate_datasets 9 | from .utils import samples_for_each_class, sample_from_dataset_icarl, get_sample_weights, compute_fgt, load_task_specific_data, load_task_specific_data_in_proportion 10 | from .utils import average_acc_stats_across_runs, average_fgt_stats_across_runs 11 | from .vis_utils import plot_acc_multiple_runs, plot_histogram, snapshot_experiment_meta_data, snapshot_experiment_eval, snapshot_task_labels 12 | from .resnet_utils import _conv, _fc, _bn, _residual_block, _residual_block_first 13 | from .vgg_utils import vgg_conv_layer, vgg_fc_layer 14 | from .er_utils import update_reservior, update_fifo_buffer, er_mem_update_hindsight, update_avg_image_vectors -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Define utility functions for manipulating datasets 8 | """ 9 | import os 10 | import numpy as np 11 | import sys 12 | from copy import deepcopy 13 | 14 | import tensorflow as tf 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | 17 | from six.moves.urllib.request import urlretrieve 18 | from six.moves import cPickle as pickle 19 | import tarfile 20 | import zipfile 21 | import random 22 | import cv2 23 | from scipy import ndimage 24 | import matplotlib 25 | matplotlib.use('Agg') 26 | import matplotlib.pyplot as plt 27 | 28 | #IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | IMG_MEAN = np.array((103.94,116.78,123.68), dtype=np.float32) 30 | ############################################################ 31 | ### Data augmentation utils ################################ 32 | ############################################################ 33 | def image_scaling(images): 34 | """ 35 | Randomly scales the images between 0.5 to 1.5 times the original size. 36 | Args: 37 | images: Training images to scale. 38 | """ 39 | scale = tf.random_uniform([1], minval=0.5, maxval=1.5, dtype=tf.float32, seed=None) 40 | h_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(images)[1]), scale)) 41 | w_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(images)[2]), scale)) 42 | new_shape = tf.squeeze(tf.stack([h_new, w_new]), squeeze_dims=[1]) 43 | images = tf.image.resize_images(images, new_shape) 44 | result = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), images) 45 | return result 46 | 47 | 48 | def random_crop_and_pad_image(images, crop_h, crop_w): 49 | """ 50 | Randomly crop and pads the input images. 51 | Args: 52 | images: Training i mages to crop/ pad. 53 | crop_h: Height of cropped segment. 54 | crop_w: Width of cropped segment. 55 | """ 56 | image_shape = tf.shape(images) 57 | image_pad = tf.image.pad_to_bounding_box(images, 0, 0, tf.maximum(crop_h, image_shape[1]), tf.maximum(crop_w, image_shape[2])) 58 | img_crop = tf.map_fn(lambda img: tf.random_crop(img, [crop_h,crop_w,3]), image_pad) 59 | return img_crop 60 | 61 | def random_horizontal_flip(x): 62 | """ 63 | Randomly flip a batch of images horizontally 64 | Args: 65 | x Tensor of shape B x H x W x C 66 | Returns: 67 | random_flipped Randomly flipped tensor of shape B x H x W x C 68 | """ 69 | # Define random horizontal flip 70 | flips = [(slice(None, None, None), slice(None, None, random.choice([-1, None])), slice(None, None, None)) 71 | for _ in xrange(x.shape[0])] 72 | random_flipped = np.array([img[flip] for img, flip in zip(x, flips)]) 73 | return random_flipped 74 | 75 | ############################################################ 76 | ### CIFAR download utils ################################### 77 | ############################################################ 78 | CIFAR_10_URL = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 79 | CIFAR_100_URL = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 80 | CIFAR_10_DIR = "/cifar_10" 81 | CIFAR_100_DIR = "/cifar_100" 82 | 83 | def construct_split_cifar(task_labels, is_cifar_100=True): 84 | """ 85 | Construct Split CIFAR-10 and CIFAR-100 datasets 86 | 87 | Args: 88 | task_labels Labels of different tasks 89 | data_dir Data directory where the CIFAR data will be saved 90 | """ 91 | 92 | data_dir = 'CIFAR_data' 93 | 94 | # Get the cifar dataset 95 | cifar_data = _get_cifar(data_dir, is_cifar_100) 96 | 97 | # Define a list for storing the data for different tasks 98 | datasets = [] 99 | 100 | # Data splits 101 | sets = ["train", "validation", "test"] 102 | 103 | for task in task_labels: 104 | 105 | for set_name in sets: 106 | this_set = cifar_data[set_name] 107 | 108 | global_class_indices = np.column_stack(np.nonzero(this_set[1])) 109 | count = 0 110 | 111 | for cls in task: 112 | if count == 0: 113 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == 114 | cls][:,np.array([True, False])]) 115 | else: 116 | class_indices = np.append(class_indices, np.squeeze(global_class_indices[global_class_indices[:,1] ==\ 117 | cls][:,np.array([True, False])])) 118 | 119 | count += 1 120 | 121 | class_indices = np.sort(class_indices, axis=None) 122 | 123 | if set_name == "train": 124 | train = { 125 | 'images':deepcopy(this_set[0][class_indices, :]), 126 | 'labels':deepcopy(this_set[1][class_indices, :]), 127 | } 128 | elif set_name == "validation": 129 | validation = { 130 | 'images':deepcopy(this_set[0][class_indices, :]), 131 | 'labels':deepcopy(this_set[1][class_indices, :]), 132 | } 133 | elif set_name == "test": 134 | test = { 135 | 'images':deepcopy(this_set[0][class_indices, :]), 136 | 'labels':deepcopy(this_set[1][class_indices, :]), 137 | } 138 | 139 | cifar = { 140 | 'train': train, 141 | 'validation': validation, 142 | 'test': test, 143 | } 144 | 145 | datasets.append(cifar) 146 | 147 | return datasets 148 | 149 | 150 | def _get_cifar(data_dir, is_cifar_100): 151 | """ 152 | Get the CIFAR-10 and CIFAR-100 datasets 153 | 154 | Args: 155 | data_dir Directory where the downloaded data will be stored 156 | """ 157 | x_train = None 158 | y_train = None 159 | x_validation = None 160 | y_validation = None 161 | x_test = None 162 | y_test = None 163 | l = None 164 | 165 | # Download the dataset if needed 166 | _cifar_maybe_download_and_extract(data_dir) 167 | 168 | # Dictionary to store the dataset 169 | dataset = dict() 170 | dataset['train'] = [] 171 | dataset['validation'] = [] 172 | dataset['test'] = [] 173 | 174 | def dense_to_one_hot(labels_dense, num_classes=100): 175 | num_labels = labels_dense.shape[0] 176 | index_offset = np.arange(num_labels) * num_classes 177 | labels_one_hot = np.zeros((num_labels, num_classes)) 178 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 179 | 180 | return labels_one_hot 181 | 182 | if is_cifar_100: 183 | # Load the training data of CIFAR-100 184 | f = open(data_dir + CIFAR_100_DIR + '/train', 'rb') 185 | datadict = pickle.load(f, encoding='latin1') 186 | f.close() 187 | 188 | _X = datadict['data'] 189 | _Y = np.array(datadict['fine_labels']) 190 | _Y = dense_to_one_hot(_Y, num_classes=100) 191 | 192 | _X = np.array(_X, dtype=float) / 255.0 193 | _X = _X.reshape([-1, 3, 32, 32]) 194 | _X = _X.transpose([0, 2, 3, 1]) 195 | 196 | # Compute the data mean for normalization 197 | x_train_mean = np.mean(_X, axis=0) 198 | 199 | x_train = _X[:40000] 200 | y_train = _Y[:40000] 201 | 202 | x_validation = _X[40000:] 203 | y_validation = _Y[40000:] 204 | else: 205 | # Load all the training batches of the CIFAR-10 206 | for i in range(5): 207 | f = open(data_dir + CIFAR_10_DIR + '/data_batch_' + str(i + 1), 'rb') 208 | datadict = pickle.load(f, encoding='latin1') 209 | f.close() 210 | 211 | _X = datadict['data'] 212 | _Y = np.array(datadict['labels']) 213 | _Y = dense_to_one_hot(_Y, num_classes=10) 214 | 215 | _X = np.array(_X, dtype=float) / 255.0 216 | _X = _X.reshape([-1, 3, 32, 32]) 217 | _X = _X.transpose([0, 2, 3, 1]) 218 | 219 | if x_train is None: 220 | x_train = _X 221 | y_train = _Y 222 | else: 223 | x_train = np.concatenate((x_train, _X), axis=0) 224 | y_train = np.concatenate((y_train, _Y), axis=0) 225 | 226 | # Compute the data mean for normalization 227 | x_train_mean = np.mean(x_train, axis=0) 228 | x_validation = x_train[:40000] # We don't use validation set with CIFAR-10 229 | y_validation = y_train[40000:] 230 | 231 | # Normalize the train and validation sets 232 | x_train -= x_train_mean 233 | x_validation -= x_train_mean 234 | 235 | dataset['train'].append(x_train) 236 | dataset['train'].append(y_train) 237 | dataset['train'].append(l) 238 | 239 | dataset['validation'].append(x_validation) 240 | dataset['validation'].append(y_validation) 241 | dataset['validation'].append(l) 242 | 243 | if is_cifar_100: 244 | # Load the test batch of CIFAR-100 245 | f = open(data_dir + CIFAR_100_DIR + '/test', 'rb') 246 | datadict = pickle.load(f, encoding='latin1') 247 | f.close() 248 | 249 | _X = datadict['data'] 250 | _Y = np.array(datadict['fine_labels']) 251 | _Y = dense_to_one_hot(_Y, num_classes=100) 252 | else: 253 | # Load the test batch of CIFAR-10 254 | f = open(data_dir + CIFAR_10_DIR + '/test_batch', 'rb') 255 | datadict = pickle.load(f, encoding='latin1') 256 | f.close() 257 | 258 | _X = datadict["data"] 259 | _Y = np.array(datadict['labels']) 260 | _Y = dense_to_one_hot(_Y, num_classes=10) 261 | 262 | _X = np.array(_X, dtype=float) / 255.0 263 | _X = _X.reshape([-1, 3, 32, 32]) 264 | _X = _X.transpose([0, 2, 3, 1]) 265 | 266 | x_test = _X 267 | y_test = _Y 268 | 269 | # Normalize the test set 270 | x_test -= x_train_mean 271 | 272 | dataset['test'].append(x_test) 273 | dataset['test'].append(y_test) 274 | dataset['test'].append(l) 275 | 276 | return dataset 277 | 278 | 279 | def _print_download_progress(count, block_size, total_size): 280 | """ 281 | Show the download progress of the cifar data 282 | """ 283 | pct_complete = float(count * block_size) / total_size 284 | msg = "\r- Download progress: {0:.1%}".format(pct_complete) 285 | sys.stdout.write(msg) 286 | sys.stdout.flush() 287 | 288 | 289 | def _cifar_maybe_download_and_extract(data_dir): 290 | """ 291 | Routine to download and extract the cifar dataset 292 | 293 | Args: 294 | data_dir Directory where the downloaded data will be stored 295 | """ 296 | cifar_10_directory = data_dir + CIFAR_10_DIR 297 | cifar_100_directory = data_dir + CIFAR_100_DIR 298 | 299 | # If the data_dir does not exist, create the directory and download 300 | # the data 301 | if not os.path.exists(data_dir): 302 | os.makedirs(data_dir) 303 | 304 | url = CIFAR_10_URL 305 | filename = url.split('/')[-1] 306 | file_path = os.path.join(data_dir, filename) 307 | zip_cifar_10 = file_path 308 | file_path, _ = urlretrieve(url=url, filename=file_path, reporthook=_print_download_progress) 309 | 310 | print() 311 | print("Download finished. Extracting files.") 312 | if file_path.endswith(".zip"): 313 | zipfile.ZipFile(file=file_path, mode="r").extractall(data_dir) 314 | elif file_path.endswith((".tar.gz", ".tgz")): 315 | tarfile.open(name=file_path, mode="r:gz").extractall(data_dir) 316 | print("Done.") 317 | 318 | url = CIFAR_100_URL 319 | filename = url.split('/')[-1] 320 | file_path = os.path.join(data_dir, filename) 321 | zip_cifar_100 = file_path 322 | file_path, _ = urlretrieve(url=url, filename=file_path, reporthook=_print_download_progress) 323 | 324 | print() 325 | print("Download finished. Extracting files.") 326 | if file_path.endswith(".zip"): 327 | zipfile.ZipFile(file=file_path, mode="r").extractall(data_dir) 328 | elif file_path.endswith((".tar.gz", ".tgz")): 329 | tarfile.open(name=file_path, mode="r:gz").extractall(data_dir) 330 | print("Done.") 331 | 332 | os.rename(data_dir + "/cifar-10-batches-py", cifar_10_directory) 333 | os.rename(data_dir + "/cifar-100-python", cifar_100_directory) 334 | os.remove(zip_cifar_10) 335 | os.remove(zip_cifar_100) 336 | 337 | 338 | ######################################### 339 | ## MNIST Utils ########################## 340 | ######################################### 341 | def reformat_mnist(datasets): 342 | """ 343 | Routine to Reformat the mnist dataset into a 3d tensor 344 | """ 345 | image_size = 28 # Height of MNIST dataset 346 | num_channels = 1 # Gray scale 347 | for i in range(len(datasets)): 348 | sets = ["train", "validation", "test"] 349 | for set_name in sets: 350 | datasets[i]['%s'%set_name]['images'] = datasets[i]['%s'%set_name]['images'].reshape\ 351 | ((-1, image_size, image_size, num_channels)).astype(np.float32) 352 | 353 | return datasets 354 | 355 | 356 | def rotate_image_by_angle(img, angle=45): 357 | WIDTH, HEIGHT = 28 , 28 358 | img = img.reshape((WIDTH, HEIGHT)) 359 | img = ndimage.rotate(img, angle, reshape=False, order=0) 360 | out = np.array(img).flatten() 361 | return out 362 | 363 | def construct_rotate_mnist(num_tasks): 364 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 365 | datasets = [] 366 | 367 | for i in range(num_tasks): 368 | per_task_rotation = 180.0 / num_tasks 369 | rotation_degree = (i - 1)*per_task_rotation 370 | rotation_degree -= (np.random.random()*per_task_rotation) 371 | copied_mnist = deepcopy(mnist) 372 | sets = ["train", "validation", "test"] 373 | for set_name in sets: 374 | this_set = getattr(copied_mnist, set_name) # shallow copy 375 | 376 | rotate_image_by_angle(this_set._images[0]) 377 | this_set._images = np.array([rotate_image_by_angle(img, rotation_degree) for img in this_set._images]) 378 | if set_name == "train": 379 | train = { 380 | 'images':this_set._images, 381 | 'labels':this_set.labels, 382 | } 383 | elif set_name == "validation": 384 | validation = { 385 | 'images':this_set._images, 386 | 'labels':this_set.labels, 387 | } 388 | elif set_name == "test": 389 | test = { 390 | 'images':this_set._images, 391 | 'labels':this_set.labels, 392 | } 393 | dataset = { 394 | 'train': train, 395 | 'validation': validation, 396 | 'test': test, 397 | } 398 | 399 | datasets.append(dataset) 400 | 401 | return datasets 402 | def construct_permute_mnist(num_tasks): 403 | """ 404 | Construct a dataset of permutted mnist images 405 | 406 | Args: 407 | num_tasks Number of tasks 408 | Returns 409 | dataset A permutted mnist dataset 410 | """ 411 | # Download and store mnist dataset 412 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 413 | 414 | datasets = [] 415 | 416 | for i in range(num_tasks): 417 | perm_inds = list(range(mnist.train.images.shape[1])) 418 | np.random.shuffle(perm_inds) 419 | copied_mnist = deepcopy(mnist) 420 | sets = ["train", "validation", "test"] 421 | for set_name in sets: 422 | this_set = getattr(copied_mnist, set_name) # shallow copy 423 | this_set._images = np.transpose(np.array([this_set.images[:,c] for c in perm_inds])) 424 | # print(this_set._images.shape) 425 | if set_name == "train": 426 | train = { 427 | 'images':this_set._images, 428 | 'labels':this_set.labels, 429 | } 430 | elif set_name == "validation": 431 | validation = { 432 | 'images':this_set._images, 433 | 'labels':this_set.labels, 434 | } 435 | elif set_name == "test": 436 | test = { 437 | 'images':this_set._images, 438 | 'labels':this_set.labels, 439 | } 440 | dataset = { 441 | 'train': train, 442 | 'validation': validation, 443 | 'test': test, 444 | } 445 | 446 | datasets.append(dataset) 447 | 448 | return datasets 449 | 450 | def construct_split_mnist(task_labels): 451 | """ 452 | Construct a split mnist dataset 453 | 454 | Args: 455 | task_labels List of split labels 456 | 457 | Returns: 458 | dataset A list of split datasets 459 | 460 | """ 461 | # Download and store mnist dataset 462 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 463 | 464 | datasets = [] 465 | 466 | sets = ["train", "validation", "test"] 467 | 468 | for task in task_labels: 469 | 470 | for set_name in sets: 471 | this_set = getattr(mnist, set_name) 472 | 473 | global_class_indices = np.column_stack(np.nonzero(this_set.labels)) 474 | count = 0 475 | 476 | for cls in task: 477 | if count == 0: 478 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] ==\ 479 | cls][:,np.array([True, False])]) 480 | else: 481 | class_indices = np.append(class_indices, np.squeeze(global_class_indices[global_class_indices[:,1] ==\ 482 | cls][:,np.array([True, False])])) 483 | count += 1 484 | 485 | class_indices = np.sort(class_indices, axis=None) 486 | 487 | if set_name == "train": 488 | train = { 489 | 'images':deepcopy(mnist.train.images[class_indices, :]), 490 | 'labels':deepcopy(mnist.train.labels[class_indices, :]), 491 | } 492 | elif set_name == "validation": 493 | validation = { 494 | 'images':deepcopy(mnist.validation.images[class_indices, :]), 495 | 'labels':deepcopy(mnist.validation.labels[class_indices, :]), 496 | } 497 | elif set_name == "test": 498 | test = { 499 | 'images':deepcopy(mnist.test.images[class_indices, :]), 500 | 'labels':deepcopy(mnist.test.labels[class_indices, :]), 501 | } 502 | 503 | mnist2 = { 504 | 'train': train, 505 | 'validation': validation, 506 | 'test': test, 507 | } 508 | 509 | datasets.append(mnist2) 510 | 511 | return datasets 512 | 513 | ################################################### 514 | ###### ImageNet Utils ############################# 515 | ################################################### 516 | def construct_split_imagenet(task_labels, data_dir): 517 | """ 518 | Construct Split ImageNet dataset 519 | 520 | Args: 521 | task_labels Labels of different tasks 522 | data_dir Data directory from where to load the imagenet data 523 | """ 524 | 525 | # Load the imagenet dataset 526 | imagenet_data = _load_imagenet(data_dir) 527 | 528 | # Define a list for storing the data for different tasks 529 | datasets = [] 530 | 531 | # Data splits 532 | sets = ["train", "test"] 533 | 534 | for task in task_labels: 535 | 536 | for set_name in sets: 537 | this_set = imagenet_data[set_name] 538 | 539 | global_class_indices = np.column_stack(np.nonzero(this_set[1])) 540 | count = 0 541 | 542 | for cls in task: 543 | if count == 0: 544 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == 545 | cls][:,np.array([True, False])]) 546 | else: 547 | class_indices = np.append(class_indices, np.squeeze(global_class_indices[global_class_indices[:,1] ==\ 548 | cls][:,np.array([True, False])])) 549 | 550 | count += 1 551 | 552 | class_indices = np.sort(class_indices, axis=None) 553 | 554 | if set_name == "train": 555 | train = { 556 | 'images':deepcopy(this_set[0][class_indices, :]), 557 | 'labels':deepcopy(this_set[1][class_indices, :]), 558 | } 559 | elif set_name == "test": 560 | test = { 561 | 'images':deepcopy(this_set[0][class_indices, :]), 562 | 'labels':deepcopy(this_set[1][class_indices, :]), 563 | } 564 | 565 | imagenet = { 566 | 'train': train, 567 | 'test': test, 568 | } 569 | 570 | datasets.append(imagenet) 571 | 572 | return datasets 573 | 574 | def _load_imagenet(data_dir): 575 | """ 576 | Load the ImageNet data 577 | 578 | Args: 579 | data_dir Directory where the pickle files have been dumped 580 | """ 581 | x_train = None 582 | y_train = None 583 | x_test = None 584 | y_test = None 585 | 586 | # Dictionary to store the dataset 587 | dataset = dict() 588 | dataset['train'] = [] 589 | dataset['test'] = [] 590 | 591 | def dense_to_one_hot(labels_dense, num_classes=100): 592 | num_labels = labels_dense.shape[0] 593 | index_offset = np.arange(num_labels) * num_classes 594 | labels_one_hot = np.zeros((num_labels, num_classes)) 595 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 596 | 597 | return labels_one_hot 598 | 599 | # Load the training batches 600 | for i in range(4): 601 | f = open(data_dir + '/train_batch_' + str(i), 'rb') 602 | datadict = pickle.load(f, encoding='latin1') 603 | f.close() 604 | 605 | _X = datadict['data'] 606 | _Y = np.array(datadict['labels']) 607 | 608 | # Convert the lables to one-hot 609 | _Y = dense_to_one_hot(_Y) 610 | 611 | # Normalize the images 612 | _X = np.array(_X, dtype=float)/ 255.0 613 | _X = _X.reshape([-1, 224, 224, 3]) 614 | 615 | if x_train is None: 616 | x_train = _X 617 | y_train = _Y 618 | else: 619 | x_train = np.concatenate((x_train, _X), axis=0) 620 | y_train = np.concatenate((y_train, _Y), axis=0) 621 | 622 | dataset['train'].append(x_train) 623 | dataset['train'].append(y_train) 624 | 625 | # Load test batches 626 | for i in range(4): 627 | f = open(data_dir + '/test_batch_' + str(i), 'rb') 628 | datadict = pickle.load(f, encoding='latin1') 629 | f.close() 630 | 631 | _X = datadict['data'] 632 | _Y = np.array(datadict['labels']) 633 | 634 | # Convert the lables to one-hot 635 | _Y = dense_to_one_hot(_Y) 636 | 637 | # Normalize the images 638 | _X = np.array(_X, dtype=float)/ 255.0 639 | _X = _X.reshape([-1, 224, 224, 3]) 640 | 641 | if x_test is None: 642 | x_test = _X 643 | y_test = _Y 644 | else: 645 | x_test = np.concatenate((x_test, _X), axis=0) 646 | y_test = np.concatenate((y_test, _Y), axis=0) 647 | 648 | dataset['test'].append(x_test) 649 | dataset['test'].append(y_test) 650 | 651 | 652 | return dataset 653 | 654 | if __name__ == "__main__": 655 | construct_rotate_mnist(20) 656 | # rotate_image_by_angle(np.array([[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]])) 657 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/er_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def update_reservior(images, labels, episodic_images, episodic_labels, M, N): 4 | """ 5 | Update the episodic memory with current example using the reservior sampling 6 | """ 7 | for er_x, er_y in zip(images, labels): 8 | if M > N: 9 | episodic_images[N] = er_x 10 | episodic_labels[N] = er_y 11 | else: 12 | j = np.random.randint(0, N) 13 | if j < M: 14 | episodic_images[j] = er_x 15 | episodic_labels[j] = er_y 16 | N += 1 17 | 18 | return N 19 | 20 | def update_fifo_buffer(images, labels, episodic_images, episodic_labels, task_labels, mem_per_class, count_cls, N): 21 | for er_x, er_y in zip(images, labels): 22 | cls = np.unique(np.nonzero(er_y))[-1] 23 | # Write the example at the location pointed by count_cls[cls] 24 | cls_to_index_map = np.where(np.array(task_labels) == cls)[0][0] 25 | with_in_task_offset = mem_per_class * cls_to_index_map 26 | mem_index = count_cls[cls] + with_in_task_offset + N 27 | episodic_images[mem_index] = er_x 28 | episodic_labels[mem_index] = er_y 29 | count_cls[cls] = (count_cls[cls] + 1) % mem_per_class 30 | 31 | return 32 | 33 | def er_mem_update_hindsight(model, sess, x_hat_batch, y_hat_batch, episodic_images, episodic_labels, episodic_filled_counter, 34 | task_labels, logit_mask, phi_hat, avg_img_vectors, args, loop_over_mem=50): 35 | """ 36 | Update the episodic memory using hindsight 37 | """ 38 | # Store the current estimate of the parameters in the star_params 39 | sess.run(model.set_star_vars) 40 | 41 | # Train on the episodic memory to get the new estimate of the parameters 42 | batch_size = 10 43 | samples_at_a_time = episodic_filled_counter if (episodic_filled_counter <= batch_size) else batch_size 44 | for jj in range(loop_over_mem): 45 | mem_indices = np.random.choice(episodic_filled_counter, samples_at_a_time, replace=False) 46 | train_x = episodic_images[mem_indices] 47 | train_y = episodic_labels[mem_indices] 48 | feed_dict = {model.x: train_x, model.y_: train_y, model.keep_prob: 1.0, 49 | model.learning_rate: args.learning_rate} 50 | feed_dict[model.output_mask] = logit_mask 51 | _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) 52 | if jj % 5 == 0: 53 | print('Hindsight loss:{}'.format(loss)) 54 | 55 | 56 | # Update this synthetic samples by maximizing forgetting loss while maintaining good performance on the current task 57 | for jj, cls in enumerate(task_labels): 58 | y_hat_dense = np.repeat(cls, 1) 59 | y_hat_one_hot = _dense_to_one_hot(y_hat_dense, model.total_classes) 60 | 61 | # Initialize the anchor for this task this class 62 | sess.run(model.anchor_xx.assign(np.expand_dims(avg_img_vectors[cls], axis=0))) 63 | 64 | for ii in range(100): 65 | feed_dict = {model.y_: y_hat_one_hot, model.phi_hat_reference: phi_hat[cls] , model.keep_prob: 1.0} 66 | feed_dict[model.output_mask] = logit_mask 67 | fgt_loss, phi_dist, total_loss, _ = sess.run([model.negForgetting_loss, model.phi_distance, model.hindsight_objective, model.update_hindsight_anchor], feed_dict=feed_dict) 68 | if ii%100 == 0: 69 | print('Fgt_loss: {}\t Phi_dist: {}\t Total: {}'.format(fgt_loss, phi_dist, total_loss)) 70 | 71 | # Store the learned images in the episodic memory 72 | offset = jj 73 | class_x_hat = sess.run(model.anchor_xx) 74 | x_hat_batch[jj] = class_x_hat 75 | y_hat_batch[jj] = y_hat_one_hot 76 | 77 | # Restore the weights 78 | sess.run(model.restore_weights) 79 | return x_hat_batch, y_hat_batch 80 | 81 | def update_avg_image_vectors(train_x, train_y, avg_img_vectors, running_alpha=0.5): 82 | """ 83 | Updates the average image vectors 84 | 85 | avg_img_vectors => TOTAL_CLASSES x H x W x C 86 | """ 87 | # For each label in the batch, update the corresponding avg_image_vector 88 | num_examples = train_x.shape[0] 89 | for ii in range(num_examples): 90 | yy = train_y[ii] 91 | cls = np.nonzero(yy) 92 | avg_img_vectors[cls] -= (1 - running_alpha) * (avg_img_vectors[cls] - train_x[ii]) # running average 93 | 94 | return 95 | 96 | 97 | # -------------------------- Internet APIs ---------------------------------------------------------------------------------------- 98 | def _dense_to_one_hot(labels_dense, num_classes): 99 | num_labels = labels_dense.shape[0] 100 | index_offset = np.arange(num_labels) * num_classes 101 | labels_one_hot = np.zeros((num_labels, num_classes)) 102 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 103 | 104 | return labels_one_hot 105 | 106 | def _get_indices_of_class_examples(train_y, cls): 107 | """ 108 | Returns the indies of examples with given class label 109 | """ 110 | global_class_indices = np.column_stack(np.nonzero(train_y)) 111 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == cls][:,np.array([True, False])]) 112 | return class_indices -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | def _conv(x, kernel_size, out_channels, stride, var_list, pad="SAME", name="conv"): 11 | """ 12 | Define API for conv operation. This includes kernel declaration and 13 | conv operation both. 14 | """ 15 | in_channels = x.get_shape().as_list()[-1] 16 | with tf.variable_scope(name): 17 | #n = kernel_size * kernel_size * out_channels 18 | n = kernel_size * in_channels 19 | stdv = 1.0 / math.sqrt(n) 20 | w = tf.get_variable('kernel', [kernel_size, kernel_size, in_channels, out_channels], 21 | tf.float32, 22 | initializer=tf.random_uniform_initializer(-stdv, stdv)) 23 | #initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/n))) 24 | 25 | # Append the variable to the trainable variables list 26 | var_list.append(w) 27 | 28 | # Do the convolution operation 29 | output = tf.nn.conv2d(x, w, [1, stride, stride, 1], padding=pad) 30 | return output 31 | 32 | def _fc(x, out_dim, var_list, name="fc", is_cifar=False): 33 | """ 34 | Define API for the fully connected layer. This includes both the variable 35 | declaration and matmul operation. 36 | """ 37 | in_dim = x.get_shape().as_list()[1] 38 | stdv = 1.0 / math.sqrt(in_dim) 39 | with tf.variable_scope(name): 40 | # Define the weights and biases for this layer 41 | w = tf.get_variable('weights', [in_dim, out_dim], tf.float32, 42 | initializer=tf.random_uniform_initializer(-stdv, stdv)) 43 | #initializer=tf.truncated_normal_initializer(stddev=0.1)) 44 | if is_cifar: 45 | b = tf.get_variable('biases', [out_dim], tf.float32, initializer=tf.random_uniform_initializer(-stdv, stdv)) 46 | else: 47 | b = tf.get_variable('biases', [out_dim], tf.float32, initializer=tf.constant_initializer(0)) 48 | 49 | # Append the variable to the trainable variables list 50 | var_list.append(w) 51 | var_list.append(b) 52 | 53 | # Do the FC operation 54 | output = tf.matmul(x, w) + b 55 | return output 56 | 57 | def _bn(x, var_list, train_phase, name='bn_'): 58 | """ 59 | Batch normalization on convolutional maps. 60 | Args: 61 | 62 | Return: 63 | """ 64 | n_out = x.get_shape().as_list()[3] 65 | with tf.variable_scope(name): 66 | beta = tf.get_variable('beta', shape=[n_out], dtype=tf.float32, initializer=tf.constant_initializer(0.0)) 67 | gamma = tf.get_variable('gamma', shape=[n_out], dtype=tf.float32, initializer=tf.constant_initializer(1.0)) 68 | var_list.append(beta) 69 | var_list.append(gamma) 70 | batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments') 71 | ema = tf.train.ExponentialMovingAverage(decay=0.9) 72 | 73 | def mean_var_with_update(): 74 | ema_apply_op = ema.apply([batch_mean, batch_var]) 75 | with tf.control_dependencies([ema_apply_op]): 76 | return tf.identity(batch_mean), tf.identity(batch_var) 77 | 78 | mean, var = tf.cond(train_phase, 79 | mean_var_with_update, 80 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 81 | normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3) 82 | 83 | return normed 84 | 85 | def _residual_block(x, trainable_vars, train_phase, apply_relu=True, name="unit"): 86 | """ 87 | ResNet block when the number of channels across the skip connections are the same 88 | """ 89 | in_channels = x.get_shape().as_list()[-1] 90 | with tf.variable_scope(name) as scope: 91 | shortcut = x 92 | x = _conv(x, 3, in_channels, 1, trainable_vars, name='conv_1') 93 | x = _bn(x, trainable_vars, train_phase, name="bn_1") 94 | x = tf.nn.relu(x) 95 | x = _conv(x, 3, in_channels, 1, trainable_vars, name='conv_2') 96 | x = _bn(x, trainable_vars, train_phase, name="bn_2") 97 | 98 | x = x + shortcut 99 | if apply_relu == True: 100 | x = tf.nn.relu(x) 101 | 102 | return x 103 | 104 | def _residual_block_first(x, out_channels, strides, trainable_vars, train_phase, apply_relu=True, name="unit", is_ATT_DATASET=False): 105 | """ 106 | A generic ResNet Block 107 | """ 108 | in_channels = x.get_shape().as_list()[-1] 109 | with tf.variable_scope(name) as scope: 110 | # Figure out the shortcut connection first 111 | if in_channels == out_channels: 112 | if strides == 1: 113 | shortcut = tf.identity(x) 114 | else: 115 | shortcut = tf.nn.max_pool(x, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID') 116 | else: 117 | shortcut = _conv(x, 1, out_channels, strides, trainable_vars, name="shortcut") 118 | if not is_ATT_DATASET: 119 | shortcut = _bn(shortcut, trainable_vars, train_phase, name="bn_0") 120 | 121 | # Residual block 122 | x = _conv(x, 3, out_channels, strides, trainable_vars, name="conv_1") 123 | x = _bn(x, trainable_vars, train_phase, name="bn_1") 124 | x = tf.nn.relu(x) 125 | x = _conv(x, 3, out_channels, 1, trainable_vars, name="conv_2") 126 | x = _bn(x, trainable_vars, train_phase, name="bn_2") 127 | 128 | x = x + shortcut 129 | if apply_relu: 130 | x = tf.nn.relu(x) 131 | 132 | return x 133 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Define some utility functions 8 | """ 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | def clone_variable_list(variable_list): 13 | """ 14 | Clone the variable list 15 | """ 16 | return [tf.identity(var) for var in variable_list] 17 | 18 | def create_fc_layer(input, w, b, apply_relu=True): 19 | """ 20 | Construct a Fully Connected layer 21 | Args: 22 | w Weights 23 | b Biases 24 | apply_relu Apply relu (T/F)? 25 | 26 | Returns: 27 | Output of an FC layer 28 | """ 29 | with tf.name_scope('fc_layer'): 30 | output = tf.matmul(input, w) + b 31 | 32 | # Apply relu 33 | if apply_relu: 34 | output = tf.nn.relu(output) 35 | 36 | return output 37 | 38 | def create_conv_layer(input, w, b, stride=1, apply_relu=True): 39 | """ 40 | Construct a convolutional layer 41 | Args: 42 | w Weights 43 | b Biases 44 | pre_activations List where the pre_activations will be stored 45 | apply_relu Apply relu (T/F)? 46 | 47 | Returns: 48 | Output of a conv layer 49 | """ 50 | with tf.name_scope('conv_layer'): 51 | # Do the convolution operation 52 | output = tf.nn.conv2d(input, w, [1, stride, stride, 1], padding='SAME') + b 53 | 54 | # Apply relu 55 | if apply_relu: 56 | output = tf.nn.relu(output) 57 | 58 | return output 59 | 60 | def load_task_specific_data_in_proportion(datasets, task_labels, classes_appearing_in_tasks, class_seen_already): 61 | """ 62 | Loads task specific data from the datasets proportionate to classes appearing in different tasks 63 | """ 64 | global_class_indices = np.column_stack(np.nonzero(datasets['labels'])) 65 | count = 0 66 | for cls in task_labels: 67 | if count == 0: 68 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == cls][:,np.array([True, False])]) 69 | total_class_instances = class_indices.size 70 | num_instances_to_choose = total_class_instances // classes_appearing_in_tasks[cls] 71 | offset = (class_seen_already[cls] - 1) * num_instances_to_choose 72 | final_class_indices = class_indices[offset: offset+num_instances_to_choose] 73 | else: 74 | current_class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == cls][:,np.array([True, False])]) 75 | total_class_instances = current_class_indices.size 76 | num_instances_to_choose = total_class_instances // classes_appearing_in_tasks[cls] 77 | offset = (class_seen_already[cls] - 1) * num_instances_to_choose 78 | final_class_indices = np.append(final_class_indices, current_class_indices[offset: offset+num_instances_to_choose]) 79 | count += 1 80 | final_class_indices = np.sort(final_class_indices, axis=None) 81 | return datasets['images'][final_class_indices, :], datasets['labels'][final_class_indices, :] 82 | 83 | 84 | def load_task_specific_data(datasets, task_labels): 85 | """ 86 | Loads task specific data from the datasets 87 | """ 88 | global_class_indices = np.column_stack(np.nonzero(datasets['labels'])) 89 | count = 0 90 | for cls in task_labels: 91 | if count == 0: 92 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == cls][:,np.array([True, False])]) 93 | else: 94 | class_indices = np.append(class_indices, np.squeeze(global_class_indices[global_class_indices[:,1] == cls][:,np.array([True, False])])) 95 | count += 1 96 | class_indices = np.sort(class_indices, axis=None) 97 | return datasets['images'][class_indices, :], datasets['labels'][class_indices, :] 98 | 99 | def samples_for_each_class(dataset_labels, task): 100 | """ 101 | Numbers of samples for each class in the task 102 | Args: 103 | dataset_labels Labels to count samples from 104 | task Labels with in a task 105 | 106 | Returns 107 | """ 108 | num_samples = np.zeros([len(task)], dtype=np.float32) 109 | i = 0 110 | for label in task: 111 | global_class_indices = np.column_stack(np.nonzero(dataset_labels)) 112 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == label][:,np.array([True, False])]) 113 | class_indices = np.sort(class_indices, axis=None) 114 | num_samples[i] = len(class_indices) 115 | i += 1 116 | 117 | return num_samples 118 | 119 | 120 | def get_sample_weights(labels, tasks): 121 | weights = np.zeros([labels.shape[0]], dtype=np.float32) 122 | for label in tasks: 123 | global_class_indices = np.column_stack(np.nonzero(labels)) 124 | class_indices = np.array(np.squeeze(global_class_indices[global_class_indices[:,1] == label][:,np.array([True, False])])) 125 | total_class_samples = class_indices.shape[0] 126 | weights[class_indices] = 1.0/ total_class_samples 127 | 128 | # Rescale the weights such that min is 1. This will make the weights of less observed 129 | # examples 1. 130 | weights /= weights.min() 131 | 132 | return weights 133 | 134 | def update_episodic_memory_with_less_data(task_dataset, importance_array, total_mem_size, task, episodic_images, episodic_labels, task_labels=None, is_herding=False): 135 | """ 136 | Update the episodic memory when the task data is less than the memory size 137 | Args: 138 | 139 | Returns: 140 | """ 141 | num_examples_in_task = task_dataset['images'].shape[0] 142 | # Empty spaces in the episodic memory 143 | empty_spaces = np.sum(np.sum(episodic_labels, axis=1) == 0) 144 | if empty_spaces >= num_examples_in_task: 145 | # Find where the empty spaces are in order 146 | empty_indices = np.where(np.sum(episodic_labels, axis=1) == 0)[0] 147 | # Store the whole task data in the episodic memory 148 | episodic_images[empty_indices[:num_examples_in_task]] = task_dataset['images'] 149 | episodic_labels[empty_indices[:num_examples_in_task]] = task_dataset['labels'] 150 | elif empty_spaces == 0: 151 | # Compute the amount of space in the episodic memory for the new task 152 | space_for_new_task = total_mem_size// (task + 1) # task 0, 1, ... 153 | # Get the indices to update in the episodic memory 154 | eps_mem_indices = np.random.choice(total_mem_size, space_for_new_task, replace=False) # Sample without replacement 155 | # Get the indices of important samples from the task dataset 156 | label_importance = importance_array + 1e-32 157 | label_importance /= np.sum(label_importance) # Convert to a probability distribution 158 | task_mem_indices = np.random.choice(num_examples_in_task, space_for_new_task, p=label_importance, replace=False) # Sample without replacement 159 | # Update the episodic memory 160 | episodic_images[eps_mem_indices] = task_dataset['images'][task_mem_indices] 161 | episodic_labels[eps_mem_indices] = task_dataset['labels'][task_mem_indices] 162 | else: 163 | # When there is some free space but not enough to store the whole task 164 | # Find where the empty spaces are in order 165 | empty_indices = np.where(np.sum(episodic_labels, axis=1) == 0)[0] 166 | # Store some of the examples from task in the memory 167 | episodic_images[empty_indices] = task_dataset['images'][:len(empty_indices)] 168 | episodic_labels[empty_indices] = task_dataset['labels'][:len(empty_indices)] 169 | # Adjust the remanining samples in the episodic memory 170 | space_for_new_task = (total_mem_size // (task + 1)) - len(empty_indices) # task 0, 1, ... 171 | # Get the indices to update in the episodic memory 172 | eps_mem_indices = np.random.choice((total_mem_size - len(empty_indices)), space_for_new_task, replace=False) # Sample without replacement 173 | # Get the indices of important samples from the task dataset 174 | label_importance = importance_array[len(empty_indices):] + 1e-32 175 | label_importance /= np.sum(label_importance) # Convert to a probability distribution 176 | updated_num_examples_in_task = num_examples_in_task - len(empty_indices) 177 | task_mem_indices = np.random.choice(updated_num_examples_in_task, space_for_new_task, p=label_importance, replace=False) # Sample without replacement 178 | task_mem_indices += len(empty_indices) # Add the offset 179 | # Update the episodic memory 180 | episodic_images[eps_mem_indices] = task_dataset['images'][task_mem_indices] 181 | episodic_labels[eps_mem_indices] = task_dataset['labels'][task_mem_indices] 182 | 183 | def update_episodic_memory(task_dataset, importance_array, total_mem_size, task, episodic_images, episodic_labels, task_labels=None, is_herding=False): 184 | """ 185 | Update the episodic memory with new task data 186 | Args: 187 | 188 | Reruns: 189 | """ 190 | num_examples_in_task = task_dataset['images'].shape[0] 191 | # Compute the amount of space in the episodic memory for the new task 192 | space_for_new_task = total_mem_size// (task + 1) # task 0, 1, ... 193 | # Get the indices to update in the episodic memory 194 | eps_mem_indices = np.random.choice(total_mem_size, space_for_new_task, replace=False) # Sample without replacement 195 | if is_herding and task_labels is not None: 196 | # Get the samples based on herding 197 | imp_images, imp_labels = sample_from_dataset_icarl(task_dataset, importance_array, task_labels, space_for_new_task//len(task_labels)) 198 | episodic_images[eps_mem_indices[np.arange(imp_images.shape[0])]] = imp_images 199 | episodic_labels[eps_mem_indices[np.arange(imp_images.shape[0])]] = imp_labels 200 | else: 201 | # Get the indices of important samples from the task dataset 202 | label_importance = importance_array + 1e-32 203 | label_importance /= np.sum(label_importance) # Convert to a probability distribution 204 | task_mem_indices = np.random.choice(num_examples_in_task, space_for_new_task, p=label_importance, replace=False) # Sample without replacement 205 | # Update the episodic memory 206 | episodic_images[eps_mem_indices] = task_dataset['images'][task_mem_indices] 207 | episodic_labels[eps_mem_indices] = task_dataset['labels'][task_mem_indices] 208 | 209 | def sample_from_dataset(dataset, importance_array, task, samples_count, preds=None): 210 | """ 211 | Samples from a dataset based on a probability distribution 212 | Args: 213 | dataset Dataset to sample from 214 | importance_array Importance scores (not necessarily have to be a prob distribution) 215 | task Labels with in a task 216 | samples_count Number of samples to return 217 | 218 | Return: 219 | images Important images 220 | labels Important labels 221 | """ 222 | 223 | count = 0 224 | # For each label in the task extract the important samples 225 | for label in task: 226 | global_class_indices = np.column_stack(np.nonzero(dataset['labels'])) 227 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == label][:,np.array([True, False])]) 228 | class_indices = np.sort(class_indices, axis=None) 229 | 230 | if (preds is not None): 231 | # Find the indices where prediction match the correct label 232 | pred_indices = np.where(preds == label)[0] 233 | 234 | # Find the correct prediction indices 235 | correct_pred_indices = np.intersect1d(pred_indices, class_indices) 236 | 237 | else: 238 | correct_pred_indices = class_indices 239 | 240 | # Extract the importance for the label 241 | label_importance = importance_array[correct_pred_indices] + 1e-32 242 | label_importance /= np.sum(label_importance) 243 | 244 | actual_samples_count = min(samples_count, np.count_nonzero(label_importance)) 245 | #print('Storing {} samples from {} class'.format(actual_samples_count, label)) 246 | 247 | # If no samples are correctly classified then skip saving the samples 248 | if (actual_samples_count != 0): 249 | 250 | # Extract the important indices 251 | imp_indices = np.random.choice(correct_pred_indices, actual_samples_count, p=label_importance, replace=False) 252 | 253 | if count == 0: 254 | images = dataset['images'][imp_indices] 255 | labels = dataset['labels'][imp_indices] 256 | else: 257 | images = np.vstack((images, dataset['images'][imp_indices])) 258 | labels = np.vstack((labels, dataset['labels'][imp_indices])) 259 | 260 | count += 1 261 | 262 | if count != 0: 263 | return images, labels 264 | else: 265 | return None, None 266 | 267 | def concatenate_datasets(current_images, current_labels, prev_images, prev_labels): 268 | """ 269 | Concatnates current dataset with the previous one. This will be used for 270 | adding important samples from the previous datasets 271 | Args: 272 | current_images Images of current dataset 273 | current_labels Labels of current dataset 274 | prev_images List containing images of previous datasets 275 | prev_labels List containing labels of previous datasets 276 | 277 | Returns: 278 | images Concatenated images 279 | labels Concatenated labels 280 | """ 281 | """ 282 | images = current_images 283 | labels = current_labels 284 | for i in range(len(prev_images)): 285 | images = np.vstack((images, prev_images[i])) 286 | labels = np.vstack((labels, prev_labels[i])) 287 | """ 288 | images = np.concatenate((current_images, prev_images), axis=0) 289 | labels = np.concatenate((current_labels, prev_labels), axis=0) 290 | 291 | return images, labels 292 | 293 | 294 | def sample_from_dataset_icarl(dataset, features, task, samples_count, preds=None): 295 | """ 296 | Samples from a dataset based on a icarl - mean of features 297 | Args: 298 | dataset Dataset to sample from 299 | features Features - activation before the last layer 300 | task Labels with in a task 301 | samples_count Number of samples to return 302 | 303 | Return: 304 | images Important images 305 | labels Important labels 306 | """ 307 | 308 | print('Herding based sampling!') 309 | #samples_count = min(samples_count, dataset['images'].shape[0]) 310 | count = 0 311 | # For each label in the task extract the important samples 312 | for label in task: 313 | global_class_indices = np.column_stack(np.nonzero(dataset['labels'])) 314 | class_indices = np.squeeze(global_class_indices[global_class_indices[:,1] == label][:,np.array([True, False])]) 315 | class_indices = np.sort(class_indices, axis=None) 316 | 317 | if (preds is not None): 318 | # Find the indices where prediction match the correct label 319 | pred_indices = np.where(preds == label)[0] 320 | 321 | # Find the correct prediction indices 322 | correct_pred_indices = np.intersect1d(pred_indices, class_indices) 323 | 324 | else: 325 | correct_pred_indices = class_indices 326 | 327 | mean_feature = np.mean(features[correct_pred_indices, :], axis=0) 328 | 329 | actual_samples_count = min(samples_count, len(correct_pred_indices)) 330 | 331 | # If no samples are correctly classified then skip saving the samples 332 | imp_indices = np.zeros(actual_samples_count, dtype=np.int32) 333 | sample_sum= np.zeros(mean_feature.shape) 334 | if (actual_samples_count != 0): 335 | # Extract the important indices 336 | for i in range(actual_samples_count): 337 | sample_mean = (features[correct_pred_indices, :] + 338 | np.tile(sample_sum, [len(correct_pred_indices),1]))/ float(i + 1) 339 | norm_distance = np.linalg.norm((np.tile(mean_feature, [len(correct_pred_indices),1]) 340 | - sample_mean), ord=2, axis=1) 341 | imp_indices[i] = correct_pred_indices[np.argmin(norm_distance)] 342 | sample_sum = sample_sum + features[imp_indices[i], :] 343 | 344 | if count == 0: 345 | images = dataset['images'][imp_indices] 346 | labels = dataset['labels'][imp_indices] 347 | else: 348 | images = np.vstack((images, dataset['images'][imp_indices])) 349 | labels = np.vstack((labels, dataset['labels'][imp_indices])) 350 | 351 | count += 1 352 | 353 | if count != 0: 354 | return images, labels 355 | else: 356 | return None, None 357 | 358 | def average_acc_stats_across_runs(data, key): 359 | """ 360 | Compute the average accuracy statistics (mean and std) across runs 361 | """ 362 | num_runs = data.shape[0] 363 | avg_acc = np.zeros(num_runs) 364 | for i in range(num_runs): 365 | avg_acc[i] = np.mean(data[i][-1]) 366 | 367 | return avg_acc.mean()*100, avg_acc.std()*100 368 | 369 | def average_fgt_stats_across_runs(data, key): 370 | """ 371 | Compute the forgetting statistics (mean and std) across runs 372 | """ 373 | num_runs = data.shape[0] 374 | fgt = np.zeros(num_runs) 375 | wst_fgt = np.zeros(num_runs) 376 | for i in range(num_runs): 377 | fgt[i] = compute_fgt(data[i]) 378 | 379 | return fgt.mean(), fgt.std() 380 | 381 | def compute_fgt(data): 382 | """ 383 | Given a TxT data matrix, compute average forgetting at T-th task 384 | """ 385 | num_tasks = data.shape[0] 386 | T = num_tasks - 1 387 | fgt = 0.0 388 | for i in range(T): 389 | fgt += np.max(data[:T,i]) - data[T, i] 390 | 391 | avg_fgt = fgt/ float(num_tasks - 1) 392 | return avg_fgt 393 | 394 | def update_reservior(current_image, current_label, episodic_images, episodic_labels, M, N): 395 | """ 396 | Update the episodic memory with current example using the reservior sampling 397 | """ 398 | if M > N: 399 | episodic_images[N] = current_image 400 | episodic_labels[N] = current_label 401 | else: 402 | j = np.random.randint(0, N) 403 | if j < M: 404 | episodic_images[j] = current_image 405 | episodic_labels[j] = current_label 406 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/vgg_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | def vgg_conv_layer(x, kernel_size, out_channels, stride, var_list, pad="SAME", name="conv"): 11 | """ 12 | Define API for conv operation. This includes kernel declaration and 13 | conv operation both followed by relu. 14 | """ 15 | in_channels = x.get_shape().as_list()[-1] 16 | with tf.variable_scope(name): 17 | #n = kernel_size * kernel_size * out_channels 18 | n = kernel_size * in_channels 19 | stdv = 1.0 / math.sqrt(n) 20 | w = tf.get_variable('kernel_weights', [kernel_size, kernel_size, in_channels, out_channels], 21 | tf.float32, 22 | initializer=tf.random_uniform_initializer(-stdv, stdv)) 23 | b = tf.get_variable('kernel_biases', [out_channels], tf.float32, initializer=tf.random_uniform_initializer(-stdv, stdv)) 24 | 25 | # Append the variable to the trainable variables list 26 | var_list.append(w) 27 | var_list.append(b) 28 | 29 | # Do the convolution operation 30 | bias = tf.nn.bias_add(tf.nn.conv2d(x, w, [1, stride, stride, 1], padding=pad), b) 31 | relu = tf.nn.relu(bias) 32 | return relu 33 | 34 | def vgg_fc_layer(x, out_dim, var_list, apply_relu=True, name="fc"): 35 | """ 36 | Define API for the fully connected layer. This includes both the variable 37 | declaration and matmul operation. 38 | """ 39 | in_dim = x.get_shape().as_list()[1] 40 | stdv = 1.0 / math.sqrt(in_dim) 41 | with tf.variable_scope(name): 42 | # Define the weights and biases for this layer 43 | w = tf.get_variable('weights', [in_dim, out_dim], tf.float32, 44 | initializer=tf.random_uniform_initializer(-stdv, stdv)) 45 | b = tf.get_variable('biases', [out_dim], tf.float32, initializer=tf.random_uniform_initializer(-stdv, stdv)) 46 | 47 | # Append the variable to the trainable variables list 48 | var_list.append(w) 49 | var_list.append(b) 50 | 51 | # Do the FC operation 52 | output = tf.matmul(x, w) + b 53 | 54 | # Apply relu if needed 55 | if apply_relu: 56 | output = tf.nn.relu(output) 57 | 58 | return output 59 | -------------------------------------------------------------------------------- /external_libs/continual_learning_algorithms/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Define some utility functions 8 | """ 9 | import numpy as np 10 | 11 | import matplotlib 12 | matplotlib.use('agg') 13 | import matplotlib.colors as colors 14 | import matplotlib.cm as cmx 15 | import matplotlib.pyplot as plt 16 | import matplotlib.figure as figure 17 | from six.moves import cPickle as pickle 18 | 19 | def snapshot_experiment_eval(logdir, experiment_id, data): 20 | """ 21 | Store the output of the experiment in a file 22 | """ 23 | snapshot_file = logdir + '/' + experiment_id + '.pickle' 24 | with open(snapshot_file, 'wb') as f: 25 | pickle.dump(data, f) 26 | 27 | print('Experimental Eval has been snapshotted to %s!'%(snapshot_file)) 28 | 29 | def snapshot_task_labels(logdir, experiment_id, data): 30 | """ 31 | Store the output of the experiment in a file 32 | """ 33 | snapshot_file = logdir + '/' + experiment_id + '_task_labels.pickle' 34 | with open(snapshot_file, 'wb') as f: 35 | pickle.dump(data, f) 36 | 37 | print('Experimental Eval has been snapshotted to %s!'%(snapshot_file)) 38 | 39 | def snapshot_experiment_meta_data(logdir, experiment_id, exper_meta_data): 40 | """ 41 | Store the meta-data of the experiment in a file 42 | """ 43 | meta_file = logdir + '/' + experiment_id + '.txt' 44 | with open(meta_file, 'w') as f: 45 | for key in exper_meta_data: 46 | print('{}: {}'.format(key, exper_meta_data[key])) 47 | f.write('{}:{} \n'.format(key, exper_meta_data[key])) 48 | 49 | print('Experimental meta-data has been snapshotted to %s!'%(meta_file)) 50 | 51 | def plot_acc_multiple_runs(data, task_labels, valid_measures, n_stats, plot_name=None): 52 | """ 53 | Plots the accuracies 54 | Args: 55 | task_labels List of tasks 56 | n_stats Number of runs 57 | plot_name Name of the file where the plot will be saved 58 | 59 | Returns: 60 | """ 61 | n_tasks = len(task_labels) 62 | plt.figure(figsize=(14, 3)) 63 | axs = [plt.subplot(1,n_tasks+1,1)] 64 | for i in range(1, n_tasks + 1): 65 | axs.append(plt.subplot(1, n_tasks+1, i+1, sharex=axs[0], sharey=axs[0])) 66 | 67 | fmt_chars = ['o', 's', 'd'] 68 | fmts = [] 69 | for i in range(len(valid_measures)): 70 | fmts.append(fmt_chars[i%len(fmt_chars)]) 71 | 72 | plot_keys = sorted(data['mean'].keys()) 73 | 74 | for k, cval in enumerate(plot_keys): 75 | label = "c=%g"%cval 76 | mean_vals = data['mean'][cval] 77 | std_vals = data['std'][cval] 78 | for j in range(n_tasks+1): 79 | plt.sca(axs[j]) 80 | errorbar_kwargs = dict(fmt="%s-"%fmts[k], markersize=5) 81 | if j < n_tasks: 82 | norm= np.sqrt(n_stats) # np.sqrt(n_stats) for SEM or 1 for STDEV 83 | axs[j].errorbar(np.arange(n_tasks)+1, mean_vals[:, j], yerr=std_vals[:, j]/norm, label=label, **errorbar_kwargs) 84 | else: 85 | mean_stuff = [] 86 | std_stuff = [] 87 | for i in range(len(data['mean'][cval])): 88 | mean_stuff.append(data['mean'][cval][i][:i+1].mean()) 89 | std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats))) 90 | plt.errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, label="%s"%valid_measures[k], **errorbar_kwargs) 91 | plt.xticks(np.arange(n_tasks)+1) 92 | plt.xlim((1.0,5.5)) 93 | """ 94 | # Uncomment this if clutter along y-axis needs to be removed 95 | if j == 0: 96 | axs[j].set_yticks([0.5,1]) 97 | else: 98 | plt.setp(axs[j].get_yticklabels(), visible=False) 99 | plt.ylim((0.45,1.1)) 100 | """ 101 | 102 | for i, ax in enumerate(axs): 103 | if i < n_tasks: 104 | ax.set_title((['Task %d (%d to %d)'%(j+1,task_labels[j][0], task_labels[j][-1])\ 105 | for j in range(n_tasks)] + ['average'])[i], fontsize=8) 106 | else: 107 | ax.set_title("Average", fontsize=8) 108 | ax.axhline(0.5, color='k', linestyle=':', label="chance", zorder=0) 109 | 110 | handles, labels = axs[-1].get_legend_handles_labels() 111 | 112 | # Reorder legend so chance is last 113 | axs[-1].legend([handles[j] for j in [i for i in range(len(valid_measures)+1)]], 114 | [labels[j] for j in [i for i in range(len(valid_measures)+1)]], loc='best', fontsize=6) 115 | 116 | axs[0].set_xlabel("Tasks") 117 | axs[0].set_ylabel("Accuracy") 118 | plt.gcf().tight_layout() 119 | plt.grid('on') 120 | if plot_name == None: 121 | plt.show() 122 | else: 123 | plt.savefig(plot_name) 124 | 125 | def plot_histogram(data, n_bins=10, plot_name='my_hist'): 126 | plt.hist(data, bins=n_bins) 127 | plt.savefig(plot_name) 128 | plt.close() 129 | -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Noah Golmant 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/__init__.py: -------------------------------------------------------------------------------- 1 | """ Top-level module for hessian eigenvec computation 2 | This library is cited in our paper. 3 | """ 4 | from .power_iter import power_iteration, deflated_power_iteration 5 | from .lanczos import lanczos 6 | from .hvp_operator import HVPOperator, compute_hessian_eigenthings 7 | 8 | __all__ = [ 9 | "power_iteration", 10 | "deflated_power_iteration", 11 | "lanczos", 12 | "HVPOperator", 13 | "compute_hessian_eigenthings", 14 | ] 15 | 16 | name = "hessian_eigenthings" 17 | -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/hvp_operator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines a linear operator to compute the hessian-vector product 3 | for a given pytorch model using subsampled data. 4 | """ 5 | import torch 6 | from .power_iter import Operator, deflated_power_iteration 7 | from .lanczos import lanczos 8 | 9 | 10 | class HVPOperator(Operator): 11 | """ 12 | Use PyTorch autograd for Hessian Vec product calculation 13 | model: PyTorch network to compute hessian for 14 | dataloader: pytorch dataloader that we get examples from to compute grads 15 | loss: Loss function to descend (e.g. F.cross_entropy) 16 | use_gpu: use cuda or not 17 | max_samples: max number of examples per batch using all GPUs. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | model, 23 | dataloader, 24 | criterion, 25 | use_gpu=True, 26 | full_dataset=True, 27 | max_samples=256, 28 | ): 29 | size = int(sum(p.numel() for p in model.parameters())) 30 | super(HVPOperator, self).__init__(size) 31 | self.grad_vec = torch.zeros(size) 32 | self.model = model 33 | if use_gpu: 34 | self.model = self.model.cuda() 35 | self.dataloader = dataloader 36 | # Make a copy since we will go over it a bunch 37 | self.dataloader_iter = iter(dataloader) 38 | self.criterion = criterion 39 | self.use_gpu = use_gpu 40 | self.full_dataset = full_dataset 41 | self.max_samples = max_samples 42 | 43 | def apply(self, vec): 44 | """ 45 | Returns H*vec where H is the hessian of the loss w.r.t. 46 | the vectorized model parameters 47 | """ 48 | if self.full_dataset: 49 | return self._apply_full(vec) 50 | else: 51 | return self._apply_batch(vec) 52 | 53 | def _apply_batch(self, vec): 54 | # compute original gradient, tracking computation graph 55 | self.zero_grad() 56 | grad_vec = self.prepare_grad() 57 | self.zero_grad() 58 | # take the second gradient 59 | grad_grad = torch.autograd.grad( 60 | grad_vec, self.model.parameters(), grad_outputs=vec, only_inputs=True 61 | ) 62 | # concatenate the results over the different components of the network 63 | hessian_vec_prod = torch.cat([g.contiguous().view(-1) for g in grad_grad]) 64 | return hessian_vec_prod 65 | 66 | def _apply_full(self, vec): 67 | n = len(self.dataloader) 68 | hessian_vec_prod = None 69 | for _ in range(n): 70 | if hessian_vec_prod is not None: 71 | hessian_vec_prod += self._apply_batch(vec) 72 | else: 73 | hessian_vec_prod = self._apply_batch(vec) 74 | hessian_vec_prod = hessian_vec_prod / n 75 | return hessian_vec_prod 76 | 77 | def zero_grad(self): 78 | """ 79 | Zeros out the gradient info for each parameter in the model 80 | """ 81 | for p in self.model.parameters(): 82 | if p.grad is not None: 83 | p.grad.data.zero_() 84 | 85 | def prepare_grad(self): 86 | """ 87 | Compute gradient w.r.t loss over all parameters and vectorize 88 | """ 89 | try: 90 | all_inputs, all_targets = next(self.dataloader_iter) 91 | except StopIteration: 92 | self.dataloader_iter = iter(self.dataloader) 93 | all_inputs, all_targets = next(self.dataloader_iter) 94 | 95 | num_chunks = max(1, len(all_inputs) // self.max_samples) 96 | 97 | grad_vec = None 98 | 99 | input_chunks = all_inputs.chunk(num_chunks) 100 | target_chunks = all_targets.chunk(num_chunks) 101 | for input, target in zip(input_chunks, target_chunks): 102 | if self.use_gpu: 103 | input = input.cuda() 104 | target = target.cuda() 105 | 106 | output = self.model(input) 107 | loss = self.criterion(output, target) 108 | grad_dict = torch.autograd.grad( 109 | loss, self.model.parameters(), create_graph=True 110 | ) 111 | if grad_vec is not None: 112 | grad_vec += torch.cat([g.contiguous().view(-1) for g in grad_dict]) 113 | else: 114 | grad_vec = torch.cat([g.contiguous().view(-1) for g in grad_dict]) 115 | grad_vec /= num_chunks 116 | self.grad_vec = grad_vec 117 | return self.grad_vec 118 | 119 | 120 | def compute_hessian_eigenthings( 121 | model, 122 | dataloader, 123 | loss, 124 | num_eigenthings=10, 125 | full_dataset=True, 126 | mode="power_iter", 127 | use_gpu=True, 128 | max_samples=512, 129 | **kwargs 130 | ): 131 | """ 132 | Computes the top `num_eigenthings` eigenvalues and eigenvecs 133 | for the hessian of the given model by using subsampled power iteration 134 | with deflation and the hessian-vector product 135 | 136 | Parameters 137 | --------------- 138 | 139 | model : Module 140 | pytorch model for this netowrk 141 | dataloader : torch.data.DataLoader 142 | dataloader with x,y pairs for which we compute the loss. 143 | loss : torch.nn.modules.Loss | torch.nn.functional criterion 144 | loss function to differentiate through 145 | num_eigenthings : int 146 | number of eigenvalues/eigenvecs to compute. computed in order of 147 | decreasing eigenvalue magnitude. 148 | full_dataset : boolean 149 | if true, each power iteration call evaluates the gradient over the 150 | whole dataset. 151 | mode : str ['power_iter', 'lanczos'] 152 | which backend to use to compute the top eigenvalues. 153 | use_gpu: 154 | if true, attempt to use cuda for all lin alg computatoins 155 | max_samples: 156 | the maximum number of samples that can fit on-memory. used 157 | to accumulate gradients for large batches. 158 | **kwargs: 159 | contains additional parameters passed onto lanczos or power_iter. 160 | """ 161 | hvp_operator = HVPOperator( 162 | model, 163 | dataloader, 164 | loss, 165 | use_gpu=use_gpu, 166 | full_dataset=full_dataset, 167 | max_samples=max_samples, 168 | ) 169 | eigenvals, eigenvecs = None, None 170 | if mode == "power_iter": 171 | eigenvals, eigenvecs = deflated_power_iteration( 172 | hvp_operator, num_eigenthings, use_gpu=use_gpu, **kwargs 173 | ) 174 | elif mode == "lanczos": 175 | eigenvals, eigenvecs = lanczos( 176 | hvp_operator, num_eigenthings, use_gpu=use_gpu, **kwargs 177 | ) 178 | else: 179 | raise ValueError("Unsupported mode %s (must be power_iter or lanczos)" % mode) 180 | return eigenvals, eigenvecs 181 | -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/lanczos.py: -------------------------------------------------------------------------------- 1 | """ Use scipy/ARPACK implicitly restarted lanczos to find top k eigenthings """ 2 | import numpy as np 3 | import torch 4 | from scipy.sparse.linalg import LinearOperator as ScipyLinearOperator 5 | from scipy.sparse.linalg import eigsh 6 | from warnings import warn 7 | 8 | 9 | def lanczos( 10 | operator, 11 | num_eigenthings=10, 12 | which="LM", 13 | max_steps=20, 14 | tol=1e-6, 15 | num_lanczos_vectors=None, 16 | init_vec=None, 17 | use_gpu=False, 18 | ): 19 | """ 20 | Use the scipy.sparse.linalg.eigsh hook to the ARPACK lanczos algorithm 21 | to find the top k eigenvalues/eigenvectors. 22 | 23 | Parameters 24 | ------------- 25 | operator: power_iter.Operator 26 | linear operator to solve. 27 | num_eigenthings : int 28 | number of eigenvalue/eigenvector pairs to compute 29 | which : str ['LM', SM', 'LA', SA'] 30 | L,S = largest, smallest. M, A = in magnitude, algebriac 31 | SM = smallest in magnitude. LA = largest algebraic. 32 | max_steps : int 33 | maximum number of arnoldi updates 34 | tol : float 35 | relative accuracy of eigenvalues / stopping criterion 36 | num_lanczos_vectors : int 37 | number of lanczos vectors to compute. if None, > 2*num_eigenthings 38 | init_vec: [torch.Tensor, torch.cuda.Tensor] 39 | if None, use random tensor. this is the init vec for arnoldi updates. 40 | use_gpu: bool 41 | if true, use cuda tensors. 42 | 43 | Returns 44 | ---------------- 45 | eigenvalues : np.ndarray 46 | array containing `num_eigenthings` eigenvalues of the operator 47 | eigenvectors : np.ndarray 48 | array containing `num_eigenthings` eigenvectors of the operator 49 | """ 50 | if isinstance(operator.size, int): 51 | size = operator.size 52 | else: 53 | size = operator.size[0] 54 | shape = (size, size) 55 | 56 | if num_lanczos_vectors is None: 57 | num_lanczos_vectors = min(2 * num_eigenthings, size - 1) 58 | if num_lanczos_vectors < 2 * num_eigenthings: 59 | warn( 60 | "[lanczos] number of lanczos vectors should usually be > 2*num_eigenthings" 61 | ) 62 | 63 | def _scipy_apply(x): 64 | x = torch.from_numpy(x) 65 | if use_gpu: 66 | x = x.cuda() 67 | return operator.apply(x.float()).cpu().numpy() 68 | 69 | scipy_op = ScipyLinearOperator(shape, _scipy_apply) 70 | if init_vec is None: 71 | init_vec = np.random.rand(size) 72 | elif isinstance(init_vec, torch.Tensor): 73 | init_vec = init_vec.cpu().numpy() 74 | eigenvals, eigenvecs = eigsh( 75 | A=scipy_op, 76 | k=num_eigenthings, 77 | which=which, 78 | maxiter=max_steps, 79 | tol=tol, 80 | ncv=num_lanczos_vectors, 81 | return_eigenvectors=True, 82 | ) 83 | return eigenvals, eigenvecs.T 84 | -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/power_iter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions to perform power iteration with deflation 3 | to compute the top eigenvalues and eigenvectors of a linear operator 4 | """ 5 | import numpy as np 6 | import torch 7 | 8 | from .utils import log, progress_bar 9 | 10 | 11 | class Operator: 12 | """ 13 | maps x -> Lx for a linear operator L 14 | """ 15 | 16 | def __init__(self, size): 17 | self.size = size 18 | 19 | def apply(self, vec): 20 | """ 21 | Function mapping vec -> L vec where L is a linear operator 22 | """ 23 | raise NotImplementedError 24 | 25 | 26 | class LambdaOperator(Operator): 27 | """ 28 | Linear operator based on a provided lambda function 29 | """ 30 | 31 | def __init__(self, apply_fn, size): 32 | super(LambdaOperator, self).__init__(size) 33 | self.apply_fn = apply_fn 34 | 35 | def apply(self, x): 36 | return self.apply_fn(x) 37 | 38 | 39 | def deflated_power_iteration( 40 | operator, 41 | num_eigenthings=10, 42 | power_iter_steps=20, 43 | power_iter_err_threshold=1e-4, 44 | momentum=0.0, 45 | use_gpu=True, 46 | to_numpy=True, 47 | ): 48 | """ 49 | Compute top k eigenvalues by repeatedly subtracting out dyads 50 | operator: linear operator that gives us access to matrix vector product 51 | num_eigenvals number of eigenvalues to compute 52 | power_iter_steps: number of steps per run of power iteration 53 | power_iter_err_threshold: early stopping threshold for power iteration 54 | returns: np.ndarray of top eigenvalues, np.ndarray of top eigenvectors 55 | """ 56 | eigenvals = [] 57 | eigenvecs = [] 58 | current_op = operator 59 | prev_vec = None 60 | 61 | def _deflate(x, val, vec): 62 | return val * vec.dot(x) * vec 63 | 64 | log("beginning deflated power iteration") 65 | for i in range(num_eigenthings): 66 | log("computing eigenvalue/vector %d of %d" % (i + 1, num_eigenthings)) 67 | eigenval, eigenvec = power_iteration( 68 | current_op, 69 | power_iter_steps, 70 | power_iter_err_threshold, 71 | momentum=momentum, 72 | use_gpu=use_gpu, 73 | init_vec=prev_vec, 74 | ) 75 | log("eigenvalue %d: %.4f" % (i + 1, eigenval)) 76 | 77 | def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec): 78 | return op.apply(x) - _deflate(x, val, vec) 79 | 80 | current_op = LambdaOperator(_new_op_fn, operator.size) 81 | prev_vec = eigenvec 82 | eigenvals.append(eigenval) 83 | eigenvec = eigenvec.cpu() 84 | if to_numpy: 85 | eigenvecs.append(eigenvec.numpy()) 86 | else: 87 | eigenvecs.append(eigenvec) 88 | 89 | eigenvals = np.array(eigenvals) 90 | eigenvecs = np.array(eigenvecs) 91 | 92 | # sort them in descending order 93 | sorted_inds = np.argsort(eigenvals) 94 | eigenvals = eigenvals[sorted_inds][::-1] 95 | eigenvecs = eigenvecs[sorted_inds][::-1] 96 | return eigenvals, eigenvecs 97 | 98 | 99 | def power_iteration( 100 | operator, steps=20, error_threshold=1e-4, momentum=0.0, use_gpu=True, init_vec=None 101 | ): 102 | """ 103 | Compute dominant eigenvalue/eigenvector of a matrix 104 | operator: linear Operator giving us matrix-vector product access 105 | steps: number of update steps to take 106 | returns: (principal eigenvalue, principal eigenvector) pair 107 | """ 108 | vector_size = operator.size # input dimension of operator 109 | if init_vec is None: 110 | vec = torch.rand(vector_size) 111 | else: 112 | vec = init_vec 113 | 114 | if use_gpu: 115 | vec = vec.cuda() 116 | 117 | prev_lambda = 0.0 118 | prev_vec = torch.randn_like(vec) 119 | for i in range(steps): 120 | prev_vec = vec / (torch.norm(vec) + 1e-6) 121 | new_vec = operator.apply(vec) - momentum * prev_vec 122 | # need to handle case where we end up in the nullspace of the operator. 123 | # in this case, we are done. 124 | if torch.sum(new_vec).item() == 0.0: 125 | return 0.0, new_vec 126 | lambda_estimate = vec.dot(new_vec).item() 127 | diff = lambda_estimate - prev_lambda 128 | vec = new_vec.detach() / torch.norm(new_vec) 129 | if lambda_estimate == 0.0: # for low-rank 130 | error = 1.0 131 | else: 132 | error = np.abs(diff / lambda_estimate) 133 | progress_bar(i, steps, "power iter error: %.4f" % error) 134 | if error < error_threshold: 135 | return lambda_estimate, vec 136 | prev_lambda = lambda_estimate 137 | 138 | return lambda_estimate, vec 139 | -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/spectral_density.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def _lanczos_step(vec, size, current_draw): 6 | 7 | pass 8 | 9 | 10 | def lanczos( 11 | operator, 12 | max_steps=20, 13 | tol=1e-6, 14 | num_lanczos_vectors=None, 15 | init_vec=None, 16 | use_gpu=False, 17 | ): 18 | """ 19 | Use the scipy.sparse.linalg.eigsh hook to the ARPACK lanczos algorithm 20 | to find the top k eigenvalues/eigenvectors. 21 | 22 | Parameters 23 | ------------- 24 | operator: power_iter.Operator 25 | linear operator to solve. 26 | num_eigenthings : int 27 | number of eigenvalue/eigenvector pairs to compute 28 | which : str ['LM', SM', 'LA', SA'] 29 | L,S = largest, smallest. M, A = in magnitude, algebriac 30 | SM = smallest in magnitude. LA = largest algebraic. 31 | max_steps : int 32 | maximum number of arnoldi updates 33 | tol : float 34 | relative accuracy of eigenvalues / stopping criterion 35 | num_lanczos_vectors : int 36 | number of lanczos vectors to compute. if None, > 2*num_eigenthings 37 | init_vec: [torch.Tensor, torch.cuda.Tensor] 38 | if None, use random tensor. this is the init vec for arnoldi updates. 39 | use_gpu: bool 40 | if true, use cuda tensors. 41 | 42 | Returns 43 | ---------------- 44 | eigenvalues : np.ndarray 45 | array containing `num_eigenthings` eigenvalues of the operator 46 | eigenvectors : np.ndarray 47 | array containing `num_eigenthings` eigenvectors of the operator 48 | """ 49 | if isinstance(operator.size, int): 50 | size = operator.size 51 | else: 52 | size = operator.size[0] 53 | 54 | if num_lanczos_vectors is None: 55 | num_lanczos_vectors = min(2 * num_eigenthings, size - 1) 56 | if num_lanczos_vectors < 2 * num_eigenthings: 57 | warn( 58 | "[lanczos] number of lanczos vectors should usually be > 2*num_eigenthings" 59 | ) 60 | 61 | return eigenvals 62 | -------------------------------------------------------------------------------- /external_libs/hessian_eigenthings/utils.py: -------------------------------------------------------------------------------- 1 | """ small helpers """ 2 | import shutil 3 | import sys 4 | import time 5 | 6 | TOTAL_BAR_LENGTH = 65.0 7 | 8 | term_width = shutil.get_terminal_size().columns 9 | 10 | 11 | def log(msg): 12 | # TODO make this an actual logger lol 13 | print("[hessian_eigenthings] " + str(msg)) 14 | 15 | 16 | last_time = time.time() 17 | begin_time = last_time 18 | 19 | 20 | def format_time(seconds): 21 | """ converts seconds into day-hour-minute-second-ms string format """ 22 | days = int(seconds / 3600 / 24) 23 | seconds = seconds - days * 3600 * 24 24 | hours = int(seconds / 3600) 25 | seconds = seconds - hours * 3600 26 | minutes = int(seconds / 60) 27 | seconds = seconds - minutes * 60 28 | secondsf = int(seconds) 29 | seconds = seconds - secondsf 30 | millis = int(seconds * 1000) 31 | 32 | f = "" 33 | i = 1 34 | if days > 0: 35 | f += str(days) + "D" 36 | i += 1 37 | if hours > 0 and i <= 2: 38 | f += str(hours) + "h" 39 | i += 1 40 | if minutes > 0 and i <= 2: 41 | f += str(minutes) + "m" 42 | i += 1 43 | if secondsf > 0 and i <= 2: 44 | f += str(secondsf) + "s" 45 | i += 1 46 | if millis > 0 and i <= 2: 47 | f += str(millis) + "ms" 48 | i += 1 49 | if f == "": 50 | f = "0ms" 51 | return f 52 | 53 | 54 | def progress_bar(current, total, msg=None): 55 | """ handy utility to display an updating progress bar... 56 | percentage completed is computed as current/total 57 | 58 | from: https://github.com/noahgolmant/skeletor/blob/master/skeletor/utils.py 59 | """ 60 | global last_time, begin_time # pylint: disable=global-statement 61 | if current == 0: 62 | begin_time = time.time() # Reset for new bar. 63 | 64 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 65 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 66 | 67 | sys.stdout.write(" [") 68 | for _ in range(cur_len): 69 | sys.stdout.write("=") 70 | sys.stdout.write(">") 71 | for _ in range(rest_len): 72 | sys.stdout.write(".") 73 | sys.stdout.write("]") 74 | 75 | cur_time = time.time() 76 | step_time = cur_time - last_time 77 | last_time = cur_time 78 | tot_time = cur_time - begin_time 79 | 80 | L = [] 81 | L.append(" Step: %s" % format_time(step_time)) 82 | L.append(" | Tot: %s" % format_time(tot_time)) 83 | if msg: 84 | L.append(" | " + msg) 85 | 86 | msg = "".join(L) 87 | sys.stdout.write(msg) 88 | for _ in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 89 | sys.stdout.write(" ") 90 | 91 | # Go back to the center of the bar. 92 | for _ in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 93 | sys.stdout.write("\b") 94 | sys.stdout.write(" %d/%d " % (current + 1, total)) 95 | 96 | if current < total - 1: 97 | sys.stdout.write("\r") 98 | else: 99 | sys.stdout.write("\n") 100 | sys.stdout.flush() 101 | -------------------------------------------------------------------------------- /replicate_appendix_c5.sh: -------------------------------------------------------------------------------- 1 | echo "Make sure you have tensorflow==1.15.2 installed." 2 | cd ./external_libs/continual_learning_algorithms 3 | bash replicate_mnist_stable.sh 4 | -------------------------------------------------------------------------------- /replicate_experiment_1.sh: -------------------------------------------------------------------------------- 1 | echo "************************ replicating experiment 1 (rotated MNIST) ***********************" 2 | echo " >>>>>>>> Plastic (Naive) SGD " 3 | python -m stable_sgd.main --dataset rot-mnist --tasks 5 --epochs-per-task 5 --lr 0.01 --gamma 1.0 --hiddens 100 --batch-size 64 --dropout 0.0 --seed 1234 4 | python -m stable_sgd.main --dataset rot-mnist --tasks 5 --epochs-per-task 5 --lr 0.01 --gamma 1.0 --hiddens 100 --batch-size 64 --dropout 0.0 --seed 4567 5 | python -m stable_sgd.main --dataset rot-mnist --tasks 5 --epochs-per-task 5 --lr 0.01 --gamma 1.0 --hiddens 100 --batch-size 64 --dropout 0.0 --seed 7891 6 | 7 | echo " >>>>>>>> Stable SGD " 8 | python -m stable_sgd.main --dataset rot-mnist --tasks 5 --epochs-per-task 5 --lr 0.1 --gamma 0.4 --hiddens 100 --batch-size 16 --dropout 0.5 --seed 1234 9 | python -m stable_sgd.main --dataset rot-mnist --tasks 5 --epochs-per-task 5 --lr 0.1 --gamma 0.4 --hiddens 100 --batch-size 16 --dropout 0.5 --seed 4567 10 | python -m stable_sgd.main --dataset rot-mnist --tasks 5 --epochs-per-task 5 --lr 0.1 --gamma 0.4 --hiddens 100 --batch-size 16 --dropout 0.5 --seed 7891 11 | 12 | 13 | 14 | echo "************************ replicating experiment 1 (permuted MNIST) ***********************" 15 | echo " >>>>>>>> Plastic (Naive) SGD " 16 | python -m stable_sgd.main --dataset perm-mnist --tasks 5 --epochs-per-task 5 --lr 0.01 --gamma 1.0 --hiddens 100 --batch-size 64 --dropout 0.0 --seed 1234 17 | python -m stable_sgd.main --dataset perm-mnist --tasks 5 --epochs-per-task 5 --lr 0.01 --gamma 1.0 --hiddens 100 --batch-size 64 --dropout 0.0 --seed 4567 18 | python -m stable_sgd.main --dataset perm-mnist --tasks 5 --epochs-per-task 5 --lr 0.01 --gamma 1.0 --hiddens 100 --batch-size 64 --dropout 0.0 --seed 7891 19 | 20 | echo " >>>>>>>> Stable SGD " 21 | python -m stable_sgd.main --dataset perm-mnist --tasks 5 --epochs-per-task 5 --lr 0.1 --gamma 0.4 --hiddens 100 --batch-size 16 --dropout 0.5 --seed 1234 22 | python -m stable_sgd.main --dataset perm-mnist --tasks 5 --epochs-per-task 5 --lr 0.1 --gamma 0.4 --hiddens 100 --batch-size 16 --dropout 0.5 --seed 4567 23 | python -m stable_sgd.main --dataset perm-mnist --tasks 5 --epochs-per-task 5 --lr 0.1 --gamma 0.4 --hiddens 100 --batch-size 16 --dropout 0.5 --seed 7891 24 | -------------------------------------------------------------------------------- /replicate_experiment_2.sh: -------------------------------------------------------------------------------- 1 | echo "************************ replicating experiment 2 (rotated MNIST) ***********************" 2 | echo " >>>>>>>> Plastic (Naive) SGD " 3 | python -m stable_sgd.main --dataset rot-mnist --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 1234 4 | python -m stable_sgd.main --dataset rot-mnist --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --sedd 4567 5 | python -m stable_sgd.main --dataset rot-mnist --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 7891 6 | 7 | echo "" 8 | echo " >>>>>>>> Stable SGD " 9 | python -m stable_sgd.main --dataset rot-mnist --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.5 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 1234 10 | python -m stable_sgd.main --dataset rot-mnist --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.5 --hiddens 256 --batch-size 10 --dropout 0.5 --sedd 4567 11 | python -m stable_sgd.main --dataset rot-mnist --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.5 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 7891 12 | 13 | echo "" 14 | echo ">>>>>>>>> Other Methods (ER, A-GEM, EWC)" 15 | echo "Make sure you have tensorflow==1.12 installed. (see the readme doc)" 16 | cd ./external_libs/continual_learning_algorithms 17 | bash replicate_mnist.sh rot-mnist 18 | cd ../.. 19 | 20 | echo "************************ replicating experiment 2 (permuted MNIST) ***********************" 21 | echo " >>>>>>>> Plastic (Naive) SGD " 22 | python -m stable_sgd.main --dataset perm-mnist --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 1234 23 | python -m stable_sgd.main --dataset perm-mnist --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 4567 24 | python -m stable_sgd.main --dataset perm-mnist --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 7891 25 | 26 | echo "" 27 | echo " >>>>>>>> Stable SGD " 28 | python -m stable_sgd.main --dataset perm-mnist --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 1234 29 | python -m stable_sgd.main --dataset perm-mnist --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 4567 30 | python -m stable_sgd.main --dataset perm-mnist --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 7891 31 | 32 | echo "" 33 | echo ">>>>>>>>> Other Methods (ER, A-GEM, EWC)" 34 | echo "Make sure you have tensorflow==1.12 installed. (see the readme doc)" 35 | cd ./external_libs/continual_learning_algorithms 36 | bash replicate_mnist.sh perm-mnist 37 | cd ../.. 38 | 39 | 40 | 41 | echo "************************ replicating experiment 2 (Split CIFAR-100) ***********************" 42 | echo " >>>>>>>> Plastic (Naive) SGD " 43 | python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 1234 44 | python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 4567 45 | python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.01 --gamma 1.0 --hiddens 256 --batch-size 10 --dropout 0.0 --seed 7891 46 | 47 | echo "" 48 | echo " >>>>>>>> Stable SGD " 49 | python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 1234 50 | python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 4567 51 | python -m stable_sgd.main --dataset cifar100 --tasks 20 --epochs-per-task 1 --lr 0.1 --gamma 0.8 --hiddens 256 --batch-size 10 --dropout 0.5 --seed 7891 52 | 53 | echo "" 54 | echo ">>>>>>>>> Other Methods (ER, A-GEM, EWC)" 55 | echo "Make sure you have tensorflow==1.12 installed. (see the readme doc)" 56 | cd ./external_libs/continual_learning_algorithms 57 | bash replicate_cifar.sh 58 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | future==0.18.2 3 | kiwisolver==1.2.0 4 | matplotlib==3.2.1 5 | numpy==1.18.5 6 | pandas==1.0.4 7 | Pillow==7.1.2 8 | pyparsing==2.4.7 9 | python-dateutil==2.8.1 10 | pytz==2020.1 11 | scipy==1.4.1 12 | seaborn==0.10.1 13 | six==1.15.0 14 | torch==1.5.0 15 | torchvision==0.6.0 16 | opencv-contrib-python==4.2.0.34 17 | opencv-python==4.2.0.34 18 | tensorflow==1.15.4 19 | tensorflow-gpu==1.15.4 20 | -------------------------------------------------------------------------------- /setup_and_install.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | cd external_libs/continual_learning_algorithms 3 | bash run.sh 4 | cd ../.. 5 | echo "Setup was successful!" -------------------------------------------------------------------------------- /stable_sgd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imirzadeh/stable-continual-learning/be4049c6f1c433e6d4593bf3113280dcbeb127b9/stable_sgd/__init__.py -------------------------------------------------------------------------------- /stable_sgd/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from torch.utils.data import TensorDataset, DataLoader 5 | import torchvision.transforms.functional as TorchVisionFunc 6 | 7 | 8 | def get_permuted_mnist(task_id, batch_size): 9 | """ 10 | Get the dataset loaders (train and test) for a `single` task of permuted MNIST. 11 | This function will be called several times for each task. 12 | 13 | :param task_id: id of the task [starts from 1] 14 | :param batch_size: 15 | :return: a tuple: (train loader, test loader) 16 | """ 17 | 18 | # convention, the first task will be the original MNIST images, and hence no permutation 19 | if task_id == 1: 20 | idx_permute = np.array(range(784)) 21 | else: 22 | idx_permute = torch.from_numpy(np.random.RandomState().permutation(784)) 23 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 24 | torchvision.transforms.Lambda(lambda x: x.view(-1)[idx_permute] ), 25 | ]) 26 | mnist_train = torchvision.datasets.MNIST('./data/', train=True, download=True, transform=transforms) 27 | train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=True) 28 | test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=transforms), batch_size=256, shuffle=False, num_workers=4, pin_memory=True) 29 | 30 | return train_loader, test_loader 31 | 32 | 33 | def get_permuted_mnist_tasks(num_tasks, batch_size): 34 | """ 35 | Returns the datasets for sequential tasks of permuted MNIST 36 | 37 | :param num_tasks: number of tasks. 38 | :param batch_size: batch-size for loaders. 39 | :return: a dictionary where each key is a dictionary itself with train, and test loaders. 40 | """ 41 | datasets = {} 42 | for task_id in range(1, num_tasks+1): 43 | train_loader, test_loader = get_permuted_mnist(task_id, batch_size) 44 | datasets[task_id] = {'train': train_loader, 'test': test_loader} 45 | return datasets 46 | 47 | 48 | class RotationTransform: 49 | """ 50 | Rotation transforms for the images in `Rotation MNIST` dataset. 51 | """ 52 | def __init__(self, angle): 53 | self.angle = angle 54 | 55 | def __call__(self, x): 56 | return TorchVisionFunc.rotate(x, self.angle, fill=(0,)) 57 | 58 | 59 | def get_rotated_mnist(task_id, batch_size): 60 | """ 61 | Returns the dataset for a single task of Rotation MNIST dataset 62 | :param task_id: 63 | :param batch_size: 64 | :return: 65 | """ 66 | per_task_rotation = 10 67 | rotation_degree = (task_id - 1)*per_task_rotation 68 | rotation_degree -= (np.random.random()*per_task_rotation) 69 | 70 | transforms = torchvision.transforms.Compose([ 71 | RotationTransform(rotation_degree), 72 | torchvision.transforms.ToTensor(), 73 | ]) 74 | 75 | train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, transform=transforms), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) 76 | test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=transforms), batch_size=256, shuffle=False, num_workers=4, pin_memory=True) 77 | 78 | return train_loader, test_loader 79 | 80 | 81 | def get_rotated_mnist_tasks(num_tasks, batch_size): 82 | """ 83 | Returns data loaders for all tasks of rotation MNIST dataset. 84 | :param num_tasks: number of tasks in the benchmark. 85 | :param batch_size: 86 | :return: 87 | """ 88 | datasets = {} 89 | for task_id in range(1, num_tasks+1): 90 | train_loader, test_loader = get_rotated_mnist(task_id, batch_size) 91 | datasets[task_id] = {'train': train_loader, 'test': test_loader} 92 | return datasets 93 | 94 | 95 | def get_split_cifar100(task_id, batch_size, cifar_train, cifar_test): 96 | """ 97 | Returns a single task of split CIFAR-100 dataset 98 | :param task_id: 99 | :param batch_size: 100 | :return: 101 | """ 102 | 103 | 104 | start_class = (task_id-1)*5 105 | end_class = task_id * 5 106 | 107 | targets_train = torch.tensor(cifar_train.targets) 108 | target_train_idx = ((targets_train >= start_class) & (targets_train < end_class)) 109 | 110 | targets_test = torch.tensor(cifar_test.targets) 111 | target_test_idx = ((targets_test >= start_class) & (targets_test < end_class)) 112 | 113 | train_loader = torch.utils.data.DataLoader(torch.utils.data.dataset.Subset(cifar_train, np.where(target_train_idx==1)[0]), batch_size=batch_size, shuffle=True) 114 | test_loader = torch.utils.data.DataLoader(torch.utils.data.dataset.Subset(cifar_test, np.where(target_test_idx==1)[0]), batch_size=batch_size) 115 | 116 | return train_loader, test_loader 117 | 118 | 119 | def get_split_cifar100_tasks(num_tasks, batch_size): 120 | """ 121 | Returns data loaders for all tasks of split CIFAR-100 122 | :param num_tasks: 123 | :param batch_size: 124 | :return: 125 | """ 126 | datasets = {} 127 | 128 | # convention: tasks starts from 1 not 0 ! 129 | # task_id = 1 (i.e., first task) => start_class = 0, end_class = 4 130 | cifar_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),]) 131 | cifar_train = torchvision.datasets.CIFAR100('./data/', train=True, download=True, transform=cifar_transforms) 132 | cifar_test = torchvision.datasets.CIFAR100('./data/', train=False, download=True, transform=cifar_transforms) 133 | 134 | for task_id in range(1, num_tasks+1): 135 | train_loader, test_loader = get_split_cifar100(task_id, batch_size, cifar_train, cifar_test) 136 | datasets[task_id] = {'train': train_loader, 'test': test_loader} 137 | return datasets 138 | 139 | # if __name__ == "__main__": 140 | # dataset = get_split_cifar100(1) 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /stable_sgd/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import torch.nn as nn 6 | from stable_sgd.models import MLP, ResNet18 7 | from stable_sgd.data_utils import get_permuted_mnist_tasks, get_rotated_mnist_tasks, get_split_cifar100_tasks 8 | from stable_sgd.utils import parse_arguments, DEVICE, init_experiment, end_experiment, log_metrics, log_hessian, save_checkpoint 9 | 10 | 11 | def train_single_epoch(net, optimizer, loader, criterion, task_id=None): 12 | """ 13 | Train the model for a single epoch 14 | 15 | :param net: 16 | :param optimizer: 17 | :param loader: 18 | :param criterion: 19 | :param task_id: 20 | :return: 21 | """ 22 | net = net.to(DEVICE) 23 | net.train() 24 | 25 | for batch_idx, (data, target) in enumerate(loader): 26 | data = data.to(DEVICE) 27 | target = target.to(DEVICE) 28 | optimizer.zero_grad() 29 | if task_id: 30 | pred = net(data, task_id) 31 | else: 32 | pred = net(data) 33 | loss = criterion(pred, target) 34 | loss.backward() 35 | optimizer.step() 36 | return net 37 | 38 | 39 | def eval_single_epoch(net, loader, criterion, task_id=None): 40 | """ 41 | Evaluate the model for single epoch 42 | 43 | :param net: 44 | :param loader: 45 | :param criterion: 46 | :param task_id: 47 | :return: 48 | """ 49 | net = net.to(DEVICE) 50 | net.eval() 51 | test_loss = 0 52 | correct = 0 53 | with torch.no_grad(): 54 | for data, target in loader: 55 | data = data.to(DEVICE) 56 | target = target.to(DEVICE) 57 | # for cifar head 58 | if task_id is not None: 59 | output = net(data, task_id) 60 | else: 61 | output = net(data) 62 | test_loss += criterion(output, target).item() 63 | pred = output.data.max(1, keepdim=True)[1] 64 | correct += pred.eq(target.data.view_as(pred)).sum() 65 | test_loss /= len(loader.dataset) 66 | correct = correct.to('cpu') 67 | avg_acc = 100.0 * float(correct.numpy()) / len(loader.dataset) 68 | return {'accuracy': avg_acc, 'loss': test_loss} 69 | 70 | 71 | def get_benchmark_data_loader(args): 72 | """ 73 | Returns the benchmark loader which could be either of these: 74 | get_split_cifar100_tasks, get_permuted_mnist_tasks, or get_rotated_mnist_tasks 75 | 76 | :param args: 77 | :return: a function which when called, returns all tasks 78 | """ 79 | if args.dataset == 'perm-mnist' or args.dataset == 'permuted-mnist': 80 | return get_permuted_mnist_tasks 81 | elif args.dataset == 'rot-mnist' or args.dataset == 'rotation-mnist': 82 | return get_rotated_mnist_tasks 83 | elif args.dataset == 'cifar-100' or args.dataset == 'cifar100': 84 | return get_split_cifar100_tasks 85 | else: 86 | raise Exception("Unknown dataset.\n"+ 87 | "The code supports 'perm-mnist, rot-mnist, and cifar-100.") 88 | 89 | 90 | def get_benchmark_model(args): 91 | """ 92 | Return the corresponding PyTorch model for experiment 93 | :param args: 94 | :return: 95 | """ 96 | if 'mnist' in args.dataset: 97 | if args.tasks == 20 and args.hiddens < 256: 98 | print("Warning! the main paper MLP with 256 neurons for experiment with 20 tasks") 99 | return MLP(args.hiddens, {'dropout': args.dropout}).to(DEVICE) 100 | elif 'cifar' in args.dataset: 101 | return ResNet18(config={'dropout': args.dropout}).to(DEVICE) 102 | else: 103 | raise Exception("Unknown dataset.\n"+ 104 | "The code supports 'perm-mnist, rot-mnist, and cifar-100.") 105 | 106 | 107 | def run(args): 108 | """ 109 | Run a single run of experiment. 110 | 111 | :param args: please see `utils.py` for arguments and options 112 | """ 113 | # init experiment 114 | acc_db, loss_db, hessian_eig_db = init_experiment(args) 115 | 116 | # load benchmarks and model 117 | print("Loading {} tasks for {}".format(args.tasks, args.dataset)) 118 | tasks = get_benchmark_data_loader(args)(args.tasks, args.batch_size) 119 | print("loaded all tasks!") 120 | model = get_benchmark_model(args) 121 | 122 | # criterion 123 | criterion = nn.CrossEntropyLoss().to(DEVICE) 124 | time = 0 125 | 126 | for current_task_id in range(1, args.tasks+1): 127 | print("================== TASK {} / {} =================".format(current_task_id, args.tasks)) 128 | train_loader = tasks[current_task_id]['train'] 129 | lr = max(args.lr * args.gamma ** (current_task_id), 0.00005) 130 | 131 | for epoch in range(1, args.epochs_per_task+1): 132 | # 1. train and save 133 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.8) 134 | train_single_epoch(model, optimizer, train_loader, criterion, current_task_id) 135 | time += 1 136 | 137 | # 2. evaluate on all tasks up to now, including the current task 138 | for prev_task_id in range(1, current_task_id+1): 139 | # 2.0. only evaluate once a task is finished 140 | if epoch == args.epochs_per_task: 141 | model = model.to(DEVICE) 142 | val_loader = tasks[prev_task_id]['test'] 143 | 144 | # 2.1. compute accuracy and loss 145 | metrics = eval_single_epoch(model, val_loader, criterion, prev_task_id) 146 | acc_db, loss_db = log_metrics(metrics, time, prev_task_id, acc_db, loss_db) 147 | 148 | # 2.2. (optional) compute eigenvalues and eigenvectors of Loss Hessian 149 | if prev_task_id == current_task_id and args.compute_eigenspectrum: 150 | hessian_eig_db = log_hessian(model, val_loader, time, prev_task_id, hessian_eig_db) 151 | 152 | # 2.3. save model parameters 153 | save_checkpoint(model, time) 154 | 155 | end_experiment(args, acc_db, loss_db, hessian_eig_db) 156 | 157 | 158 | if __name__ == "__main__": 159 | args = parse_arguments() 160 | run(args) -------------------------------------------------------------------------------- /stable_sgd/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import relu, avg_pool2d 4 | 5 | 6 | class MLP(nn.Module): 7 | """ 8 | Two layer MLP for MNIST benchmarks. 9 | """ 10 | def __init__(self, hiddens, config): 11 | super(MLP, self).__init__() 12 | self.W1 = nn.Linear(784, hiddens) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.dropout_1 = nn.Dropout(p=config['dropout']) 15 | self.W2 = nn.Linear(hiddens, hiddens) 16 | self.dropout_2 = nn.Dropout(p=config['dropout']) 17 | self.W3 = nn.Linear(hiddens, 10) 18 | 19 | def forward(self, x, task_id=None): 20 | x = x.view(-1, 784) 21 | out = self.W1(x) 22 | out = self.relu(out) 23 | out = self.dropout_1(out) 24 | out = self.W2(out) 25 | out = self.relu(out) 26 | out = self.dropout_2(out) 27 | out = self.W3(out) 28 | return out 29 | 30 | 31 | def conv3x3(in_planes, out_planes, stride=1): 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=1, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, in_planes, planes, stride=1, config={}): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(in_planes, planes, stride) 42 | self.conv2 = conv3x3(planes, planes) 43 | 44 | self.shortcut = nn.Sequential() 45 | if stride != 1 or in_planes != self.expansion * planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 48 | stride=stride, bias=False), 49 | ) 50 | self.IC1 = nn.Sequential( 51 | nn.BatchNorm2d(planes), 52 | nn.Dropout(p=config['dropout']) 53 | ) 54 | 55 | self.IC2 = nn.Sequential( 56 | nn.BatchNorm2d(planes), 57 | nn.Dropout(p=config['dropout']) 58 | ) 59 | 60 | def forward(self, x): 61 | out = self.conv1(x) 62 | out = relu(out) 63 | out = self.IC1(out) 64 | 65 | out += self.shortcut(x) 66 | out = relu(out) 67 | out = self.IC2(out) 68 | return out 69 | 70 | 71 | class ResNet(nn.Module): 72 | def __init__(self, block, num_blocks, num_classes, nf, config={}): 73 | super(ResNet, self).__init__() 74 | self.in_planes = nf 75 | 76 | self.conv1 = conv3x3(3, nf * 1) 77 | self.bn1 = nn.BatchNorm2d(nf * 1) 78 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1, config=config) 79 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2, config=config) 80 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2, config=config) 81 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2, config=config) 82 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride, config): 85 | strides = [stride] + [1] * (num_blocks - 1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride, config=config)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x, task_id): 93 | bsz = x.size(0) 94 | out = relu(self.bn1(self.conv1(x.view(bsz, 3, 32, 32)))) 95 | out = self.layer1(out) 96 | out = self.layer2(out) 97 | out = self.layer3(out) 98 | out = self.layer4(out) 99 | out = avg_pool2d(out, 4) 100 | out = out.view(out.size(0), -1) 101 | out = self.linear(out) 102 | t = task_id 103 | offset1 = int((t-1) * 5) 104 | offset2 = int(t * 5) 105 | if offset1 > 0: 106 | out[:, :offset1].data.fill_(-10e10) 107 | if offset2 < 100: 108 | out[:, offset2:100].data.fill_(-10e10) 109 | return out 110 | 111 | 112 | def ResNet18(nclasses=100, nf=20, config={}): 113 | net = ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, config=config) 114 | return net 115 | -------------------------------------------------------------------------------- /stable_sgd/utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import torch 3 | import argparse 4 | import matplotlib 5 | import numpy as np 6 | import pandas as pd 7 | matplotlib.use('Agg') 8 | import seaborn as sns 9 | from pathlib import Path 10 | import matplotlib.pyplot as plt 11 | from external_libs.hessian_eigenthings import compute_hessian_eigenthings 12 | 13 | 14 | TRIAL_ID = uuid.uuid4().hex.upper()[0:6] 15 | EXPERIMENT_DIRECTORY = './outputs/{}'.format(TRIAL_ID) 16 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | 18 | 19 | def parse_arguments(): 20 | parser = argparse.ArgumentParser(description='Argument parser') 21 | parser.add_argument('--tasks', default=5, type=int, help='total number of tasks') 22 | parser.add_argument('--epochs-per-task', default=1, type=int, help='epochs per task') 23 | parser.add_argument('--dataset', default='rot-mnist', type=str, help='dataset. options: rot-mnist, perm-mnist, cifar100') 24 | parser.add_argument('--batch-size', default=10, type=int, help='batch-size') 25 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 26 | parser.add_argument('--gamma', default=0.4, type=float, help='learning rate decay. Use 1.0 for no decay') 27 | parser.add_argument('--dropout', default=0.25, type=float, help='dropout probability. Use 0.0 for no dropout') 28 | parser.add_argument('--hiddens', default=256, type=int, help='num of hidden neurons in each layer of a 2-layer MLP') 29 | parser.add_argument('--compute-eigenspectrum', default=False, type=bool, help='compute eigenvalues/eigenvectors?') 30 | parser.add_argument('--seed', default=1234, type=int, help='random seed') 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def init_experiment(args): 37 | print('------------------- Experiment started -----------------') 38 | print(f"Parameters:\n seed={args.seed}\n benchmark={args.dataset}\n num_tasks={args.tasks}\n "+ 39 | f"epochs_per_task={args.epochs_per_task}\n batch_size={args.batch_size}\n "+ 40 | f"learning_rate={args.lr}\n learning rate decay(gamma)={args.gamma}\n dropout prob={args.dropout}\n") 41 | 42 | # 1. setup seed for reproducibility 43 | torch.manual_seed(args.seed) 44 | np.random.seed(args.seed) 45 | 46 | # 2. create directory to save results 47 | Path(EXPERIMENT_DIRECTORY).mkdir(parents=True, exist_ok=True) 48 | print("The results will be saved in {}\n".format(EXPERIMENT_DIRECTORY)) 49 | 50 | # 3. create data structures to store metrics 51 | loss_db = {t: [0 for i in range(args.tasks*args.epochs_per_task)] for t in range(1, args.tasks+1)} 52 | acc_db = {t: [0 for i in range(args.tasks*args.epochs_per_task)] for t in range(1, args.tasks+1)} 53 | hessian_eig_db = {} 54 | return acc_db, loss_db, hessian_eig_db 55 | 56 | 57 | def end_experiment(args, acc_db, loss_db, hessian_eig_db): 58 | 59 | # 1. save all metrics into csv file 60 | acc_df = pd.DataFrame(acc_db) 61 | acc_df.to_csv(EXPERIMENT_DIRECTORY+'/accs.csv') 62 | visualize_result(acc_df, EXPERIMENT_DIRECTORY+'/accs.png') 63 | 64 | loss_df = pd.DataFrame(loss_db) 65 | loss_df.to_csv(EXPERIMENT_DIRECTORY+'/loss.csv') 66 | visualize_result(loss_df, EXPERIMENT_DIRECTORY+'/loss.png') 67 | 68 | hessian_df = pd.DataFrame(hessian_eig_db) 69 | hessian_df.to_csv(EXPERIMENT_DIRECTORY+'/hessian_eigs.csv') 70 | 71 | # 2. calculate average accuracy and forgetting (c.f. ``evaluation`` section in our paper) 72 | score = np.mean([acc_db[i][-1] for i in acc_db.keys()]) 73 | forget = np.mean([max(acc_db[i])-acc_db[i][-1] for i in range(1, args.tasks)])/100.0 74 | 75 | print('average accuracy = {}, forget = {}'.format(score, forget)) 76 | print() 77 | print('------------------- Experiment ended -----------------') 78 | 79 | 80 | def log_metrics(metrics, time, task_id, acc_db, loss_db): 81 | """ 82 | Log accuracy and loss at different times of training 83 | """ 84 | print('epoch {}, task:{}, metrics: {}'.format(time, task_id, metrics)) 85 | # log to db 86 | acc = metrics['accuracy'] 87 | loss = metrics['loss'] 88 | loss_db[task_id][time-1] = loss 89 | acc_db[task_id][time-1] = acc 90 | return acc_db, loss_db 91 | 92 | 93 | def save_eigenvec(filename, arr): 94 | """ 95 | Save eigenvectors to file 96 | """ 97 | np.save(filename, arr) 98 | 99 | 100 | def log_hessian(model, loader, time, task_id, hessian_eig_db): 101 | """ 102 | Compute and log Hessian for a specific task 103 | 104 | :param model: The PyTorch Model 105 | :param loader: Dataloader [to calculate loss and then Hessian] 106 | :param time: time is a discrete concept regarding epoch. If we have T tasks each with E epoch, 107 | time will be from 0, to (T x E)-1. E.g., if we have 5 tasks with 5 epochs each, then when we finish 108 | task 1, time will be 5. 109 | :param task_id: Task id (to distiniguish between Hessians of different tasks) 110 | :param hessian_eig_db: (The dictionary to store hessians) 111 | :return: 112 | """ 113 | criterion = torch.nn.CrossEntropyLoss().to(DEVICE) 114 | use_gpu = True if DEVICE != 'cpu' else False 115 | est_eigenvals, est_eigenvecs = compute_hessian_eigenthings( 116 | model, 117 | loader, 118 | criterion, 119 | num_eigenthings=3, 120 | power_iter_steps=18, 121 | power_iter_err_threshold=1e-5, 122 | momentum=0, 123 | use_gpu=use_gpu, 124 | ) 125 | key = 'task-{}-epoch-{}'.format(task_id, time-1) 126 | hessian_eig_db[key] = est_eigenvals 127 | save_eigenvec(EXPERIMENT_DIRECTORY+"/{}-vec.npy".format(key), est_eigenvecs) 128 | return hessian_eig_db 129 | 130 | 131 | def save_checkpoint(model, time): 132 | """ 133 | Save checkpoints of model paramters 134 | :param model: pytorch model 135 | :param time: int 136 | """ 137 | filename = '{directory}/model-{trial}-{time}.pth'.format(directory=EXPERIMENT_DIRECTORY, trial=TRIAL_ID, time=time) 138 | torch.save(model.cpu().state_dict(), filename) 139 | 140 | 141 | def visualize_result(df, filename): 142 | ax = sns.lineplot(data=df, dashes=False) 143 | ax.figure.savefig(filename, dpi=250) 144 | plt.close() 145 | --------------------------------------------------------------------------------