├── LICENSE ├── README.md ├── docs └── _static │ ├── acc_mat.png │ ├── asym_illustration(1).png │ ├── att_fun.png │ ├── cil_survey_approaches.png │ ├── facil_logo.jpg │ ├── facil_logo.png │ ├── tb1.png │ └── tb2.png ├── environment.yml ├── requirements.txt ├── scripts ├── script_cifar100.sh └── script_imagenet.sh └── src ├── README.md ├── __pycache__ ├── last_layer_analysis.cpython-36.pyc ├── last_layer_analysis.cpython-38.pyc ├── utils.cpython-36.pyc └── utils.cpython-38.pyc ├── approach ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── ewc.cpython-38.pyc │ ├── incremental_learning.cpython-38.pyc │ ├── lwf.cpython-38.pyc │ ├── olwf.cpython-38.pyc │ ├── olwf_asym.cpython-38.pyc │ ├── olwf_asympost.cpython-38.pyc │ ├── olwf_jsd.cpython-38.pyc │ └── path_integral.cpython-38.pyc ├── bic.py ├── dmc.py ├── eeil.py ├── ewc.py ├── finetuning.py ├── freezing.py ├── icarl.py ├── il2m.py ├── incremental_learning.py ├── joint.py ├── lucir.py ├── lwf.py ├── lwm.py ├── mas.py ├── oewc.py ├── olwf_asym.py ├── olwf_asym_original.py ├── olwf_asympost.py ├── path_integral.py └── r_walk.py ├── datasets ├── README.md ├── __pycache__ │ ├── base_dataset.cpython-36.pyc │ ├── base_dataset.cpython-38.pyc │ ├── data_loader.cpython-36.pyc │ ├── data_loader.cpython-38.pyc │ ├── dataset_config.cpython-36.pyc │ ├── dataset_config.cpython-38.pyc │ ├── exemplars_dataset.cpython-38.pyc │ ├── exemplars_selection.cpython-38.pyc │ ├── memory_dataset.cpython-36.pyc │ └── memory_dataset.cpython-38.pyc ├── base_dataset.py ├── data_loader.py ├── dataset_config.py ├── exemplars_dataset.py ├── exemplars_selection.py └── memory_dataset.py ├── gridsearch.py ├── gridsearch_config.py ├── last_layer_analysis.py ├── loggers ├── README.md ├── __pycache__ │ ├── disk_logger.cpython-38.pyc │ ├── exp_logger.cpython-36.pyc │ ├── exp_logger.cpython-38.pyc │ └── tensorboard_logger.cpython-38.pyc ├── disk_logger.py ├── exp_logger.py └── tensorboard_logger.py ├── main_incremental.py ├── networks ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── early_conv_vit.cpython-38.pyc │ ├── early_conv_vit_net.cpython-38.pyc │ ├── efficient_net.cpython-38.pyc │ ├── lenet.cpython-36.pyc │ ├── lenet.cpython-38.pyc │ ├── mobile_net.cpython-38.pyc │ ├── network.cpython-38.pyc │ ├── ovit.cpython-36.pyc │ ├── ovit.cpython-38.pyc │ ├── ovit_tiny_16_augreg_224.cpython-36.pyc │ ├── ovit_tiny_16_augreg_224.cpython-38.pyc │ ├── resnet32.cpython-36.pyc │ ├── resnet32.cpython-38.pyc │ ├── timm_vit_tiny_16_augreg_224.cpython-38.pyc │ ├── vggnet.cpython-36.pyc │ ├── vggnet.cpython-38.pyc │ ├── vit_original.cpython-38.pyc │ └── vit_tiny_16_augreg_224.cpython-38.pyc ├── early_conv_vit.py ├── early_conv_vit_net.py ├── efficient_net.py ├── fpt.py ├── lenet.py ├── mobile_net.py ├── network.py ├── ovit.py ├── ovit_tiny_16_augreg_224.py ├── pretrained_weights │ └── augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz ├── resnet32.py ├── timm_vit_tiny_16_augreg_224.py ├── vggnet.py ├── vit_original.py └── vit_tiny_16_augreg_224.py ├── test.ipynb ├── test.py ├── tests ├── README.md ├── __init__.py ├── test_bic.py ├── test_dataloader.py ├── test_datasets_transforms.py ├── test_dmc.py ├── test_eeil.py ├── test_ewc.py ├── test_finetuning.py ├── test_fix_bn.py ├── test_freezing.py ├── test_gridsearch.py ├── test_icarl.py ├── test_il2m.py ├── test_joint.py ├── test_last_layer_analysis.py ├── test_loggers.py ├── test_lucir.py ├── test_lwf.py ├── test_lwm.py ├── test_mas.py ├── test_multisoftmax.py ├── test_path_integral.py ├── test_rwalk.py ├── test_stop_at_task.py └── test_warmup.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Marc Masana Castrillo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual Learning with Vision Transformers 2 | 3 | > Update: Our paper wins the best runner-up award at the [3rd CLVision workshop](https://sites.google.com/view/clvision2022/call-for-papers/accepted-papers). 4 | 5 | This repo hosts the official implementation of our CVPR 2022 workshop paper [Towards Exemplar-Free Continual Learning in Vision Transformers: an Account of Attention, Functional and Weight Regularization](https://openaccess.thecvf.com/content/CVPR2022W/CLVision/html/Pelosin_Towards_Exemplar-Free_Continual_Learning_in_Vision_Transformers_An_Account_of_CVPRW_2022_paper.html). 6 | 7 | TLDR; We introduce attentional and functional variants for asymmetric and symmetric Pooled Attention Distillation (PAD) losses in Vision Transformers: 8 |
9 | 10 |
11 | 12 | ## Running the code 13 | 14 |
15 | 16 |
17 | 18 | 19 | Given below are two examples for the asymmetric attentional and functional variants pooling along the height dimension on ImageNet-100. 20 | 1. Attentional variant: 21 | 22 | ```python 23 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asym --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_attentional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' l 24 | ``` 25 | 26 | 2. Functional variant: 27 | ```python 28 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asympost --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_functional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' 29 | ``` 30 | 31 | The corresponding runs for symmetric variants would then be: 32 | 1. Attentional variant: 33 | 34 | ```python 35 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asym --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_attentional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' --sym 36 | ``` 37 | 38 | 2. Functional variant: 39 | ```python 40 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asympost --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_functional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' --sym 41 | ``` 42 | 43 | Other available continual learning approaches with Vision Transformers include: 44 |
45 |

46 | EWC • Finetuning • LwF • PathInt 47 |

48 |
49 | 50 | The detailed scripts for our experiments can be found in `scripts/`. 51 | 52 | ## Cite 53 | If you found our implementation to be useful, feel free to use the citation: 54 | ```bibtex 55 | @InProceedings{Pelosin_Jha_CVPR, 56 | author = {Pelosin, Francesco and Jha, Saurav and Torsello, Andrea and Raducanu, Bogdan and van de Weijer, Joost}, 57 | title = {Towards Exemplar-Free Continual Learning in Vision Transformers: An Account of Attention, Functional and Weight Regularization}, 58 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 59 | month = {June}, 60 | year = {2022}, 61 | pages = {3820-3829} 62 | } 63 | ``` 64 | ## Acknowledgement 65 | This repo is based on [FACIL](https://github.com/mmasana/FACIL). 66 | 67 |
68 | 69 |
70 | -------------------------------------------------------------------------------- /docs/_static/acc_mat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/acc_mat.png -------------------------------------------------------------------------------- /docs/_static/asym_illustration(1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/asym_illustration(1).png -------------------------------------------------------------------------------- /docs/_static/att_fun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/att_fun.png -------------------------------------------------------------------------------- /docs/_static/cil_survey_approaches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/cil_survey_approaches.png -------------------------------------------------------------------------------- /docs/_static/facil_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/facil_logo.jpg -------------------------------------------------------------------------------- /docs/_static/facil_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/facil_logo.png -------------------------------------------------------------------------------- /docs/_static/tb1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/tb1.png -------------------------------------------------------------------------------- /docs/_static/tb2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/tb2.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: FACIL 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - absl-py=0.11.0=pyhd3eb1b0_1 8 | - aiohttp=3.7.3=py38h27cfd23_1 9 | - apipkg=1.5=py38_0 10 | - async-timeout=3.0.1=py38h06a4308_0 11 | - attrs=20.3.0=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - blinker=1.4=py38h06a4308_0 14 | - brotlipy=0.7.0=py38h27cfd23_1003 15 | - c-ares=1.17.1=h27cfd23_0 16 | - ca-certificates=2021.1.19=h06a4308_0 17 | - cachetools=4.2.1=pyhd3eb1b0_0 18 | - certifi=2020.12.5=py38h06a4308_0 19 | - cffi=1.14.4=py38h261ae71_0 20 | - chardet=3.0.4=py38h06a4308_1003 21 | - click=7.1.2=pyhd3eb1b0_0 22 | - cryptography=2.9.2=py38h1ba5d50_0 23 | - cudatoolkit=10.1.243=h6bb024c_0 24 | - cycler=0.10.0=py38_0 25 | - dbus=1.13.18=hb2f20db_0 26 | - execnet=1.8.0=pyhd3eb1b0_0 27 | - expat=2.2.10=he6710b0_2 28 | - fontconfig=2.13.1=h6c09931_0 29 | - freetype=2.10.4=h5ab3b9f_0 30 | - glib=2.66.1=h92f7085_0 31 | - google-auth=1.24.0=pyhd3eb1b0_0 32 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2 33 | - grpcio=1.31.0=py38hf8bcb03_0 34 | - gst-plugins-base=1.14.0=h8213a91_2 35 | - gstreamer=1.14.0=h28cd5cc_2 36 | - icu=58.2=he6710b0_3 37 | - idna=2.10=pyhd3eb1b0_0 38 | - importlib-metadata=2.0.0=py_1 39 | - iniconfig=1.1.1=pyhd3eb1b0_0 40 | - intel-openmp=2020.2=254 41 | - jpeg=9b=h024ee3a_2 42 | - kiwisolver=1.3.1=py38h2531618_0 43 | - lcms2=2.11=h396b838_0 44 | - ld_impl_linux-64=2.33.1=h53a641e_7 45 | - libedit=3.1.20191231=h14c3975_1 46 | - libffi=3.3=he6710b0_2 47 | - libgcc-ng=9.1.0=hdf63c60_0 48 | - libpng=1.6.37=hbc83047_0 49 | - libprotobuf=3.14.0=h8c45485_0 50 | - libstdcxx-ng=9.1.0=hdf63c60_0 51 | - libtiff=4.1.0=h2733197_1 52 | - libuuid=1.0.3=h1bed415_2 53 | - libuv=1.40.0=h7b6447c_0 54 | - libxcb=1.14=h7b6447c_0 55 | - libxml2=2.9.10=hb55368b_3 56 | - lz4-c=1.9.3=h2531618_0 57 | - markdown=3.3.3=py38h06a4308_0 58 | - matplotlib=3.3.2=h06a4308_0 59 | - matplotlib-base=3.3.2=py38h817c723_0 60 | - mkl=2020.2=256 61 | - mkl-service=2.3.0=py38he904b0f_0 62 | - mkl_fft=1.2.0=py38h23d657b_0 63 | - mkl_random=1.1.1=py38h0573a6f_0 64 | - more-itertools=8.6.0=pyhd3eb1b0_0 65 | - multidict=4.7.6=py38h7b6447c_1 66 | - ncurses=6.2=he6710b0_1 67 | - ninja=1.10.2=py38hff7bd54_0 68 | - numpy=1.19.2=py38h54aff64_0 69 | - numpy-base=1.19.2=py38hfa32c7d_0 70 | - oauthlib=3.1.0=py_0 71 | - olefile=0.46=py_0 72 | - openssl=1.1.1i=h27cfd23_0 73 | - packaging=20.9=pyhd3eb1b0_0 74 | - pandas=1.2.1=py38ha9443f7_0 75 | - pcre=8.44=he6710b0_0 76 | - pillow=8.1.0=py38he98fc37_0 77 | - pip=20.3.3=py38h06a4308_0 78 | - pluggy=0.13.1=py38_0 79 | - protobuf=3.14.0=py38h2531618_1 80 | - py=1.10.0=pyhd3eb1b0_0 81 | - pyasn1=0.4.8=py_0 82 | - pyasn1-modules=0.2.8=py_0 83 | - pycparser=2.20=py_2 84 | - pyjwt=2.0.1=py38h06a4308_0 85 | - pyopenssl=20.0.1=pyhd3eb1b0_1 86 | - pyparsing=2.4.7=pyhd3eb1b0_0 87 | - pyqt=5.9.2=py38h05f1152_4 88 | - pysocks=1.7.1=py38h06a4308_0 89 | - pytest=6.2.2=py38h06a4308_2 90 | - pytest-forked=1.3.0=pyhd3eb1b0_0 91 | - pytest-xdist=2.2.0=pyhd3eb1b0_0 92 | - python=3.8.5=h7579374_1 93 | - python-dateutil=2.8.1=pyhd3eb1b0_0 94 | - pytorch=1.7.1=py3.8_cuda10.1.243_cudnn7.6.3_0 95 | - pytz=2021.1=pyhd3eb1b0_0 96 | - qt=5.9.7=h5867ecd_1 97 | - readline=8.1=h27cfd23_0 98 | - requests=2.25.1=pyhd3eb1b0_0 99 | - requests-oauthlib=1.3.0=py_0 100 | - rsa=4.7=pyhd3eb1b0_1 101 | - setuptools=52.0.0=py38h06a4308_0 102 | - sip=4.19.13=py38he6710b0_0 103 | - six=1.15.0=py38h06a4308_0 104 | - sqlite=3.33.0=h62c20be_0 105 | - tensorboard=2.3.0=pyh4dce500_0 106 | - tensorboard-plugin-wit=1.6.0=py_0 107 | - tk=8.6.10=hbc83047_0 108 | - toml=0.10.1=py_0 109 | - torchvision=0.8.2=py38_cu101 110 | - tornado=6.1=py38h27cfd23_0 111 | - typing-extensions=3.7.4.3=hd3eb1b0_0 112 | - typing_extensions=3.7.4.3=pyh06a4308_0 113 | - urllib3=1.26.3=pyhd3eb1b0_0 114 | - werkzeug=1.0.1=pyhd3eb1b0_0 115 | - wheel=0.36.2=pyhd3eb1b0_0 116 | - xz=5.2.5=h7b6447c_0 117 | - yarl=1.5.1=py38h7b6447c_0 118 | - zipp=3.4.0=pyhd3eb1b0_0 119 | - zlib=1.2.11=h7b6447c_3 120 | - zstd=1.4.5=h9ceee32_0 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # NOTE: Previous versions of pytorch and torchvision might also work as well, 2 | # but we haven't test them yet 3 | torch>=1.7.1 4 | torchvision>=0.8.2 5 | matplotlib 6 | numpy 7 | tensorboard -------------------------------------------------------------------------------- /scripts/script_cifar100.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | if [ "$1" != "" ]; then 5 | echo "Running approach: $1" 6 | else 7 | echo "No approach has been assigned." 8 | fi 9 | if [ "$2" != "" ]; then 10 | echo "Running on gpu: $2" 11 | else 12 | echo "No gpu has been assigned." 13 | fi 14 | 15 | PROJECT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && cd .. && pwd )" 16 | SRC_DIR="$PROJECT_DIR/src" 17 | echo "Project dir: $PROJECT_DIR" 18 | echo "Sources dir: $SRC_DIR" 19 | 20 | RESULTS_DIR="$PROJECT_DIR/results" 21 | if [ "$4" != "" ]; then 22 | RESULTS_DIR=$4 23 | else 24 | echo "No results dir is given. Default will be used." 25 | fi 26 | echo "Results dir: $RESULTS_DIR" 27 | 28 | MU=1.0 29 | # LAMBDA=1.0 30 | 31 | for LAMBDA in 0.0 32 | do 33 | for seed in $(seq 1 1 1) 34 | do 35 | python3 -u src/main_incremental.py --datasets cifar100_224 --network OVit_tiny_16_augreg_224 --approach olwf_asym --nepochs 100 --log disk --batch-size 96 --gpu 0 --exp-name 10epochs_sym_pool_layer_6_${SPARSENESS}_lambda_${LAMBDA}_50_cls_6tasks_${seed}_mu_${MU} --after-norm --sym --lr 0.01 --seed ${seed} --lamb $LAMBDA --plast_mu $MU --num-tasks 6 --nc-first-task 50 --lr-patience 15 --int-layer --pool-layers 6 36 | python3 -u src/main_incremental.py --datasets cifar100_224 --network OVit_tiny_16_augreg_224 --approach olwf_asym --nepochs 100 --log disk --batch-size 96 --gpu 0 --exp-name 10epochs_sym_pool_layer_1_${SPARSENESS}_lambda_${LAMBDA}_50_cls_6tasks_${seed}_mu_${MU} --after-norm --sym --lr 0.01 --seed ${seed} --lamb $LAMBDA --plast_mu $MU --num-tasks 6 --nc-first-task 50 --lr-patience 15 --int-layer --pool-layers 1 37 | python3 -u src/main_incremental.py --datasets cifar100_224 --network OVit_tiny_16_augreg_224 --approach olwf_asym --nepochs 100 --log disk --batch-size 96 --gpu 0 --exp-name 10epochs_sym_pool_layer_12_${SPARSENESS}_lambda_${LAMBDA}_50_cls_6tasks_${seed}_mu_${MU} --after-norm --sym --lr 0.01 --seed ${seed} --lamb $LAMBDA --plast_mu $MU --num-tasks 6 --nc-first-task 50 --lr-patience 15 --int-layer --pool-layers 12 38 | 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Framework for Analysis of Class-Incremental Learning 2 | Run the code with: 3 | ``` 4 | python3 -u src/main_incremental.py 5 | ``` 6 | followed by general options: 7 | 8 | * `--gpu`: index of GPU to run the experiment on (default=0) 9 | * `--results-path`: path where results are stored (default='../results') 10 | * `--exp-name`: experiment name (default=None) 11 | * `--seed`: random seed (default=0) 12 | * `--save-models`: save trained models (default=False) 13 | * `--last-layer-analysis`: plot last layer analysis (default=False) 14 | * `--no-cudnn-deterministic`: disable CUDNN deterministic (default=False) 15 | 16 | and specific options for each of the code parts (corresponding to folders): 17 | 18 | * `--approach`: learning approach used (default='finetuning') [[more](approaches/README.md)] 19 | * `--datasets`: dataset or datasets used (default=['cifar100']) [[more](datasets/README.md)] 20 | * `--network`: network architecture used (default='resnet32') [[more](networks/README.md)] 21 | * `--log`: loggers used (default='disk') [[more](loggers/README.md)] 22 | 23 | go to each of their respective readme to see all available options for each of them. 24 | 25 | ## Approaches 26 | Initially, the approaches included in the framework correspond to the ones presented in 27 | _**Class-incremental learning: survey and performance evaluation**_ (preprint , 2020). The regularization-based 28 | approaches are EWC, MAS, PathInt, LwF, LwM and DMC (green). The rehearsal approaches are iCaRL, EEIL and RWalk (blue). 29 | The bias-correction approaches are IL2M, BiC and LUCIR (orange). 30 | 31 | ![alt text](../docs/_static/cil_survey_approaches.png "Survey approaches") 32 | 33 | More approaches will be included in the future. To learn more about them refer to the readme in 34 | [src/approaches](approaches). 35 | 36 | ## Datasets 37 | To learn about the dataset management refer to the readme in [src/datasets](datasets). 38 | 39 | ## Networks 40 | To learn about the different torchvision and custom networks refer to the readme in [src/networks](networks). 41 | 42 | ## GridSearch 43 | We implement the option to use a realistic grid search for hyperparameters which only takes into account the task at 44 | hand, without access to previous or future information not available in the incremental learning scenario. It 45 | corresponds to the one introduced in _**Class-incremental learning: survey and performance evaluation**_. The GridSearch 46 | can be applied by using: 47 | 48 | * `--gridsearch-tasks`: number of tasks to apply GridSearch (-1: all tasks) (default=-1) 49 | 50 | which we recommend to set to the total number of tasks of the experiment for a more realistic setting of the correct 51 | learning rate and possible forgetting-intransigence trade-off. However, since this has a considerable extra 52 | computational cost, it can also be set to the first 3 tasks, which would fix those hyperparameters for the remaining 53 | tasks. Other GridSearch options include: 54 | 55 | * `--gridsearch-config`: configuration file for GridSearch options (default='gridsearch_config') [[more](gridsearch_config.py)] 56 | * `--gridsearch-acc-drop-thr`: GridSearch accuracy drop threshold (default=0.2) 57 | * `--gridsearch-hparam-decay`: GridSearch hyperparameter decay (default=0.5) 58 | * `--gridsearch-max-num-searches`: GridSearch maximum number of hyperparameter search (default=7) 59 | 60 | ## Utils 61 | We have some utility functions added into `utils.py`. -------------------------------------------------------------------------------- /src/__pycache__/last_layer_analysis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/last_layer_analysis.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/last_layer_analysis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/last_layer_analysis.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # list all approaches available 4 | __all__ = list( 5 | map(lambda x: x[:-3], 6 | filter(lambda x: x not in ['__init__.py', 'incremental_learning.py'] and x.endswith('.py'), 7 | os.listdir(os.path.dirname(__file__)) 8 | ) 9 | ) 10 | ) 11 | -------------------------------------------------------------------------------- /src/approach/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/ewc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/ewc.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/incremental_learning.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/incremental_learning.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/lwf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/lwf.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/olwf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/olwf_asym.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf_asym.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/olwf_asympost.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf_asympost.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/olwf_jsd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf_jsd.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/__pycache__/path_integral.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/path_integral.cpython-38.pyc -------------------------------------------------------------------------------- /src/approach/ewc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from argparse import ArgumentParser 4 | 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | from .incremental_learning import Inc_Learning_Appr 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Elastic Weight Consolidation (EWC) approach 11 | described in http://arxiv.org/abs/1612.00796 12 | """ 13 | 14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 16 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred', 17 | fi_num_samples=-1): 18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 20 | exemplars_dataset) 21 | self.lamb = lamb 22 | self.alpha = alpha 23 | self.sampling_type = fi_sampling_type 24 | self.num_samples = fi_num_samples 25 | 26 | # In all cases, we only keep importance weights for the model, but not for the heads. 27 | feat_ext = self.model.model 28 | # Store current parameters as the initial parameters before first task starts 29 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad} 30 | # Store fisher information weight importance 31 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 32 | if p.requires_grad} 33 | 34 | @staticmethod 35 | def exemplars_dataset_class(): 36 | return ExemplarsDataset 37 | 38 | @staticmethod 39 | def extra_parser(args): 40 | """Returns a parser containing the approach specific parameters""" 41 | parser = ArgumentParser() 42 | # Eq. 3: "lambda sets how important the old task is compared to the new one" 43 | parser.add_argument('--lamb', default=5000, type=float, required=False, 44 | help='Forgetting-intransigence trade-off (default=%(default)s)') 45 | # Define how old and new fisher is fused, by default it is a 50-50 fusion 46 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 47 | help='EWC alpha (default=%(default)s)') 48 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False, 49 | choices=['true', 'max_pred', 'multinomial'], 50 | help='Sampling type for Fisher information (default=%(default)s)') 51 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False, 52 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 53 | 54 | return parser.parse_known_args(args) 55 | 56 | def _get_optimizer(self): 57 | """Returns the optimizer""" 58 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 59 | # if there are no exemplars, previous heads are not modified 60 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 61 | else: 62 | params = self.model.parameters() 63 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 64 | 65 | def compute_fisher_matrix_diag(self, trn_loader): 66 | # Store Fisher Information 67 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 68 | if p.requires_grad} 69 | # Compute fisher information for specified number of samples -- rounded to the batch size 70 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 71 | else (len(trn_loader.dataset) // trn_loader.batch_size) 72 | # Do forward and backward pass to compute the fisher information 73 | self.model.train() 74 | for images, targets in itertools.islice(trn_loader, n_samples_batches): 75 | outputs = self.model.forward(images.to(self.device)) 76 | 77 | if self.sampling_type == 'true': 78 | # Use the labels to compute the gradients based on the CE-loss with the ground truth 79 | preds = targets.to(self.device) 80 | elif self.sampling_type == 'max_pred': 81 | # Not use labels and compute the gradients related to the prediction the model has learned 82 | preds = torch.cat(outputs, dim=1).argmax(1).flatten() 83 | elif self.sampling_type == 'multinomial': 84 | # Use a multinomial sampling to compute the gradients 85 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1) 86 | preds = torch.multinomial(probs, len(targets)).flatten() 87 | 88 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds) 89 | self.optimizer.zero_grad() 90 | loss.backward() 91 | # Accumulate all gradients from loss with regularization 92 | for n, p in self.model.model.named_parameters(): 93 | if p.grad is not None: 94 | fisher[n] += p.grad.pow(2) * len(targets) 95 | # Apply mean across all samples 96 | n_samples = n_samples_batches * trn_loader.batch_size 97 | fisher = {n: (p / n_samples) for n, p in fisher.items()} 98 | return fisher 99 | 100 | def train_loop(self, t, trn_loader, val_loader): 101 | """Contains the epochs loop""" 102 | 103 | # add exemplars to train_loader 104 | if len(self.exemplars_dataset) > 0 and t > 0: 105 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 106 | batch_size=trn_loader.batch_size, 107 | shuffle=True, 108 | num_workers=trn_loader.num_workers, 109 | pin_memory=trn_loader.pin_memory) 110 | 111 | # FINETUNING TRAINING -- contains the epochs loop 112 | super().train_loop(t, trn_loader, val_loader) 113 | 114 | # EXEMPLAR MANAGEMENT -- select training subset 115 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 116 | 117 | def post_train_process(self, t, trn_loader): 118 | """Runs after training all the epochs of the task (after the train session)""" 119 | 120 | # Store current parameters for the next task 121 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 122 | 123 | # calculate Fisher information 124 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader) 125 | # merge fisher information, we do not want to keep fisher information for each task in memory 126 | for n in self.fisher.keys(): 127 | # Added option to accumulate fisher over time with a pre-fixed growing alpha 128 | if self.alpha == -1: 129 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 130 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n] 131 | else: 132 | self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n]) 133 | 134 | def criterion(self, t, outputs, targets): 135 | """Returns the loss value""" 136 | loss = 0 137 | if t > 0: 138 | loss_reg = 0 139 | # Eq. 3: elastic weight consolidation quadratic penalty 140 | for n, p in self.model.model.named_parameters(): 141 | if n in self.fisher.keys(): 142 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2 143 | loss += self.lamb * loss_reg 144 | # Current cross-entropy loss -- with exemplars use all heads 145 | if len(self.exemplars_dataset) > 0: 146 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 147 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 148 | -------------------------------------------------------------------------------- /src/approach/finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the finetuning baseline""" 10 | 11 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 12 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 13 | logger=None, exemplars_dataset=None, all_outputs=False): 14 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 15 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 16 | exemplars_dataset) 17 | self.all_out = all_outputs 18 | 19 | @staticmethod 20 | def exemplars_dataset_class(): 21 | return ExemplarsDataset 22 | 23 | @staticmethod 24 | def extra_parser(args): 25 | """Returns a parser containing the approach specific parameters""" 26 | parser = ArgumentParser() 27 | parser.add_argument('--all-outputs', action='store_true', required=False, 28 | help='Allow all weights related to all outputs to be modified (default=%(default)s)') 29 | return parser.parse_known_args(args) 30 | 31 | def _get_optimizer(self): 32 | """Returns the optimizer""" 33 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1 and not self.all_out: 34 | # if there are no exemplars, previous heads are not modified 35 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 36 | else: 37 | params = self.model.parameters() 38 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 39 | 40 | def train_loop(self, t, trn_loader, val_loader): 41 | """Contains the epochs loop""" 42 | 43 | # add exemplars to train_loader 44 | if len(self.exemplars_dataset) > 0 and t > 0: 45 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 46 | batch_size=trn_loader.batch_size, 47 | shuffle=True, 48 | num_workers=trn_loader.num_workers, 49 | pin_memory=trn_loader.pin_memory) 50 | 51 | # FINETUNING TRAINING -- contains the epochs loop 52 | super().train_loop(t, trn_loader, val_loader) 53 | 54 | # EXEMPLAR MANAGEMENT -- select training subset 55 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 56 | 57 | def criterion(self, t, outputs, targets): 58 | """Returns the loss value""" 59 | if self.all_out or len(self.exemplars_dataset) > 0: 60 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 61 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 62 | -------------------------------------------------------------------------------- /src/approach/freezing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the freezing baseline""" 10 | 11 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 12 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 13 | logger=None, exemplars_dataset=None, freeze_after=0, all_outputs=False): 14 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 15 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 16 | exemplars_dataset) 17 | self.freeze_after = freeze_after 18 | self.all_out = all_outputs 19 | 20 | @staticmethod 21 | def exemplars_dataset_class(): 22 | return ExemplarsDataset 23 | 24 | @staticmethod 25 | def extra_parser(args): 26 | """Returns a parser containing the approach specific parameters""" 27 | parser = ArgumentParser() 28 | parser.add_argument('--freeze-after', default=0, type=int, required=False, 29 | help='Freeze model except current head after the specified task (default=%(default)s)') 30 | parser.add_argument('--all-outputs', action='store_true', required=False, 31 | help='Allow all weights related to all outputs to be modified (default=%(default)s)') 32 | return parser.parse_known_args(args) 33 | 34 | def _get_optimizer(self): 35 | """Returns the optimizer""" 36 | return torch.optim.SGD(self._train_parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 37 | 38 | def _has_exemplars(self): 39 | """Returns True in case exemplars are being used""" 40 | return self.exemplars_dataset is not None and len(self.exemplars_dataset) > 0 41 | 42 | def post_train_process(self, t, trn_loader): 43 | """Runs after training all the epochs of the task (after the train session)""" 44 | if t >= self.freeze_after: 45 | self.model.freeze_backbone() 46 | 47 | def train_loop(self, t, trn_loader, val_loader): 48 | """Contains the epochs loop""" 49 | 50 | # add exemplars to train_loader 51 | if t > 0 and self._has_exemplars(): 52 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 53 | batch_size=trn_loader.batch_size, 54 | shuffle=True, 55 | num_workers=trn_loader.num_workers, 56 | pin_memory=trn_loader.pin_memory) 57 | 58 | # FINETUNING TRAINING -- contains the epochs loop 59 | super().train_loop(t, trn_loader, val_loader) 60 | 61 | # EXEMPLAR MANAGEMENT -- select training subset 62 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 63 | 64 | def train_epoch(self, t, trn_loader): 65 | """Runs a single epoch""" 66 | self._model_train(t) 67 | for images, targets in trn_loader: 68 | # Forward current model 69 | outputs = self.model(images.to(self.device)) 70 | loss = self.criterion(t, outputs, targets.to(self.device)) 71 | # Backward 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | torch.nn.utils.clip_grad_norm_(self._train_parameters(), self.clipgrad) 75 | self.optimizer.step() 76 | 77 | def _model_train(self, t): 78 | """Freezes the necessary weights""" 79 | if self.fix_bn and t > 0: 80 | self.model.freeze_bn() 81 | if self.freeze_after >= 0 and t <= self.freeze_after: # non-frozen task - whole model to train 82 | self.model.train() 83 | else: 84 | self.model.model.eval() 85 | if self._has_exemplars(): 86 | # with exemplars - use all heads 87 | for head in self.model.heads: 88 | head.train() 89 | else: 90 | # no exemplars - use current head 91 | self.model.heads[-1].train() 92 | 93 | def _train_parameters(self): 94 | """Includes the necessary weights to the optimizer""" 95 | if len(self.model.heads) <= (self.freeze_after + 1): 96 | return self.model.parameters() 97 | else: 98 | if self._has_exemplars(): 99 | return [p for head in self.model.heads for p in head.parameters()] 100 | else: 101 | return self.model.heads[-1].parameters() 102 | 103 | def criterion(self, t, outputs, targets): 104 | """Returns the loss value""" 105 | if self.all_out or self._has_exemplars(): 106 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 107 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 108 | -------------------------------------------------------------------------------- /src/approach/il2m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the Class Incremental Learning With Dual Memory (IL2M) approach described in 10 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf 11 | """ 12 | 13 | def __init__(self, model, device, nepochs=100, lr=0.1, lr_min=1e-4, lr_factor=3, lr_patience=10, clipgrad=10000, 14 | momentum=0.9, wd=0.0001, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, 15 | eval_on_train=False, logger=None, exemplars_dataset=None): 16 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 17 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 18 | exemplars_dataset) 19 | self.init_classes_means = [] 20 | self.current_classes_means = [] 21 | self.models_confidence = [] 22 | # FLAG to not do scores rectification while finetuning training 23 | self.ft_train = False 24 | 25 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 26 | assert (have_exemplars > 0), 'Error: IL2M needs exemplars.' 27 | 28 | @staticmethod 29 | def exemplars_dataset_class(): 30 | return ExemplarsDataset 31 | 32 | def il2m(self, t, trn_loader): 33 | """Compute and store statistics for score rectification""" 34 | old_classes_number = sum(self.model.task_cls[:t]) 35 | classes_counts = [0 for _ in range(sum(self.model.task_cls))] 36 | models_counts = 0 37 | 38 | # to store statistics for the classes as learned in the current incremental state 39 | self.current_classes_means = [0 for _ in range(old_classes_number)] 40 | # to store statistics for past classes as learned in their initial states 41 | for cls in range(old_classes_number, old_classes_number + self.model.task_cls[t]): 42 | self.init_classes_means.append(0) 43 | # to store statistics for model confidence in different states (i.e. avg top-1 pred scores) 44 | self.models_confidence.append(0) 45 | 46 | # compute the mean prediction scores that will be used to rectify scores in subsequent tasks 47 | with torch.no_grad(): 48 | self.model.eval() 49 | for images, targets in trn_loader: 50 | outputs = self.model(images.to(self.device)) 51 | scores = np.array(torch.cat(outputs, dim=1).data.cpu().numpy(), dtype=np.float) 52 | for m in range(len(targets)): 53 | if targets[m] < old_classes_number: 54 | # computation of class means for past classes of the current state. 55 | self.current_classes_means[targets[m]] += scores[m, targets[m]] 56 | classes_counts[targets[m]] += 1 57 | else: 58 | # compute the mean prediction scores for the new classes of the current state 59 | self.init_classes_means[targets[m]] += scores[m, targets[m]] 60 | classes_counts[targets[m]] += 1 61 | # compute the mean top scores for the new classes of the current state 62 | self.models_confidence[t] += np.max(scores[m, ]) 63 | models_counts += 1 64 | # Normalize by corresponding number of images 65 | for cls in range(old_classes_number): 66 | self.current_classes_means[cls] /= classes_counts[cls] 67 | for cls in range(old_classes_number, old_classes_number + self.model.task_cls[t]): 68 | self.init_classes_means[cls] /= classes_counts[cls] 69 | self.models_confidence[t] /= models_counts 70 | 71 | def train_loop(self, t, trn_loader, val_loader): 72 | """Contains the epochs loop""" 73 | 74 | # add exemplars to train_loader 75 | if len(self.exemplars_dataset) > 0 and t > 0: 76 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 77 | batch_size=trn_loader.batch_size, 78 | shuffle=True, 79 | num_workers=trn_loader.num_workers, 80 | pin_memory=trn_loader.pin_memory) 81 | 82 | # FINETUNING TRAINING -- contains the epochs loop 83 | self.ft_train = True 84 | super().train_loop(t, trn_loader, val_loader) 85 | self.ft_train = False 86 | 87 | # IL2M outputs rectification 88 | self.il2m(t, trn_loader) 89 | 90 | # EXEMPLAR MANAGEMENT -- select training subset 91 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 92 | 93 | def calculate_metrics(self, outputs, targets): 94 | """Contains the main Task-Aware and Task-Agnostic metrics""" 95 | if self.ft_train: 96 | # no score rectification while training 97 | hits_taw, hits_tag = super().calculate_metrics(outputs, targets) 98 | else: 99 | # Task-Aware Multi-Head 100 | pred = torch.zeros_like(targets.to(self.device)) 101 | for m in range(len(pred)): 102 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum() 103 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task] 104 | hits_taw = (pred == targets.to(self.device)).float() 105 | # Task-Agnostic Multi-Head 106 | if self.multi_softmax: 107 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs] 108 | # Eq. 1: rectify predicted scores 109 | old_classes_number = sum(self.model.task_cls[:-1]) 110 | for m in range(len(targets)): 111 | rectified_outputs = torch.cat(outputs, dim=1) 112 | pred[m] = rectified_outputs[m].argmax() 113 | if old_classes_number: 114 | # if the top-1 class predicted by the network is a new one, rectify the score 115 | if int(pred[m]) >= old_classes_number: 116 | for o in range(old_classes_number): 117 | o_task = int((self.model.task_cls.cumsum(0) <= o).sum()) 118 | rectified_outputs[m, o] *= (self.init_classes_means[o] / self.current_classes_means[o]) * \ 119 | (self.models_confidence[-1] / self.models_confidence[o_task]) 120 | pred[m] = rectified_outputs[m].argmax() 121 | # otherwise, rectification is not done because an old class is directly predicted 122 | hits_tag = (pred == targets.to(self.device)).float() 123 | return hits_taw, hits_tag 124 | 125 | def criterion(self, t, outputs, targets): 126 | """Returns the loss value""" 127 | if len(self.exemplars_dataset) > 0: 128 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 129 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 130 | -------------------------------------------------------------------------------- /src/approach/joint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | from torch.utils.data import DataLoader, Dataset 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the joint baseline""" 11 | 12 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 13 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 14 | logger=None, exemplars_dataset=None, freeze_after=-1): 15 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 16 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 17 | exemplars_dataset) 18 | self.trn_datasets = [] 19 | self.val_datasets = [] 20 | self.freeze_after = freeze_after 21 | 22 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class 23 | assert (have_exemplars == 0), 'Warning: Joint does not use exemplars. Comment this line to force it.' 24 | 25 | @staticmethod 26 | def exemplars_dataset_class(): 27 | return ExemplarsDataset 28 | 29 | @staticmethod 30 | def extra_parser(args): 31 | """Returns a parser containing the approach specific parameters""" 32 | parser = ArgumentParser() 33 | parser.add_argument('--freeze-after', default=-1, type=int, required=False, 34 | help='Freeze model except heads after the specified task' 35 | '(-1: normal Incremental Joint Training, no freeze) (default=%(default)s)') 36 | return parser.parse_known_args(args) 37 | 38 | def post_train_process(self, t, trn_loader): 39 | """Runs after training all the epochs of the task (after the train session)""" 40 | if self.freeze_after > -1 and t >= self.freeze_after: 41 | self.model.freeze_all() 42 | for head in self.model.heads: 43 | for param in head.parameters(): 44 | param.requires_grad = True 45 | 46 | def train_loop(self, t, trn_loader, val_loader): 47 | """Contains the epochs loop""" 48 | 49 | # add new datasets to existing cumulative ones 50 | self.trn_datasets.append(trn_loader.dataset) 51 | self.val_datasets.append(val_loader.dataset) 52 | trn_dset = JointDataset(self.trn_datasets) 53 | val_dset = JointDataset(self.val_datasets) 54 | trn_loader = DataLoader(trn_dset, 55 | batch_size=trn_loader.batch_size, 56 | shuffle=True, 57 | num_workers=trn_loader.num_workers, 58 | pin_memory=trn_loader.pin_memory) 59 | val_loader = DataLoader(val_dset, 60 | batch_size=val_loader.batch_size, 61 | shuffle=False, 62 | num_workers=val_loader.num_workers, 63 | pin_memory=val_loader.pin_memory) 64 | # continue training as usual 65 | super().train_loop(t, trn_loader, val_loader) 66 | 67 | def train_epoch(self, t, trn_loader): 68 | """Runs a single epoch""" 69 | if self.freeze_after < 0 or t <= self.freeze_after: 70 | self.model.train() 71 | if self.fix_bn and t > 0: 72 | self.model.freeze_bn() 73 | else: 74 | self.model.eval() 75 | for head in self.model.heads: 76 | head.train() 77 | for images, targets in trn_loader: 78 | # Forward current model 79 | outputs = self.model(images.to(self.device)) 80 | loss = self.criterion(t, outputs, targets.to(self.device)) 81 | # Backward 82 | self.optimizer.zero_grad() 83 | loss.backward() 84 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 85 | self.optimizer.step() 86 | 87 | def criterion(self, t, outputs, targets): 88 | """Returns the loss value""" 89 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 90 | 91 | 92 | class JointDataset(Dataset): 93 | """Characterizes a dataset for PyTorch -- this dataset accumulates each task dataset incrementally""" 94 | 95 | def __init__(self, datasets): 96 | self.datasets = datasets 97 | self._len = sum([len(d) for d in self.datasets]) 98 | 99 | def __len__(self): 100 | 'Denotes the total number of samples' 101 | return self._len 102 | 103 | def __getitem__(self, index): 104 | for d in self.datasets: 105 | if len(d) <= index: 106 | index -= len(d) 107 | else: 108 | x, y = d[index] 109 | return x, y 110 | -------------------------------------------------------------------------------- /src/approach/lwf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from argparse import ArgumentParser 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Learning Without Forgetting (LwF) approach 11 | described in https://arxiv.org/abs/1606.09282 12 | """ 13 | 14 | # Weight decay of 0.0005 is used in the original article (page 4). 15 | # Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our 16 | # method or the compared Less Forgetting Learning (see Table 2(b))." 17 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 18 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 19 | logger=None, exemplars_dataset=None, lamb=1, T=2): 20 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 21 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 22 | exemplars_dataset) 23 | self.model_old = None 24 | self.lamb = lamb 25 | self.T = T 26 | 27 | @staticmethod 28 | def exemplars_dataset_class(): 29 | return ExemplarsDataset 30 | 31 | @staticmethod 32 | def extra_parser(args): 33 | """Returns a parser containing the approach specific parameters""" 34 | parser = ArgumentParser() 35 | # Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor 36 | # the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by 37 | # changing lambda." 38 | parser.add_argument('--lamb', default=1, type=float, required=False, 39 | help='Forgetting-intransigence trade-off (default=%(default)s)') 40 | # Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’ 41 | # recommendations." -- Using a higher value for T produces a softer probability distribution over classes. 42 | parser.add_argument('--T', default=2, type=int, required=False, 43 | help='Temperature scaling (default=%(default)s)') 44 | return parser.parse_known_args(args) 45 | 46 | def _get_optimizer(self): 47 | """Returns the optimizer""" 48 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 49 | # if there are no exemplars, previous heads are not modified 50 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 51 | else: 52 | params = self.model.parameters() 53 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 54 | 55 | def train_loop(self, t, trn_loader, val_loader): 56 | """Contains the epochs loop""" 57 | 58 | # add exemplars to train_loader 59 | if len(self.exemplars_dataset) > 0 and t > 0: 60 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 61 | batch_size=trn_loader.batch_size, 62 | shuffle=True, 63 | num_workers=trn_loader.num_workers, 64 | pin_memory=trn_loader.pin_memory) 65 | 66 | # FINETUNING TRAINING -- contains the epochs loop 67 | super().train_loop(t, trn_loader, val_loader) 68 | 69 | # EXEMPLAR MANAGEMENT -- select training subset 70 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 71 | 72 | def post_train_process(self, t, trn_loader): 73 | """Runs after training all the epochs of the task (after the train session)""" 74 | 75 | # Restore best and save model for future tasks 76 | self.model_old = deepcopy(self.model) 77 | self.model_old.eval() 78 | self.model_old.freeze_all() 79 | 80 | def train_epoch(self, t, trn_loader): 81 | """Runs a single epoch""" 82 | self.model.train() 83 | if self.fix_bn and t > 0: 84 | self.model.freeze_bn() 85 | for images, targets in trn_loader: 86 | # Forward old model 87 | targets_old = None 88 | if t > 0: 89 | targets_old = self.model_old(images.to(self.device)) 90 | # Forward current model 91 | outputs = self.model(images.to(self.device)) 92 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old) 93 | # Backward 94 | self.optimizer.zero_grad() 95 | loss.backward() 96 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 97 | self.optimizer.step() 98 | 99 | def eval(self, t, val_loader): 100 | """Contains the evaluation code""" 101 | with torch.no_grad(): 102 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0 103 | self.model.eval() 104 | for images, targets in val_loader: 105 | # Forward old model 106 | targets_old = None 107 | if t > 0: 108 | targets_old = self.model_old(images.to(self.device)) 109 | # Forward current model 110 | outputs = self.model(images.to(self.device)) 111 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old) 112 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets) 113 | # Log 114 | total_loss += loss.data.cpu().numpy().item() * len(targets) 115 | total_acc_taw += hits_taw.sum().data.cpu().numpy().item() 116 | total_acc_tag += hits_tag.sum().data.cpu().numpy().item() 117 | total_num += len(targets) 118 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num 119 | 120 | def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5): 121 | """Calculates cross-entropy with temperature scaling""" 122 | out = torch.nn.functional.softmax(outputs, dim=1) 123 | tar = torch.nn.functional.softmax(targets, dim=1) 124 | if exp != 1: 125 | out = out.pow(exp) 126 | out = out / out.sum(1).view(-1, 1).expand_as(out) 127 | tar = tar.pow(exp) 128 | tar = tar / tar.sum(1).view(-1, 1).expand_as(tar) 129 | out = out + eps / out.size(1) 130 | out = out / out.sum(1).view(-1, 1).expand_as(out) 131 | ce = -(tar * out.log()).sum(1) 132 | if size_average: 133 | ce = ce.mean() 134 | return ce 135 | 136 | def criterion(self, t, outputs, targets, outputs_old=None): 137 | """Returns the loss value""" 138 | loss = 0 139 | if t > 0: 140 | # Knowledge distillation loss for all previous tasks 141 | loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1), 142 | torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T) 143 | # Current cross-entropy loss -- with exemplars use all heads 144 | if len(self.exemplars_dataset) > 0: 145 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 146 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 147 | -------------------------------------------------------------------------------- /src/approach/mas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from argparse import ArgumentParser 4 | 5 | from .incremental_learning import Inc_Learning_Appr 6 | from datasets.exemplars_dataset import ExemplarsDataset 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Memory Aware Synapses (MAS) approach (global version) 11 | described in https://arxiv.org/abs/1711.09601 12 | Original code available at https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses 13 | """ 14 | 15 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 16 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 17 | logger=None, exemplars_dataset=None, lamb=1, alpha=0.5, fi_num_samples=-1): 18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 20 | exemplars_dataset) 21 | self.lamb = lamb 22 | self.alpha = alpha 23 | self.num_samples = fi_num_samples 24 | 25 | # In all cases, we only keep importance weights for the model, but not for the heads. 26 | feat_ext = self.model.model 27 | # Store current parameters as the initial parameters before first task starts 28 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad} 29 | # Store fisher information weight importance 30 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 31 | if p.requires_grad} 32 | 33 | @staticmethod 34 | def exemplars_dataset_class(): 35 | return ExemplarsDataset 36 | 37 | @staticmethod 38 | def extra_parser(args): 39 | """Returns a parser containing the approach specific parameters""" 40 | parser = ArgumentParser() 41 | # Eq. 3: lambda is the regularizer trade-off -- In original code: MAS.ipynb block [4]: lambda set to 1 42 | parser.add_argument('--lamb', default=1, type=float, required=False, 43 | help='Forgetting-intransigence trade-off (default=%(default)s)') 44 | # Define how old and new importance is fused, by default it is a 50-50 fusion 45 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 46 | help='MAS alpha (default=%(default)s)') 47 | # Number of samples from train for estimating importance 48 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False, 49 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 50 | return parser.parse_known_args(args) 51 | 52 | def _get_optimizer(self): 53 | """Returns the optimizer""" 54 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 55 | # if there are no exemplars, previous heads are not modified 56 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 57 | else: 58 | params = self.model.parameters() 59 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 60 | 61 | # Section 4.1: MAS (global) is implemented since the paper shows is more efficient than l-MAS (local) 62 | def estimate_parameter_importance(self, trn_loader): 63 | # Initialize importance matrices 64 | importance = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 65 | if p.requires_grad} 66 | # Compute fisher information for specified number of samples -- rounded to the batch size 67 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 68 | else (len(trn_loader.dataset) // trn_loader.batch_size) 69 | # Do forward and backward pass to accumulate L2-loss gradients 70 | self.model.train() 71 | for images, targets in itertools.islice(trn_loader, n_samples_batches): 72 | # MAS allows any unlabeled data to do the estimation, we choose the current data as in main experiments 73 | outputs = self.model.forward(images.to(self.device)) 74 | # Page 6: labels not required, "...use the gradients of the squared L2-norm of the learned function output." 75 | loss = torch.norm(torch.cat(outputs, dim=1), p=2, dim=1).mean() 76 | self.optimizer.zero_grad() 77 | loss.backward() 78 | # Eq. 2: accumulate the gradients over the inputs to obtain importance weights 79 | for n, p in self.model.model.named_parameters(): 80 | if p.grad is not None: 81 | importance[n] += p.grad.abs() * len(targets) 82 | # Eq. 2: divide by N total number of samples 83 | n_samples = n_samples_batches * trn_loader.batch_size 84 | importance = {n: (p / n_samples) for n, p in importance.items()} 85 | return importance 86 | 87 | def train_loop(self, t, trn_loader, val_loader): 88 | """Contains the epochs loop""" 89 | 90 | # add exemplars to train_loader 91 | if len(self.exemplars_dataset) > 0 and t > 0: 92 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 93 | batch_size=trn_loader.batch_size, 94 | shuffle=True, 95 | num_workers=trn_loader.num_workers, 96 | pin_memory=trn_loader.pin_memory) 97 | 98 | # FINETUNING TRAINING -- contains the epochs loop 99 | super().train_loop(t, trn_loader, val_loader) 100 | 101 | # EXEMPLAR MANAGEMENT -- select training subset 102 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 103 | 104 | def post_train_process(self, t, trn_loader): 105 | """Runs after training all the epochs of the task (after the train session)""" 106 | 107 | # Store current parameters for the next task 108 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 109 | 110 | # calculate Fisher information 111 | curr_importance = self.estimate_parameter_importance(trn_loader) 112 | # merge fisher information, we do not want to keep fisher information for each task in memory 113 | for n in self.importance.keys(): 114 | # Added option to accumulate importance over time with a pre-fixed growing alpha 115 | if self.alpha == -1: 116 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 117 | self.importance[n] = alpha * self.importance[n] + (1 - alpha) * curr_importance[n] 118 | else: 119 | # As in original code: MAS_utils/MAS_based_Training.py line 638 -- just add prev and new 120 | self.importance[n] = self.alpha * self.importance[n] + (1 - self.alpha) * curr_importance[n] 121 | 122 | def criterion(self, t, outputs, targets): 123 | """Returns the loss value""" 124 | loss = 0 125 | if t > 0: 126 | loss_reg = 0 127 | # Eq. 3: memory aware synapses regularizer penalty 128 | for n, p in self.model.model.named_parameters(): 129 | if n in self.importance.keys(): 130 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2)) / 2 131 | loss += self.lamb * loss_reg 132 | # Current cross-entropy loss -- with exemplars use all heads 133 | if len(self.exemplars_dataset) > 0: 134 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 135 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 136 | -------------------------------------------------------------------------------- /src/approach/oewc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from argparse import ArgumentParser 4 | 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | from .incremental_learning import Inc_Learning_Appr 7 | 8 | 9 | class Appr(Inc_Learning_Appr): 10 | """Class implementing the Elastic Weight Consolidation (EWC) approach 11 | described in http://arxiv.org/abs/1612.00796 12 | """ 13 | 14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 16 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred', 17 | fi_num_samples=-1): 18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 20 | exemplars_dataset) 21 | self.lamb = lamb 22 | self.alpha = alpha 23 | self.sampling_type = fi_sampling_type 24 | self.num_samples = fi_num_samples 25 | 26 | # In all cases, we only keep importance weights for the model, but not for the heads. 27 | feat_ext = self.model.model 28 | # Store current parameters as the initial parameters before first task starts 29 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad} 30 | # Store fisher information weight importance 31 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 32 | if p.requires_grad} 33 | 34 | @staticmethod 35 | def exemplars_dataset_class(): 36 | return ExemplarsDataset 37 | 38 | @staticmethod 39 | def extra_parser(args): 40 | """Returns a parser containing the approach specific parameters""" 41 | parser = ArgumentParser() 42 | # Eq. 3: "lambda sets how important the old task is compared to the new one" 43 | parser.add_argument('--lamb', default=5000, type=float, required=False, 44 | help='Forgetting-intransigence trade-off (default=%(default)s)') 45 | # Define how old and new fisher is fused, by default it is a 50-50 fusion 46 | parser.add_argument('--alpha', default=0.5, type=float, required=False, 47 | help='EWC alpha (default=%(default)s)') 48 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False, 49 | choices=['true', 'max_pred', 'multinomial'], 50 | help='Sampling type for Fisher information (default=%(default)s)') 51 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False, 52 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)') 53 | 54 | return parser.parse_known_args(args) 55 | 56 | def _get_optimizer(self): 57 | """Returns the optimizer""" 58 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 59 | # if there are no exemplars, previous heads are not modified 60 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 61 | else: 62 | params = self.model.parameters() 63 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 64 | 65 | def compute_fisher_matrix_diag(self, trn_loader): 66 | # Store Fisher Information 67 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() 68 | if p.requires_grad} 69 | # Compute fisher information for specified number of samples -- rounded to the batch size 70 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \ 71 | else (len(trn_loader.dataset) // trn_loader.batch_size) 72 | # Do forward and backward pass to compute the fisher information 73 | self.model.train() 74 | for images, targets in itertools.islice(trn_loader, n_samples_batches): 75 | outputs = self.model.forward(images.to(self.device)) 76 | 77 | if self.sampling_type == 'true': 78 | # Use the labels to compute the gradients based on the CE-loss with the ground truth 79 | preds = targets.to(self.device) 80 | elif self.sampling_type == 'max_pred': 81 | # Not use labels and compute the gradients related to the prediction the model has learned 82 | preds = torch.cat(outputs, dim=1).argmax(1).flatten() 83 | elif self.sampling_type == 'multinomial': 84 | # Use a multinomial sampling to compute the gradients 85 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1) 86 | preds = torch.multinomial(probs, len(targets)).flatten() 87 | 88 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds) 89 | self.optimizer.zero_grad() 90 | loss.backward() 91 | # Accumulate all gradients from loss with regularization 92 | for n, p in self.model.model.named_parameters(): 93 | if p.grad is not None: 94 | fisher[n] += p.grad.pow(2) * len(targets) 95 | # Apply mean across all samples 96 | n_samples = n_samples_batches * trn_loader.batch_size 97 | fisher = {n: (p / n_samples) for n, p in fisher.items()} 98 | return fisher 99 | 100 | def train_loop(self, t, trn_loader, val_loader): 101 | """Contains the epochs loop""" 102 | 103 | # add exemplars to train_loader 104 | if len(self.exemplars_dataset) > 0 and t > 0: 105 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 106 | batch_size=trn_loader.batch_size, 107 | shuffle=True, 108 | num_workers=trn_loader.num_workers, 109 | pin_memory=trn_loader.pin_memory) 110 | 111 | # FINETUNING TRAINING -- contains the epochs loop 112 | super().train_loop(t, trn_loader, val_loader) 113 | 114 | # EXEMPLAR MANAGEMENT -- select training subset 115 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 116 | 117 | def post_train_process(self, t, trn_loader): 118 | """Runs after training all the epochs of the task (after the train session)""" 119 | 120 | # Store current parameters for the next task 121 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 122 | 123 | # calculate Fisher information 124 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader) 125 | # merge fisher information, we do not want to keep fisher information for each task in memory 126 | for n in self.fisher.keys(): 127 | # Added option to accumulate fisher over time with a pre-fixed growing alpha 128 | if self.alpha == -1: 129 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device) 130 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n] 131 | else: 132 | self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n]) 133 | 134 | def compute_orth_loss(self, old_attention_list, attention_list): 135 | totloss = 0. 136 | for i in range(len(attention_list)): 137 | al = attention_list[i].view(-1, 197,197).to(self.device) 138 | ol = old_attention_list[i].view(-1, 197,197).to(self.device) 139 | totloss += (torch.linalg.matrix_norm((al-ol)+1e-08)*float(i)).mean() 140 | return totloss 141 | 142 | def criterion(self, t, outputs, targets): 143 | """Returns the loss value""" 144 | loss = 0 145 | if t > 0: 146 | loss_reg = 0 147 | # Eq. 3: elastic weight consolidation quadratic penalty 148 | for n, p in self.model.model.named_parameters(): 149 | if n in self.fisher.keys(): 150 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2 151 | loss += self.lamb * loss_reg 152 | # Current cross-entropy loss -- with exemplars use all heads 153 | if len(self.exemplars_dataset) > 0: 154 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 155 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 156 | -------------------------------------------------------------------------------- /src/approach/path_integral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | from .incremental_learning import Inc_Learning_Appr 5 | from datasets.exemplars_dataset import ExemplarsDataset 6 | 7 | 8 | class Appr(Inc_Learning_Appr): 9 | """Class implementing the Path Integral (aka Synaptic Intelligence) approach 10 | described in http://proceedings.mlr.press/v70/zenke17a.html 11 | Original code available at https://github.com/ganguli-lab/pathint 12 | """ 13 | 14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000, 15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False, 16 | logger=None, exemplars_dataset=None, lamb=0.1, damping=0.1): 17 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd, 18 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger, 19 | exemplars_dataset) 20 | self.lamb = lamb 21 | self.damping = damping 22 | 23 | # In all cases, we only keep importance weights for the model, but not for the heads. 24 | feat_ext = self.model.model 25 | # Page 3, following Eq. 3: "The w now have an intuitive interpretation as the parameter specific contribution to 26 | # changes in the total loss." 27 | self.w = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() if p.requires_grad} 28 | # Store current parameters as the initial parameters before first task starts 29 | self.older_params = {n: p.clone().detach().to(self.device) for n, p in feat_ext.named_parameters() 30 | if p.requires_grad} 31 | # Store importance weights matrices 32 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() 33 | if p.requires_grad} 34 | 35 | @staticmethod 36 | def exemplars_dataset_class(): 37 | return ExemplarsDataset 38 | 39 | @staticmethod 40 | def extra_parser(args): 41 | """Returns a parser containing the approach specific parameters""" 42 | parser = ArgumentParser() 43 | # Eq. 4: lamb is the 'c' trade-off parameter from the surrogate loss -- 1e-3 < c < 0.1 44 | parser.add_argument('--lamb', default=0.1, type=float, required=False, 45 | help='Forgetting-intransigence trade-off (default=%(default)s)') 46 | # Eq. 5: damping parameter is set to 0.1 in the MNIST case 47 | parser.add_argument('--damping', default=0.1, type=float, required=False, 48 | help='Damping (default=%(default)s)') 49 | return parser.parse_known_args(args) 50 | 51 | def _get_optimizer(self): 52 | """Returns the optimizer""" 53 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1: 54 | # if there are no exemplars, previous heads are not modified 55 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters()) 56 | else: 57 | params = self.model.parameters() 58 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum) 59 | 60 | def train_loop(self, t, trn_loader, val_loader): 61 | """Contains the epochs loop""" 62 | 63 | # add exemplars to train_loader 64 | if len(self.exemplars_dataset) > 0 and t > 0: 65 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset, 66 | batch_size=trn_loader.batch_size, 67 | shuffle=True, 68 | num_workers=trn_loader.num_workers, 69 | pin_memory=trn_loader.pin_memory) 70 | 71 | # FINETUNING TRAINING -- contains the epochs loop 72 | super().train_loop(t, trn_loader, val_loader) 73 | 74 | # EXEMPLAR MANAGEMENT -- select training subset 75 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform) 76 | 77 | def post_train_process(self, t, trn_loader): 78 | """Runs after training all the epochs of the task (after the train session)""" 79 | 80 | # Eq. 5: accumulate Omega regularization strength (importance matrix) 81 | with torch.no_grad(): 82 | curr_params = {n: p for n, p in self.model.model.named_parameters() if p.requires_grad} 83 | for n, p in self.importance.items(): 84 | p += self.w[n] / ((curr_params[n] - self.older_params[n]) ** 2 + self.damping) 85 | self.w[n].zero_() 86 | 87 | # Store current parameters for the next task 88 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 89 | 90 | def train_epoch(self, t, trn_loader): 91 | """Runs a single epoch""" 92 | self.model.train() 93 | if self.fix_bn and t > 0: 94 | self.model.freeze_bn() 95 | for images, targets in trn_loader: 96 | # store current model without heads 97 | curr_feat_ext = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad} 98 | 99 | # Forward current model 100 | outputs = self.model(images.to(self.device)) 101 | # theoretically this is the correct one for 2 tasks, however, for more tasks maybe is the current loss 102 | # check https://github.com/ganguli-lab/pathint/blob/master/pathint/optimizers.py line 123 103 | # cross-entropy loss on current task 104 | if len(self.exemplars_dataset) == 0: 105 | loss = torch.nn.functional.cross_entropy(outputs[t], targets.to(self.device) - self.model.task_offset[t]) 106 | else: 107 | # with exemplars we check output from all heads (train data has all labels) 108 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets.to(self.device)) 109 | self.optimizer.zero_grad() 110 | loss.backward(retain_graph=True) 111 | # store gradients without regularization term 112 | unreg_grads = {n: p.grad.clone().detach() for n, p in self.model.model.named_parameters() 113 | if p.grad is not None} 114 | # apply loss with path integral regularization 115 | loss = self.criterion(t, outputs, targets.to(self.device)) 116 | 117 | # Backward 118 | self.optimizer.zero_grad() 119 | loss.backward() 120 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad) 121 | self.optimizer.step() 122 | 123 | # Eq. 3: accumulate w, compute the path integral -- "In practice, we can approximate w online as the running 124 | # sum of the product of the gradient with the parameter update". 125 | with torch.no_grad(): 126 | for n, p in self.model.model.named_parameters(): 127 | if n in unreg_grads.keys(): 128 | # w[n] >=0, but minus for loss decrease 129 | self.w[n] -= unreg_grads[n] * (p.detach() - curr_feat_ext[n]) 130 | 131 | def criterion(self, t, outputs, targets): 132 | """Returns the loss value""" 133 | loss = 0 134 | if t > 0: 135 | loss_reg = 0 136 | # Eq. 4: quadratic surrogate loss 137 | for n, p in self.model.model.named_parameters(): 138 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2)) 139 | loss += self.lamb * loss_reg 140 | # Current cross-entropy loss -- with exemplars use all heads 141 | if len(self.exemplars_dataset) > 0: 142 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets) 143 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t]) 144 | -------------------------------------------------------------------------------- /src/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | We include predefined datasets MNIST, CIFAR-100, SVHN, VGGFace2, ImageNet, ImageNet-subset. The first three are 3 | integrated directly from their respective torchvision classes. The other follow our proposed dataset implementation 4 | (see below). We also include `imagenet_32_reduced` to be used with the DMC approach as an external dataset. 5 | 6 | ## Main usage 7 | When running an experiment, datasets used can be defined in [main_incremental.py](../main_incremental.py) using 8 | `--datasets`. A single dataset or a list of datasets can be provided, and will be learned in the given order. If 9 | `--num_tasks` is larger than 1, each dataset will further be split into the given number of tasks. All tasks from the 10 | first dataset will be learned before moving to the next dataset in the list. Other main arguments related to datasets 11 | are: 12 | 13 | * `--num-workers`: number of subprocesses to use for dataloader (default=4) 14 | * `--pin-memory`: copy Tensors into CUDA pinned memory before returning them (default=False) 15 | * `--batch-size`: number of samples per batch to load (default=64) 16 | * `--nc-first-task`: number of classes of the first task (default=None) 17 | * `--use-valid-only`: use validation split instead of test (default=False) 18 | * `--stop-at-task`: stop training after specified task (0: no stop) (default=0) 19 | 20 | Datasets are defined in [dataset_config.py](dataset_config.py). Each entry key is considered a dataset name. For a 21 | proper configuration, the mandatory entry value is `path`, which contains the path to the dataset folder (or for 22 | torchvision datasets, the path to where it will be downloaded). The rest of values are possible transformations to be 23 | applied to the datasets (see [data_loader.py](data_loader.py)): 24 | 25 | * `resize`: resize the input image to the given size on both train and eval 26 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.Resize)]. 27 | * `pad`: pad the given image on all sides with the given “pad” value on both train and eval 28 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.functional.pad)]. 29 | * `crop`: crop the given image to random size and aspect ratio for train 30 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.RandomResizedCrop)], and at the center 31 | on eval [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.CenterCrop)]. 32 | * `flip`: horizontally flip the given image randomly with a 50% probability on train only 33 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.RandomHorizontalFlip)]. 34 | * `normalize`: normalize a tensor image with mean and standard deviation 35 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.Normalize)]. 36 | * `class_order`: fixes the class order given the list of ordered labels. It also allows to limit the dataset to only the 37 | classes provided. 38 | * `extend_channel`: times that the input channels have to be extended. 39 | 40 | where the first ones are adapted from PyTorch transforms, and the last two are our own additions. 41 | 42 | ### Dataset types 43 | MNIST, CIFAR-100 and SVHN are small enough to use [memory_dataset.py](memory_dataset.py), which loads all images in 44 | memory. The other datasets are too large to fully load in memory and, therefore, use [base_dataset.py](base_dataset.py). 45 | This dataset type loads the corresponding paths and labels in memory and the images make use the PyTorch DataLoader to 46 | load the images in batches when needed. This can be modified in [data_loader.py](data_loader.py#L64) by modifying the 47 | type of dataset used. 48 | 49 | ### Dataset path 50 | Modify variable `_BASE_DATA_PATH` in [dataset_config.py](dataset_config.py) pointing to the root folder containing your 51 | datasets. You can also modify the `path` field in the entries in [dataset_config.py](dataset_config.py) if a specific 52 | dataset is found to a special location in your machine. 53 | 54 | ### Exemplars 55 | For those approaches that can use exemplars, those can be used with either `--num_exemplars` or `--num_exemplars_per_class`. The first one is for a fixed memory where the number of exemplars is the same for all classes. The second one is a growing memory which specifies the number of exemplars per class that will be stored. 56 | 57 | Different exemplar sampling strategies are implemented in [exemplars_selection.py](exemplars_selection.py). Those can be selected by using `--exemplar_selection` followed by one of this strategies: 58 | 59 | * `random`: produces a random list of samples. 60 | * `herding`: sampling based on distance to the mean sample of each class. From [iCaRL](https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf) algorithms 4 and 5. 61 | * `entropy`: sampling based on entropy of each sample. From [RWalk](http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112). 62 | * `distance`: sampling based on closeness to decision boundary of each sample. From [RWalk](http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112). 63 | 64 | 65 | ## Adding new datasets 66 | To add a new dataset, follow this: 67 | 68 | 1. Create a new entry in [dataset_config.py](dataset_config.py), add the folder with the data in `path` and any other transformations or class ordering needed. 69 | 2. Depending if the dataset is in [torchvision](https://pytorch.org/docs/stable/torchvision/datasets.html) or custom: 70 | * If the new dataset is available in **torchvision**, you can add a new option like in lines 67, 79 or 91 from [data_loader.py](data_loader.py). If the dataset is too large to fit in memory, use `base_dataset`, else `memory_dataset`. 71 | * If the dataset is **custom**, the option from line 135 in [data_loader.py](data_loader.py) should be enough. In the same folder as the data add a `train.txt` and a `test.txt` files. Each containing one line per sample with the path and the corresponding class label: 72 | ``` 73 | /data/train/sample1.jpg 0 74 | /data/train/sample2.jpg 1 75 | /data/train/sample3.jpg 2 76 | ... 77 | ``` 78 | 79 | No need to modify any other file. If the dataset is a subset or modification of an already added dataset, only use step 1. 80 | 81 | ### Available custom datasets 82 | We provide `.txt` and `.zip` files for the following datasets for an easier integration with FACIL: 83 | 84 | * (COMING SOON) 85 | 86 | ## Changing dataset transformations 87 | 88 | There are some cases when it's necessary to get data from the existing dataset with a different transformation (i.e. self-supervision) or none at all (i.e. selecting exemplars). For this particular cases we prepared a simple solution to override transformations for a dataset within a selected context, e.g. taking raw pictures in `numpy` format: 89 | 90 | ```python 91 | with override_dataset_transform(trn_loader.dataset, Lambda(lambda x: np.array(x))) as ds_for_raw: 92 | x, y = zip(*(ds_for_raw[idx] for idx in selected_indices)) 93 | ``` 94 | 95 | This can be used in other cases too, like train/eval change of transformation when checking the training process, etc. Also, you can check out a simple unit test in [test_dataset_transforms.py](src/tests/test_datasets_transforms.py). 96 | 97 | ## Notes 98 | * As an example, we include two versions of CIFAR-100. The one with entry `cifar100` is the default one which by default 99 | shuffles the class order depending on the fixed seed. We also include `cifar100_icarl` which fixes the class order 100 | from iCaRL (given by seed 1993), and thus makes the comparison more fair with results from papers that use that class 101 | order (e.g. iCaRL, LUCIR, BiC). 102 | * When using multiple machines, you can create different [dataset_config.py](dataset_config.py) files for each of them. 103 | Otherwise, remember you can create symbolic links that point to a specific folder with 104 | `ln -s TARGET_DIRECTORY LINK_NAME` in Linux/UNIX. 105 | -------------------------------------------------------------------------------- /src/datasets/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/dataset_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/dataset_config.cpython-36.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/dataset_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/dataset_config.cpython-38.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/exemplars_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/exemplars_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/exemplars_selection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/exemplars_selection.cpython-38.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/memory_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/memory_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/datasets/__pycache__/memory_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/memory_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /src/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class BaseDataset(Dataset): 9 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory""" 10 | 11 | def __init__(self, data, transform, class_indices=None): 12 | """Initialization""" 13 | self.labels = data['y'] 14 | self.images = data['x'] 15 | self.transform = transform 16 | self.class_indices = class_indices 17 | 18 | def __len__(self): 19 | """Denotes the total number of samples""" 20 | return len(self.images) 21 | 22 | def __getitem__(self, index): 23 | """Generates one sample of data""" 24 | x = Image.open(self.images[index]).convert('RGB') 25 | x = self.transform(x) 26 | y = self.labels[index] 27 | return x, y 28 | 29 | 30 | def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None): 31 | """Prepare data: dataset splits, task partition, class order""" 32 | 33 | data = {} 34 | taskcla = [] 35 | 36 | # read filenames and labels 37 | trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str) 38 | tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str) 39 | if class_order is None: 40 | num_classes = len(np.unique(trn_lines[:, 1])) 41 | class_order = list(range(num_classes)) 42 | else: 43 | num_classes = len(class_order) 44 | class_order = class_order.copy() 45 | if shuffle_classes: 46 | np.random.shuffle(class_order) 47 | 48 | # compute classes per task and num_tasks 49 | if nc_first_task is None: 50 | cpertask = np.array([num_classes // num_tasks] * num_tasks) 51 | for i in range(num_classes % num_tasks): 52 | cpertask[i] += 1 53 | else: 54 | assert nc_first_task < num_classes, "first task wants more classes than exist" 55 | remaining_classes = num_classes - nc_first_task 56 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2 57 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1)) 58 | for i in range(remaining_classes % (num_tasks - 1)): 59 | cpertask[i + 1] += 1 60 | 61 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes" 62 | cpertask_cumsum = np.cumsum(cpertask) 63 | init_class = np.concatenate(([0], cpertask_cumsum[:-1])) 64 | 65 | # initialize data structure 66 | for tt in range(num_tasks): 67 | data[tt] = {} 68 | data[tt]['name'] = 'task-' + str(tt) 69 | data[tt]['trn'] = {'x': [], 'y': []} 70 | data[tt]['val'] = {'x': [], 'y': []} 71 | data[tt]['tst'] = {'x': [], 'y': []} 72 | 73 | # ALL OR TRAIN 74 | for this_image, this_label in trn_lines: 75 | if not os.path.isabs(this_image): 76 | this_image = os.path.join(path, this_image) 77 | this_label = int(this_label) 78 | if this_label not in class_order: 79 | continue 80 | # If shuffling is false, it won't change the class number 81 | this_label = class_order.index(this_label) 82 | 83 | # add it to the corresponding split 84 | this_task = (this_label >= cpertask_cumsum).sum() 85 | data[this_task]['trn']['x'].append(this_image) 86 | data[this_task]['trn']['y'].append(this_label - init_class[this_task]) 87 | 88 | # ALL OR TEST 89 | for this_image, this_label in tst_lines: 90 | if not os.path.isabs(this_image): 91 | this_image = os.path.join(path, this_image) 92 | this_label = int(this_label) 93 | if this_label not in class_order: 94 | continue 95 | # If shuffling is false, it won't change the class number 96 | this_label = class_order.index(this_label) 97 | 98 | # add it to the corresponding split 99 | this_task = (this_label >= cpertask_cumsum).sum() 100 | data[this_task]['tst']['x'].append(this_image) 101 | data[this_task]['tst']['y'].append(this_label - init_class[this_task]) 102 | 103 | # check classes 104 | for tt in range(num_tasks): 105 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y'])) 106 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes" 107 | 108 | # validation 109 | if validation > 0.0: 110 | for tt in data.keys(): 111 | for cc in range(data[tt]['ncla']): 112 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 113 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation))) 114 | rnd_img.sort(reverse=True) 115 | for ii in range(len(rnd_img)): 116 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]]) 117 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]]) 118 | data[tt]['trn']['x'].pop(rnd_img[ii]) 119 | data[tt]['trn']['y'].pop(rnd_img[ii]) 120 | 121 | # other 122 | n = 0 123 | for t in data.keys(): 124 | taskcla.append((t, data[t]['ncla'])) 125 | n += data[t]['ncla'] 126 | data['ncla'] = n 127 | 128 | return data, taskcla, class_order 129 | -------------------------------------------------------------------------------- /src/datasets/dataset_config.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | _BASE_DATA_PATH = "/datatmp/users/fpelosin/data/" 4 | 5 | dataset_config = { 6 | 'mnist': { 7 | 'path': join(_BASE_DATA_PATH, 'mnist'), 8 | 'normalize': ((0.1307,), (0.3081,)), 9 | # Use the next 3 lines to use MNIST with a 3x32x32 input 10 | # 'extend_channel': 3, 11 | # 'pad': 2, 12 | # 'normalize': ((0.1,), (0.2752,)) # values including padding 13 | }, 14 | 'svhn': { 15 | 'path': join(_BASE_DATA_PATH, 'svhn'), 16 | 'resize': (224, 224), 17 | 'crop': None, 18 | 'flip': False, 19 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 20 | }, 21 | 'cifar100_224': { 22 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 23 | 'resize': 224, 24 | 'pad': 4, 25 | 'crop': 224, 26 | 'flip': True, 27 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) 28 | }, 29 | 'cifar100': { 30 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 31 | 'resize': None, 32 | 'pad': 4, 33 | 'crop': 32, 34 | 'flip': True, 35 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) 36 | }, 37 | 'cifar10': { 38 | 'path': join(_BASE_DATA_PATH, 'cifar10'), 39 | 'resize': None, 40 | 'pad': 4, 41 | 'crop': 32, 42 | 'flip': True, 43 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) 44 | }, 45 | 'cifar100_icarl': { 46 | 'path': join(_BASE_DATA_PATH, 'cifar100'), 47 | 'resize': None, 48 | 'pad': 4, 49 | 'crop': 32, 50 | 'flip': True, 51 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)), 52 | 'class_order': [ 53 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 54 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 55 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 56 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 57 | ] 58 | }, 59 | 'vggface2': { 60 | 'path': join(_BASE_DATA_PATH, 'VGGFace2'), 61 | 'resize': 256, 62 | 'crop': 224, 63 | 'flip': True, 64 | 'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169)) 65 | }, 66 | 'imagenet_256': { 67 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'), 68 | 'resize': None, 69 | 'crop': 224, 70 | 'flip': True, 71 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 72 | }, 73 | 'imagenet_subset': { 74 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'), 75 | 'resize': None, 76 | 'crop': 224, 77 | 'flip': True, 78 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 79 | 'class_order': [ 80 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 81 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 82 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 83 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 84 | ] 85 | }, 86 | 'imagenet_32_reduced': { 87 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_32'), 88 | 'resize': None, 89 | 'pad': 4, 90 | 'crop': 32, 91 | 'flip': True, 92 | 'normalize': ((0.481, 0.457, 0.408), (0.260, 0.253, 0.268)), 93 | 'class_order': [ 94 | 472, 46, 536, 806, 547, 976, 662, 12, 955, 651, 492, 80, 999, 996, 788, 471, 911, 907, 680, 126, 42, 882, 95 | 327, 719, 716, 224, 918, 647, 808, 261, 140, 908, 833, 925, 57, 388, 407, 215, 45, 479, 525, 641, 915, 923, 96 | 108, 461, 186, 843, 115, 250, 829, 625, 769, 323, 974, 291, 438, 50, 825, 441, 446, 200, 162, 373, 872, 112, 97 | 212, 501, 91, 672, 791, 370, 942, 172, 315, 959, 636, 635, 66, 86, 197, 182, 59, 736, 175, 445, 947, 268, 98 | 238, 298, 926, 851, 494, 760, 61, 293, 696, 659, 69, 819, 912, 486, 706, 343, 390, 484, 282, 729, 575, 731, 99 | 530, 32, 534, 838, 466, 734, 425, 400, 290, 660, 254, 266, 551, 775, 721, 134, 886, 338, 465, 236, 522, 655, 100 | 209, 861, 88, 491, 985, 304, 981, 560, 405, 902, 521, 909, 763, 455, 341, 905, 280, 776, 113, 434, 274, 581, 101 | 158, 738, 671, 702, 147, 718, 148, 35, 13, 585, 591, 371, 745, 281, 956, 935, 346, 352, 284, 604, 447, 415, 102 | 98, 921, 118, 978, 880, 509, 381, 71, 552, 169, 600, 334, 171, 835, 798, 77, 249, 318, 419, 990, 335, 374, 103 | 949, 316, 755, 878, 946, 142, 299, 863, 558, 306, 183, 417, 64, 765, 565, 432, 440, 939, 297, 805, 364, 735, 104 | 251, 270, 493, 94, 773, 610, 278, 16, 363, 92, 15, 593, 96, 468, 252, 699, 377, 95, 799, 868, 820, 328, 756, 105 | 81, 991, 464, 774, 584, 809, 844, 940, 720, 498, 310, 384, 619, 56, 406, 639, 285, 67, 634, 792, 232, 54, 106 | 664, 818, 513, 349, 330, 207, 361, 345, 279, 549, 944, 817, 353, 228, 312, 796, 193, 179, 520, 451, 871, 107 | 692, 60, 481, 480, 929, 499, 673, 331, 506, 70, 645, 759, 744, 459] 108 | } 109 | } 110 | 111 | # Add missing keys: 112 | for dset in dataset_config.keys(): 113 | for k in ['resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel']: 114 | if k not in dataset_config[dset].keys(): 115 | dataset_config[dset][k] = None 116 | if 'flip' not in dataset_config[dset].keys(): 117 | dataset_config[dset]['flip'] = False 118 | -------------------------------------------------------------------------------- /src/datasets/exemplars_dataset.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from argparse import ArgumentParser 3 | 4 | from datasets.memory_dataset import MemoryDataset 5 | 6 | 7 | class ExemplarsDataset(MemoryDataset): 8 | """Exemplar storage for approaches with an interface of Dataset""" 9 | 10 | def __init__(self, transform, class_indices, 11 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'): 12 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices) 13 | self.max_num_exemplars_per_class = num_exemplars_per_class 14 | self.max_num_exemplars = num_exemplars 15 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!' 16 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize()) 17 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name) 18 | self.exemplars_selector = selector_cls(self) 19 | 20 | # Returns a parser containing the approach specific parameters 21 | @staticmethod 22 | def extra_parser(args): 23 | parser = ArgumentParser("Exemplars Management Parameters") 24 | _group = parser.add_mutually_exclusive_group() 25 | _group.add_argument('--num-exemplars', default=0, type=int, required=False, 26 | help='Fixed memory, total number of exemplars (default=%(default)s)') 27 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False, 28 | help='Growing memory, number of exemplars per class (default=%(default)s)') 29 | parser.add_argument('--exemplar-selection', default='random', type=str, 30 | choices=['herding', 'random', 'entropy', 'distance'], 31 | required=False, help='Exemplar selection strategy (default=%(default)s)') 32 | return parser.parse_known_args(args) 33 | 34 | def _is_active(self): 35 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0 36 | 37 | def collect_exemplars(self, model, trn_loader, selection_transform): 38 | if self._is_active(): 39 | self.images, self.labels = self.exemplars_selector(model, trn_loader, selection_transform) 40 | -------------------------------------------------------------------------------- /src/datasets/memory_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MemoryDataset(Dataset): 8 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory""" 9 | 10 | def __init__(self, data, transform, class_indices=None): 11 | """Initialization""" 12 | self.labels = data['y'] 13 | self.images = data['x'] 14 | self.transform = transform 15 | self.class_indices = class_indices 16 | 17 | def __len__(self): 18 | """Denotes the total number of samples""" 19 | return len(self.images) 20 | 21 | def __getitem__(self, index): 22 | """Generates one sample of data""" 23 | x = Image.fromarray(self.images[index]) 24 | x = self.transform(x) 25 | y = self.labels[index] 26 | return x, y 27 | 28 | 29 | def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None): 30 | """Prepare data: dataset splits, task partition, class order""" 31 | 32 | data = {} 33 | taskcla = [] 34 | if class_order is None: 35 | num_classes = len(np.unique(trn_data['y'])) 36 | class_order = list(range(num_classes)) 37 | else: 38 | num_classes = len(class_order) 39 | class_order = class_order.copy() 40 | if shuffle_classes: 41 | np.random.shuffle(class_order) 42 | 43 | # compute classes per task and num_tasks 44 | if nc_first_task is None: 45 | cpertask = np.array([num_classes // num_tasks] * num_tasks) 46 | for i in range(num_classes % num_tasks): 47 | cpertask[i] += 1 48 | else: 49 | print(nc_first_task, num_classes) 50 | assert nc_first_task < num_classes, "first task wants more classes than exist" 51 | remaining_classes = num_classes - nc_first_task 52 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2 53 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1)) 54 | for i in range(remaining_classes % (num_tasks - 1)): 55 | cpertask[i + 1] += 1 56 | 57 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes" 58 | cpertask_cumsum = np.cumsum(cpertask) 59 | init_class = np.concatenate(([0], cpertask_cumsum[:-1])) 60 | 61 | # initialize data structure 62 | for tt in range(num_tasks): 63 | data[tt] = {} 64 | data[tt]['name'] = 'task-' + str(tt) 65 | data[tt]['trn'] = {'x': [], 'y': []} 66 | data[tt]['val'] = {'x': [], 'y': []} 67 | data[tt]['tst'] = {'x': [], 'y': []} 68 | 69 | # ALL OR TRAIN 70 | filtering = np.isin(trn_data['y'], class_order) 71 | if filtering.sum() != len(trn_data['y']): 72 | trn_data['x'] = trn_data['x'][filtering] 73 | trn_data['y'] = np.array(trn_data['y'])[filtering] 74 | for this_image, this_label in zip(trn_data['x'], trn_data['y']): 75 | # If shuffling is false, it won't change the class number 76 | this_label = class_order.index(this_label) 77 | # add it to the corresponding split 78 | this_task = (this_label >= cpertask_cumsum).sum() 79 | data[this_task]['trn']['x'].append(this_image) 80 | data[this_task]['trn']['y'].append(this_label - init_class[this_task]) 81 | 82 | # ALL OR TEST 83 | filtering = np.isin(tst_data['y'], class_order) 84 | if filtering.sum() != len(tst_data['y']): 85 | tst_data['x'] = tst_data['x'][filtering] 86 | tst_data['y'] = tst_data['y'][filtering] 87 | for this_image, this_label in zip(tst_data['x'], tst_data['y']): 88 | # If shuffling is false, it won't change the class number 89 | this_label = class_order.index(this_label) 90 | # add it to the corresponding split 91 | this_task = (this_label >= cpertask_cumsum).sum() 92 | data[this_task]['tst']['x'].append(this_image) 93 | data[this_task]['tst']['y'].append(this_label - init_class[this_task]) 94 | 95 | # check classes 96 | for tt in range(num_tasks): 97 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y'])) 98 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes" 99 | 100 | # validation 101 | if validation > 0.0: 102 | for tt in data.keys(): 103 | for cc in range(data[tt]['ncla']): 104 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 105 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation))) 106 | rnd_img.sort(reverse=True) 107 | for ii in range(len(rnd_img)): 108 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]]) 109 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]]) 110 | data[tt]['trn']['x'].pop(rnd_img[ii]) 111 | data[tt]['trn']['y'].pop(rnd_img[ii]) 112 | # drop_percent = 0.9 113 | # for tt in list(data.keys())[1:]: 114 | # for cc in range(data[tt]['ncla']): 115 | # cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0]) 116 | # rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * drop_percent))) 117 | # rnd_img.sort(reverse=True) 118 | # for ii in range(len(rnd_img)): 119 | # data[tt]['trn']['x'].pop(rnd_img[ii]) 120 | # data[tt]['trn']['y'].pop(rnd_img[ii]) 121 | for tt in list(data.keys()): 122 | print(tt, len(data[tt]['trn']['x'])) 123 | # convert them to numpy arrays 124 | for tt in data.keys(): 125 | for split in ['trn', 'val', 'tst']: 126 | data[tt][split]['x'] = np.asarray(data[tt][split]['x']) 127 | 128 | # other 129 | n = 0 130 | for t in data.keys(): 131 | taskcla.append((t, data[t]['ncla'])) 132 | n += data[t]['ncla'] 133 | data['ncla'] = n 134 | 135 | return data, taskcla, class_order 136 | -------------------------------------------------------------------------------- /src/gridsearch.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from argparse import ArgumentParser 4 | 5 | import utils 6 | 7 | 8 | class GridSearch: 9 | """Basic class for implementing hyperparameter grid search""" 10 | 11 | def __init__(self, appr_ft, seed, gs_config='gridsearch_config', acc_drop_thr=0.2, hparam_decay=0.5, 12 | max_num_searches=7): 13 | self.seed = seed 14 | GridSearchConfig = getattr(importlib.import_module(name=gs_config), 'GridSearchConfig') 15 | self.appr_ft = appr_ft 16 | self.gs_config = GridSearchConfig() 17 | self.acc_drop_thr = acc_drop_thr 18 | self.hparam_decay = hparam_decay 19 | self.max_num_searches = max_num_searches 20 | self.lr_first = 1.0 21 | 22 | @staticmethod 23 | def extra_parser(args): 24 | """Returns a parser containing the GridSearch specific parameters""" 25 | parser = ArgumentParser() 26 | # Configuration file with a GridSearchConfig class with all necessary args 27 | parser.add_argument('--gridsearch-config', type=str, default='gridsearch_config', required=False, 28 | help='Configuration file for GridSearch options (default=%(default)s)') 29 | # Accuracy threshold drop below which the search stops for that phase 30 | parser.add_argument('--gridsearch-acc-drop-thr', default=0.2, type=float, required=False, 31 | help='GridSearch accuracy drop threshold (default=%(default)f)') 32 | # Value at which hyperparameters decay 33 | parser.add_argument('--gridsearch-hparam-decay', default=0.5, type=float, required=False, 34 | help='GridSearch hyperparameter decay (default=%(default)f)') 35 | # Maximum number of searched before the search stops for that phase 36 | parser.add_argument('--gridsearch-max-num-searches', default=7, type=int, required=False, 37 | help='GridSearch maximum number of hyperparameter search (default=%(default)f)') 38 | return parser.parse_known_args(args) 39 | 40 | def search_lr(self, model, t, trn_loader, val_loader): 41 | """Search for accuracy and best LR on finetuning""" 42 | best_ft_acc = 0.0 43 | best_ft_lr = 0.0 44 | 45 | # Get general parameters and fix the ones with only one value 46 | gen_params = self.gs_config.get_params('general') 47 | for k, v in gen_params.items(): 48 | if not isinstance(v, list): 49 | setattr(self.appr_ft, k, v) 50 | if t > 0: 51 | # LR for search are 'lr_searches' largest LR below 'lr_first' 52 | list_lr = [lr for lr in gen_params['lr'] if lr < self.lr_first][:gen_params['lr_searches'][0]] 53 | else: 54 | # For first task, try larger LR range 55 | list_lr = gen_params['lr_first'] 56 | 57 | # Iterate through the other variable parameters 58 | for curr_lr in list_lr: 59 | utils.seed_everything(seed=self.seed) 60 | self.appr_ft.model = deepcopy(model) 61 | self.appr_ft.lr = curr_lr 62 | self.appr_ft.train(t, trn_loader, val_loader) 63 | _, ft_acc_taw, _ = self.appr_ft.eval(t, val_loader) 64 | if ft_acc_taw > best_ft_acc: 65 | best_ft_acc = ft_acc_taw 66 | best_ft_lr = curr_lr 67 | print('Current best LR: ' + str(best_ft_lr)) 68 | self.gs_config.current_lr = best_ft_lr 69 | print('Current best acc: {:5.1f}'.format(best_ft_acc * 100)) 70 | # After first task, keep LR used 71 | if t == 0: 72 | self.lr_first = best_ft_lr 73 | 74 | return best_ft_acc, best_ft_lr 75 | 76 | def search_tradeoff(self, appr_name, appr, t, trn_loader, val_loader, best_ft_acc): 77 | """Search for less-forgetting tradeoff with minimum accuracy loss""" 78 | best_tradeoff = None 79 | tradeoff_name = None 80 | 81 | # Get general parameters and fix all the ones that have only one option 82 | appr_params = self.gs_config.get_params(appr_name) 83 | for k, v in appr_params.items(): 84 | if isinstance(v, list): 85 | # get tradeoff name as the only one with multiple values 86 | tradeoff_name = k 87 | else: 88 | # Any other hyperparameters are fixed 89 | setattr(appr, k, v) 90 | 91 | # If there is no tradeoff, no need to gridsearch more 92 | if tradeoff_name is not None and t > 0: 93 | # get starting value for trade-off hyperparameter 94 | best_tradeoff = appr_params[tradeoff_name][0] 95 | # iterate through decreasing trade-off values -- limit to `max_num_searches` searches 96 | num_searches = 0 97 | while num_searches < self.max_num_searches: 98 | utils.seed_everything(seed=self.seed) 99 | # Make deepcopy of the appr without duplicating the logger 100 | appr_gs = type(appr)(deepcopy(appr.model), appr.device, exemplars_dataset=appr.exemplars_dataset) 101 | for attr, value in vars(appr).items(): 102 | if attr == 'logger': 103 | setattr(appr_gs, attr, value) 104 | else: 105 | setattr(appr_gs, attr, deepcopy(value)) 106 | 107 | # update tradeoff value 108 | setattr(appr_gs, tradeoff_name, best_tradeoff) 109 | # train this iteration 110 | appr_gs.train(t, trn_loader, val_loader) 111 | _, curr_acc, _ = appr_gs.eval(t, val_loader) 112 | print('Current acc: ' + str(curr_acc) + ' for ' + tradeoff_name + '=' + str(best_tradeoff)) 113 | # Check if accuracy is within acceptable threshold drop 114 | if curr_acc < ((1 - self.acc_drop_thr) * best_ft_acc): 115 | best_tradeoff = best_tradeoff * self.hparam_decay 116 | else: 117 | break 118 | num_searches += 1 119 | else: 120 | print('There is no trade-off to gridsearch.') 121 | 122 | return best_tradeoff, tradeoff_name 123 | -------------------------------------------------------------------------------- /src/gridsearch_config.py: -------------------------------------------------------------------------------- 1 | class GridSearchConfig(): 2 | def __init__(self): 3 | self.params = { 4 | 'general': { 5 | 'lr_first': [5e-1, 1e-1, 5e-2], 6 | 'lr': [1e-1, 5e-2, 1e-2, 5e-3, 1e-3], 7 | 'lr_searches': [3], 8 | 'lr_min': 1e-4, 9 | 'lr_factor': 3, 10 | 'lr_patience': 10, 11 | 'clipping': 10000, 12 | 'momentum': 0.9, 13 | 'wd': 0.0002 14 | }, 15 | 'finetuning': { 16 | }, 17 | 'freezing': { 18 | }, 19 | 'joint': { 20 | }, 21 | 'lwf': { 22 | 'lamb': [10], 23 | 'T': 2 24 | }, 25 | 'icarl': { 26 | 'lamb': [4] 27 | }, 28 | 'dmc': { 29 | 'aux_dataset': 'imagenet_32_reduced', 30 | 'aux_batch_size': 128 31 | }, 32 | 'il2m': { 33 | }, 34 | 'eeil': { 35 | 'lamb': [10], 36 | 'T': 2, 37 | 'lr_finetuning_factor': 0.1, 38 | 'nepochs_finetuning': 40, 39 | 'noise_grad': False 40 | }, 41 | 'bic': { 42 | 'T': 2, 43 | 'val_percentage': 0.1, 44 | 'bias_epochs': 200 45 | }, 46 | 'lucir': { 47 | 'lamda_base': [10], 48 | 'lamda_mr': 1.0, 49 | 'dist': 0.5, 50 | 'K': 2 51 | }, 52 | 'lwm': { 53 | 'beta': [2], 54 | 'gamma': 1.0 55 | }, 56 | 'ewc': { 57 | 'lamb': [10000] 58 | }, 59 | 'mas': { 60 | 'lamb': [400] 61 | }, 62 | 'path_integral': { 63 | 'lamb': [10], 64 | }, 65 | 'r_walk': { 66 | 'lamb': [20], 67 | }, 68 | } 69 | self.current_lr = self.params['general']['lr'][0] 70 | self.current_tradeoff = 0 71 | 72 | def get_params(self, approach): 73 | return self.params[approach] 74 | -------------------------------------------------------------------------------- /src/last_layer_analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | matplotlib.use('Agg') 6 | 7 | 8 | def last_layer_analysis(heads, task, taskcla, y_lim=False, sort_weights=False): 9 | """Plot last layer weight and bias analysis""" 10 | print('Plotting last layer analysis...') 11 | num_classes = sum([x for (_, x) in taskcla]) 12 | weights, biases, indexes = [], [], [] 13 | class_id = 0 14 | with torch.no_grad(): 15 | for t in range(task + 1): 16 | n_classes_t = taskcla[t][1] 17 | indexes.append(np.arange(class_id, class_id + n_classes_t)) 18 | if type(heads) == torch.nn.Linear: # Single head 19 | biases.append(heads.bias[class_id: class_id + n_classes_t].detach().cpu().numpy()) 20 | weights.append((heads.weight[class_id: class_id + n_classes_t] ** 2).sum(1).sqrt().detach().cpu().numpy()) 21 | else: # Multi-head 22 | weights.append((heads[t].weight ** 2).sum(1).sqrt().detach().cpu().numpy()) 23 | if type(heads[t]) == torch.nn.Linear: 24 | biases.append(heads[t].bias.detach().cpu().numpy()) 25 | else: 26 | biases.append(np.zeros(weights[-1].shape)) # For LUCIR 27 | class_id += n_classes_t 28 | 29 | # Figure weights 30 | f_weights = plt.figure(dpi=300) 31 | ax = f_weights.subplots(nrows=1, ncols=1) 32 | for i, (x, y) in enumerate(zip(indexes, weights), 0): 33 | if sort_weights: 34 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i)) 35 | else: 36 | ax.bar(x, y, label="Task {}".format(i)) 37 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif') 38 | ax.set_ylabel("Weights L2-norm", fontsize=11, fontfamily='serif') 39 | if num_classes is not None: 40 | ax.set_xlim(0, num_classes) 41 | if y_lim: 42 | ax.set_ylim(0, 5) 43 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif') 44 | 45 | # Figure biases 46 | f_biases = plt.figure(dpi=300) 47 | ax = f_biases.subplots(nrows=1, ncols=1) 48 | for i, (x, y) in enumerate(zip(indexes, biases), 0): 49 | if sort_weights: 50 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i)) 51 | else: 52 | ax.bar(x, y, label="Task {}".format(i)) 53 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif') 54 | ax.set_ylabel("Bias values", fontsize=11, fontfamily='serif') 55 | if num_classes is not None: 56 | ax.set_xlim(0, num_classes) 57 | if y_lim: 58 | ax.set_ylim(-1.0, 1.0) 59 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif') 60 | 61 | return f_weights, f_biases 62 | -------------------------------------------------------------------------------- /src/loggers/README.md: -------------------------------------------------------------------------------- 1 | # Loggers 2 | 3 | We include a disk logger, which logs into files and folders in the disk. We also provide a tensorboard logger which 4 | provides a faster way of analysing a training process without need of further development. They can be specified with 5 | `--log` followed by `disk`, `tensorboard` or both. Custom loggers can be defined by inheriting the `ExperimentLogger` 6 | in [exp_logger.py](exp_logger.py). 7 | 8 | When enabled, both loggers will output everything in the path `[RESULTS_PATH]/[DATASETS]_[APPROACH]_[EXP_NAME]` or 9 | `[RESULTS_PATH]/[DATASETS]_[APPROACH]` if `--exp-name` is not set. 10 | 11 | ## Disk logger 12 | The disk logger outputs the following file and folder structure: 13 | - **figures/**: folder where generated figures are logged. 14 | - **models/**: folder where model weight checkpoints are saved. 15 | - **results/**: folder containing the results. 16 | - **acc_tag**: task-agnostic accuracy table. 17 | - **acc_taw**: task-aware accuracy table. 18 | - **avg_acc_tag**: task-agnostic average accuracies. 19 | - **avg_acc_taw**: task-agnostic average accuracies. 20 | - **forg_tag**: task-agnostic forgetting table. 21 | - **forg_taw**: task-aware forgetting table. 22 | - **wavg_acc_tag**: task-agnostic average accuracies weighted according to the number of classes of each task. 23 | - **wavg_acc_taw**: task-aware average accuracies weighted according to the number of classes of each task. 24 | - **raw_log**: json file containing all the logged metrics easily read by many tools (e.g. `pandas`). 25 | - stdout: a copy from the standard output of the terminal. 26 | - stderr: a copy from the error output of the terminal. 27 | 28 | ## TensorBoard logger 29 | The tensorboard logger outputs analogous metrics to the disk logger separated into different tabs according to the task 30 | and different graphs according to the data splits. 31 | 32 | Screenshot for a 10 task experiment, showing the last task plots: 33 |

34 | Tensorboard Screenshot 35 |

36 | -------------------------------------------------------------------------------- /src/loggers/__pycache__/disk_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/disk_logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/loggers/__pycache__/exp_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/exp_logger.cpython-36.pyc -------------------------------------------------------------------------------- /src/loggers/__pycache__/exp_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/exp_logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/loggers/__pycache__/tensorboard_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/tensorboard_logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/loggers/disk_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | from loggers.exp_logger import ExperimentLogger 9 | 10 | 11 | class Logger(ExperimentLogger): 12 | """Characterizes a disk logger""" 13 | 14 | def __init__(self, log_path, exp_name, begin_time=None): 15 | super(Logger, self).__init__(log_path, exp_name, begin_time) 16 | 17 | self.begin_time_str = self.begin_time.strftime("%Y-%m-%d-%H-%M") 18 | 19 | # Duplicate standard outputs 20 | sys.stdout = FileOutputDuplicator(sys.stdout, 21 | os.path.join(self.exp_path, 'stdout-{}.txt'.format(self.begin_time_str)), 'w') 22 | sys.stderr = FileOutputDuplicator(sys.stderr, 23 | os.path.join(self.exp_path, 'stderr-{}.txt'.format(self.begin_time_str)), 'w') 24 | 25 | # Raw log file 26 | self.raw_log_file = open(os.path.join(self.exp_path, "raw_log-{}.txt".format(self.begin_time_str)), 'a') 27 | 28 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 29 | if curtime is None: 30 | curtime = datetime.now() 31 | 32 | # Raw dump 33 | entry = {"task": task, "iter": iter, "name": name, "value": value, "group": group, 34 | "time": curtime.strftime("%Y-%m-%d-%H-%M")} 35 | self.raw_log_file.write(json.dumps(entry, sort_keys=True) + "\n") 36 | self.raw_log_file.flush() 37 | 38 | def log_args(self, args): 39 | with open(os.path.join(self.exp_path, 'args-{}.txt'.format(self.begin_time_str)), 'w') as f: 40 | json.dump(args.__dict__, f, separators=(',\n', ' : '), sort_keys=True) 41 | 42 | def log_result(self, array, name, step): 43 | if array.ndim <= 1: 44 | array = array[None] 45 | np.savetxt(os.path.join(self.exp_path, 'results', '{}-{}.txt'.format(name, self.begin_time_str)), 46 | array, '%.6f', delimiter='\t') 47 | 48 | def log_figure(self, name, iter, figure, curtime=None): 49 | curtime = datetime.now() 50 | figure.savefig(os.path.join(self.exp_path, 'figures', 51 | '{}_{}-{}.png'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S")))) 52 | figure.savefig(os.path.join(self.exp_path, 'figures', 53 | '{}_{}-{}.pdf'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S")))) 54 | 55 | def save_model(self, state_dict, task): 56 | torch.save(state_dict, os.path.join(self.exp_path, "models", "task{}.ckpt".format(task))) 57 | 58 | def __del__(self): 59 | self.raw_log_file.close() 60 | 61 | 62 | class FileOutputDuplicator(object): 63 | def __init__(self, duplicate, fname, mode): 64 | self.file = open(fname, mode) 65 | self.duplicate = duplicate 66 | 67 | def __del__(self): 68 | self.file.close() 69 | 70 | def write(self, data): 71 | self.file.write(data) 72 | self.duplicate.write(data) 73 | 74 | def flush(self): 75 | self.file.flush() 76 | -------------------------------------------------------------------------------- /src/loggers/exp_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from datetime import datetime 4 | 5 | 6 | class ExperimentLogger: 7 | """Main class for experiment logging""" 8 | 9 | def __init__(self, log_path, exp_name, begin_time=None): 10 | self.log_path = log_path 11 | self.exp_name = exp_name 12 | self.exp_path = os.path.join(log_path, exp_name) 13 | if begin_time is None: 14 | self.begin_time = datetime.now() 15 | else: 16 | self.begin_time = begin_time 17 | 18 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 19 | pass 20 | 21 | def log_args(self, args): 22 | pass 23 | 24 | def log_result(self, array, name, step): 25 | pass 26 | 27 | def log_figure(self, name, iter, figure, curtime=None): 28 | pass 29 | 30 | def save_model(self, state_dict, task): 31 | pass 32 | 33 | 34 | class MultiLogger(ExperimentLogger): 35 | """This class allows to use multiple loggers""" 36 | 37 | def __init__(self, log_path, exp_name, loggers=None, save_models=True): 38 | super(MultiLogger, self).__init__(log_path, exp_name) 39 | if os.path.exists(self.exp_path): 40 | print("WARNING: {} already exists!".format(self.exp_path)) 41 | else: 42 | os.makedirs(os.path.join(self.exp_path, 'models')) 43 | os.makedirs(os.path.join(self.exp_path, 'results')) 44 | os.makedirs(os.path.join(self.exp_path, 'figures')) 45 | 46 | self.save_models = save_models 47 | self.loggers = [] 48 | for l in loggers: 49 | lclass = getattr(importlib.import_module(name='loggers.' + l + '_logger'), 'Logger') 50 | self.loggers.append(lclass(self.log_path, self.exp_name)) 51 | 52 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 53 | if curtime is None: 54 | curtime = datetime.now() 55 | for l in self.loggers: 56 | l.log_scalar(task, iter, name, value, group, curtime) 57 | 58 | def log_args(self, args): 59 | for l in self.loggers: 60 | l.log_args(args) 61 | 62 | def log_result(self, array, name, step): 63 | for l in self.loggers: 64 | l.log_result(array, name, step) 65 | 66 | def log_figure(self, name, iter, figure, curtime=None): 67 | if curtime is None: 68 | curtime = datetime.now() 69 | for l in self.loggers: 70 | l.log_figure(name, iter, figure, curtime) 71 | 72 | def save_model(self, state_dict, task): 73 | if self.save_models: 74 | for l in self.loggers: 75 | l.save_model(state_dict, task) 76 | -------------------------------------------------------------------------------- /src/loggers/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | from loggers.exp_logger import ExperimentLogger 4 | import json 5 | import numpy as np 6 | 7 | 8 | class Logger(ExperimentLogger): 9 | """Characterizes a Tensorboard logger""" 10 | 11 | def __init__(self, log_path, exp_name, begin_time=None): 12 | super(Logger, self).__init__(log_path, exp_name, begin_time) 13 | self.tbwriter = SummaryWriter(self.exp_path) 14 | 15 | def log_scalar(self, task, iter, name, value, group=None, curtime=None): 16 | self.tbwriter.add_scalar(tag="t{}/{}_{}".format(task, group, name), 17 | scalar_value=value, 18 | global_step=iter) 19 | self.tbwriter.file_writer.flush() 20 | 21 | def log_figure(self, name, iter, figure, curtime=None): 22 | self.tbwriter.add_figure(tag=name, figure=figure, global_step=iter) 23 | self.tbwriter.file_writer.flush() 24 | 25 | def log_args(self, args): 26 | self.tbwriter.add_text( 27 | 'args', 28 | json.dumps(args.__dict__, 29 | separators=(',\n', ' : '), 30 | sort_keys=True)) 31 | self.tbwriter.file_writer.flush() 32 | 33 | def log_result(self, array, name, step): 34 | if array.ndim == 1: 35 | # log as scalars 36 | self.tbwriter.add_scalar(f'results/{name}', array[step], step) 37 | 38 | elif array.ndim == 2: 39 | s = "" 40 | i = step 41 | # for i in range(array.shape[0]): 42 | for j in range(array.shape[1]): 43 | s += '{:5.1f}% '.format(100 * array[i, j]) 44 | if np.trace(array) == 0.0: 45 | if i > 0: 46 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i].mean()) 47 | else: 48 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i + 1].mean()) 49 | self.tbwriter.add_text(f'results/{name}', s, step) 50 | 51 | def __del__(self): 52 | self.tbwriter.close() 53 | -------------------------------------------------------------------------------- /src/networks/README.md: -------------------------------------------------------------------------------- 1 | # Networks 2 | We include a core [network](network.py) class which handles the architecture used as well as the heads needed to do 3 | image classification in an incremental learning setting. 4 | 5 | ## Main usage 6 | When running an experiment, the network model used can be defined in [main_incremental.py](../main_incremental.py) using 7 | `--network`. By default, the existing head of the architecture (usually with 1,000 outputs because of ImageNet) will be 8 | removed since we create a head each time a task is learned. This default behaviour can be disabled by using 9 | `--keep-existing-head`. If the architecture used has the option to use a pretrained model, it can be called with 10 | `--pretrained`. 11 | 12 | We define a [network](network.py) class which contains the architecture model class to be used (torchvision models or 13 | custom), and also a `ModuleList()` of heads that grows incrementally as tasks are learned, called `model` and `heads` 14 | respectively. When doing a forward pass, inputs are passed through the `model` specific forward pass, and the outputs of 15 | it are then fed to the `heads`. This results in a list of outputs, corresponding to the different tasks learned so far 16 | (multi-head base). However, it has to be noted that when using approaches for class-incremental learning, which has no 17 | access to task-ID during test, the heads are treated as if they were concatenated, so the task-ID has no influence. 18 | 19 | We use this system since it would be equivalent to having a head that grows at each task and it would concatenate the 20 | heads after (or create a new head with the new number of outputs and copy the previous heads weights to their respective 21 | positions). However, an advantage of this system is that it allows to update previous heads by adding them to the 22 | optimizer (needed when using exemplars). This is also important when using some regularization such as weight decay, 23 | which would affect previous task heads/outputs if the corresponding weights are included in the optimizer. Furthermore, 24 | it makes it very easy to evaluate on both task-incremental and class-incremental scenarios. 25 | 26 | ### Torchvision models 27 | * Alexnet: `alexnet` 28 | * DenseNet: `densenet121, densenet169, densenet201, densenet161` 29 | * Googlenet: `googlenet` 30 | * Inception: `inception_v3` 31 | * MobileNet: `mobilenet_v2` 32 | * ResNet: `resnet18`, `resnet34`, `resnet50`, `resnet101`, `resnet152`, `resnext50_32x4d`, `resnext101_32x8d` 33 | * ShuffleNet: `shufflenet_v2_x0_5`, `shufflenet_v2_x1_0`, `shufflenet_v2_x1_5`, `shufflenet_v2_x2_0` 34 | * Squeezenet: `squeezenet1_0`, `squeezenet1_1` 35 | * VGG: `vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19_bn`, `vgg19` 36 | * WideResNet: `wide_resnet50_2`, `wide_resnet101_2` 37 | 38 | ### Custom models 39 | We include versions of [LeNet](lenet.py), [ResNet-32](resnet32.py) and [VGGnet](vggnet.py), which use a smaller input 40 | size than the torchvision models. LeNet together with MNIST is useful for quick tests and debugging. 41 | 42 | ## Adding new networks 43 | To add a new custom model architecture, follow this: 44 | 45 | 1. Take as an example [vggnet.py](vggnet.py) and define a Class for the architecture. Initialize all necessary layers 46 | and modules and define a non-incremental last layer (e.g. `self.fc`). Then add `self.head_var = 'fc'` to point to the 47 | variable containing the non-incremental head (it is not important how many classes it has as output since we remove 48 | it when using it for incremental learning). 49 | 2. Define the forward pass of the architecture inside the Class and any other necessary functions. 50 | 3. Define a function outside of the Class to call the model. It needs to contain `num_out` and `pretrained` as inputs. 51 | 4. Include the import to [\_\_init\_\_.py](__init__.py) and add the architecture name to `allmodels`. 52 | 53 | ## Notes 54 | * We provide an implementation of ResNet-32 (see [resnet32.py](resnet32.py)) which is commonly used by several works in 55 | the literature for learning CIFAR-100 in incremental learning scenarios. This network architecture is an adaptation of 56 | ResNet for smaller input size. The number of blocks can be modified in 57 | [this line](https://github.com/mmasana/IL_Survey/blob/9837386d9efddf48d22fc4d23e031248decce68d/src/networks/resnet32.py#L113) 58 | by changing `n=5` to `n=3` for ResNet-20, and `n=9` for ResNet-56. 59 | -------------------------------------------------------------------------------- /src/networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torchvision import models 3 | 4 | from .lenet import LeNet 5 | from .vggnet import VggNet 6 | from .resnet32 import resnet32 7 | from .ovit_tiny_16_augreg_224 import OVit_tiny_16_augreg_224 8 | from .vit_tiny_16_augreg_224 import Vit_tiny_16_augreg_224 9 | from .timm_vit_tiny_16_augreg_224 import Timm_Vit_tiny_16_augreg_224 10 | from .efficient_net import Efficient_net 11 | from .mobile_net import Mobile_net 12 | from .early_conv_vit import Early_conv_vit 13 | 14 | 15 | # available torchvision models 16 | tvmodels = ['alexnet', 17 | 'densenet121', 'densenet169', 'densenet201', 'densenet161', 18 | 'googlenet', 19 | 'inception_v3', 20 | 'mobilenet_v2', 21 | 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 22 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 23 | 'squeezenet1_0', 'squeezenet1_1', 24 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', 25 | 'wide_resnet50_2', 'wide_resnet101_2' 26 | ] 27 | 28 | allmodels = tvmodels + ['resnet32', 'LeNet', 'VggNet', "OVit_tiny_16_augreg_224", "Vit_tiny_16_augreg_224", "Timm_Vit_tiny_16_augreg_224", "Efficient_net", "Mobile_net", "Early_conv_vit"] 29 | 30 | 31 | def set_tvmodel_head_var(model): 32 | if type(model) == models.AlexNet: 33 | model.head_var = 'classifier' 34 | elif type(model) == models.DenseNet: 35 | model.head_var = 'classifier' 36 | elif type(model) == models.Inception3: 37 | model.head_var = 'fc' 38 | elif type(model) == models.ResNet: 39 | model.head_var = 'fc' 40 | elif type(model) == models.VGG: 41 | model.head_var = 'classifier' 42 | elif type(model) == models.GoogLeNet: 43 | model.head_var = 'fc' 44 | elif type(model) == models.MobileNetV2: 45 | model.head_var = 'classifier' 46 | elif type(model) == models.ShuffleNetV2: 47 | model.head_var = 'fc' 48 | elif type(model) == models.SqueezeNet: 49 | model.head_var = 'classifier' 50 | else: 51 | raise ModuleNotFoundError -------------------------------------------------------------------------------- /src/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/early_conv_vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/early_conv_vit.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/early_conv_vit_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/early_conv_vit_net.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/efficient_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/efficient_net.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/lenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/lenet.cpython-36.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/lenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/lenet.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/mobile_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/mobile_net.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/network.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/ovit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit.cpython-36.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/ovit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-36.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/resnet32.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/resnet32.cpython-36.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/resnet32.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/resnet32.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/timm_vit_tiny_16_augreg_224.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/timm_vit_tiny_16_augreg_224.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/vggnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vggnet.cpython-36.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/vggnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vggnet.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/vit_original.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vit_original.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/__pycache__/vit_tiny_16_augreg_224.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vit_tiny_16_augreg_224.cpython-38.pyc -------------------------------------------------------------------------------- /src/networks/early_conv_vit.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from .early_conv_vit_net import EarlyConvViT 4 | 5 | 6 | class Early_conv_vit(nn.Module): 7 | 8 | def __init__(self, num_classes=100, pretrained=False): 9 | super().__init__() 10 | 11 | self.model = EarlyConvViT( 12 | dim=768, 13 | num_classes=2, 14 | depth = 12, 15 | heads = 12, 16 | mlp_dim = 2048, 17 | channels=3, 18 | ) 19 | 20 | #import ipdb; ipdb.set_trace() 21 | if pretrained: 22 | raise NotImplementedError 23 | 24 | # last classifier layer (head) with as many outputs as classes 25 | self.fc = nn.Linear(in_features=768, out_features=num_classes, bias=True) 26 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 27 | self.head_var = 'fc' 28 | 29 | def forward(self, x): 30 | h = self.fc(self.model(x)) 31 | return h 32 | 33 | 34 | def early_conv_vit(num_out=100, pretrained=False): 35 | if pretrained: 36 | return Early_conv_vit(num_out, pretrained) 37 | else: 38 | raise NotImplementedError 39 | assert 1==0, "you should not be here :/" 40 | 41 | -------------------------------------------------------------------------------- /src/networks/early_conv_vit_net.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | 7 | POST_ATTENTION_LIST = [] 8 | POST_REC = False 9 | 10 | def start_post_rec(): 11 | global POST_REC 12 | POST_REC = True 13 | 14 | def stop_post_rec(): 15 | global POST_REC 16 | POST_REC = False 17 | 18 | def get_post_attention_list(): 19 | global POST_ATTENTION_LIST 20 | post_attention_list = POST_ATTENTION_LIST 21 | POST_ATTENTION_LIST = [] 22 | return post_attention_list 23 | 24 | ATTENTION_LIST = [] 25 | REC = False 26 | 27 | def start_rec(): 28 | global REC 29 | REC = True 30 | 31 | def stop_rec(): 32 | global REC 33 | REC = False 34 | 35 | def get_attention_list(): 36 | global ATTENTION_LIST 37 | attention_list = ATTENTION_LIST 38 | ATTENTION_LIST = [] 39 | return attention_list 40 | 41 | 42 | # helpers 43 | def make_tuple(t): 44 | """ 45 | return the input if it's already a tuple. 46 | return a tuple of the input if the input is not already a tuple. 47 | """ 48 | return t if isinstance(t, tuple) else (t, t) 49 | 50 | # classes 51 | class PreNorm(nn.Module): 52 | def __init__(self, dim, fn): 53 | super().__init__() 54 | self.norm = nn.LayerNorm(dim) 55 | self.fn = fn 56 | def forward(self, x, **kwargs): 57 | return self.fn(self.norm(x), **kwargs) 58 | 59 | class FeedForward(nn.Module): 60 | def __init__(self, dim, hidden_dim, dropout=0.): 61 | super().__init__() 62 | self.net = nn.Sequential( 63 | nn.Linear(dim, hidden_dim), 64 | nn.GELU(), 65 | nn.Dropout(dropout), 66 | nn.Linear(hidden_dim, dim), 67 | nn.Dropout(dropout) 68 | ) 69 | def forward(self, x): 70 | return self.net(x) 71 | 72 | class Attention(nn.Module): 73 | def __init__(self, dim, heads=7, dim_head=64, dropout=0.): 74 | """ 75 | reduced the default number of heads by 1 per https://arxiv.org/pdf/2106.14881v2.pdf 76 | """ 77 | super().__init__() 78 | inner_dim = dim_head * heads 79 | project_out = not (heads == 1 and dim_head == dim) 80 | 81 | self.heads = heads 82 | self.scale = dim_head ** -0.5 83 | 84 | self.attend = nn.Softmax(dim = -1) 85 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 86 | 87 | self.to_out = nn.Sequential( 88 | nn.Linear(inner_dim, dim), 89 | nn.Dropout(dropout) 90 | ) if project_out else nn.Identity() 91 | 92 | def forward(self, x): 93 | global REC 94 | global ATTENTION_LIST 95 | global POST_REC 96 | global POST_ATTENTION_LIST 97 | 98 | qkv = self.to_qkv(x).chunk(3, dim = -1) 99 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 100 | 101 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 102 | 103 | if REC: 104 | ATTENTION_LIST.append(dots) 105 | 106 | attn = self.attend(dots) 107 | 108 | out = torch.matmul(attn, v) 109 | 110 | if POST_REC: 111 | POST_ATTENTION_LIST.append(out) 112 | 113 | out = rearrange(out, 'b h n d -> b n (h d)') 114 | return self.to_out(out) 115 | 116 | class Transformer(nn.Module): 117 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 118 | super().__init__() 119 | self.layers = nn.ModuleList([]) 120 | for _ in range(depth): 121 | self.layers.append(nn.ModuleList([ 122 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 123 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 124 | ])) 125 | def forward(self, x): 126 | for attn, ff in self.layers: 127 | x = attn(x) + x 128 | x = ff(x) + x 129 | return x 130 | 131 | class EarlyConvViT(nn.Module): 132 | def __init__(self, *, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.): 133 | """ 134 | 3x3 conv, stride 1, 5 conv layers per https://arxiv.org/pdf/2106.14881v2.pdf 135 | """ 136 | super().__init__() 137 | 138 | n_filter_list = (channels, 24, 48, 96, 196) # hardcoding for now because that's what the paper used 139 | 140 | self.conv_layers = nn.Sequential( 141 | *[nn.Sequential( 142 | nn.Conv2d(in_channels=n_filter_list[i], 143 | out_channels=n_filter_list[i + 1], 144 | kernel_size=3, # hardcoding for now because that's what the paper used 145 | stride=2, # hardcoding for now because that's what the paper used 146 | padding=1), # hardcoding for now because that's what the paper used 147 | ) 148 | for i in range(len(n_filter_list)-1) 149 | ]) 150 | 151 | self.conv_layers.add_module("conv_1x1", torch.nn.Conv2d(in_channels=n_filter_list[-1], 152 | out_channels=dim, 153 | stride=1, # hardcoding for now because that's what the paper used 154 | kernel_size=1, # hardcoding for now because that's what the paper used 155 | padding=0)) # hardcoding for now because that's what the paper used 156 | self.conv_layers.add_module("flatten image", 157 | Rearrange('batch channels height width -> batch (height width) channels')) 158 | self.pos_embedding = nn.Parameter(torch.randn(1, n_filter_list[-1] + 1, dim)) 159 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 160 | self.dropout = nn.Dropout(emb_dropout) 161 | 162 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 163 | 164 | self.pool = pool 165 | self.to_latent = nn.Identity() 166 | 167 | """ 168 | self.mlp_head = nn.Sequential( 169 | nn.LayerNorm(dim), 170 | nn.Linear(dim, num_classes) 171 | ) 172 | """ 173 | 174 | def forward(self, img): 175 | x = self.conv_layers(img) 176 | b, n, _ = x.shape 177 | 178 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 179 | x = torch.cat((cls_tokens, x), dim=1) 180 | x += self.pos_embedding[:, :(n + 1)] 181 | x = self.dropout(x) 182 | 183 | x = self.transformer(x) 184 | 185 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 186 | 187 | x = self.to_latent(x) 188 | return x 189 | #return self.mlp_head(x) 190 | 191 | # It can handle bsize of 50 to 60 192 | -------------------------------------------------------------------------------- /src/networks/efficient_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import timm 5 | 6 | 7 | # It can handle bsize of 50 to 60 8 | 9 | class Efficient_net(nn.Module): 10 | 11 | def __init__(self, num_classes=100, pretrained=False): 12 | super().__init__() 13 | 14 | self.net = torch.hub.load('szq0214/MEAL-V2','meal_v2', 'mealv2_efficientnet_b0', pretrained=False)#.module 15 | 16 | if pretrained: 17 | checkpoint = torch.hub.load_state_dict_from_url('https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_EfficientNet_B0_224.pth', map_location="cuda:3") 18 | state_dict = {key.replace("module.", ""): value for key, value in checkpoint.items()} 19 | self.net.load_state_dict(state_dict) 20 | 21 | 22 | # last classifier layer (head) with as many outputs as classes 23 | self.net.classifier = nn.Identity() 24 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 25 | self.fc = nn.Linear(in_features=1280, out_features=num_classes, bias=True) 26 | self.head_var = 'fc' 27 | def forward(self, x): 28 | h = self.fc(self.net(x)) 29 | return h 30 | 31 | 32 | def efficient_net(num_out=100, pretrained=False): 33 | if pretrained: 34 | return Efficient_net(num_out, pretrained) 35 | else: 36 | raise NotImplementedError 37 | assert 1==0, "you should not be here :/" 38 | -------------------------------------------------------------------------------- /src/networks/fpt.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from .vit_original import VisionTransformer, _load_weights 4 | from transformers import GPT2Model 5 | 6 | 7 | # It can handle bsize of 50 to 60 8 | 9 | class FPT(nn.Module): 10 | 11 | def __init__(self, num_classes=100, pretrained=False): 12 | super().__init__( 13 | 14 | if not pretrained: 15 | asser 1==0, 'cannot run without pretrained' 16 | 17 | self.patch_size = 8 18 | self.input_dim = 3 * self.patch_size**2 19 | 20 | self.in_net = nn.Linear(input_dim, 768, bias=True) 21 | self.fpt = GPT2Model.from_pretrained('gpt2') 22 | self.fc = nn.Linear(in_features=768, out_features=num_classes, bias=True) 23 | self.head_var = 'fc' 24 | 25 | 26 | 27 | 28 | def forward(self, x): 29 | 30 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) 31 | x = self.in_net(x) 32 | with torch.no_grad(): 33 | x = self.fpt(x) 34 | 35 | h = self.fc(x) 36 | 37 | return h 38 | 39 | 40 | def fPT(num_out=100, pretrained=False): 41 | if pretrained: 42 | return FPT(num_out, pretrained) 43 | else: 44 | raise NotImplementedError 45 | assert 1==0, "you should not be here :/" 46 | -------------------------------------------------------------------------------- /src/networks/lenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LeNet(nn.Module): 6 | """LeNet-like network for tests with MNIST (28x28).""" 7 | 8 | def __init__(self, in_channels=1, num_classes=10, **kwargs): 9 | super().__init__() 10 | # main part of the network 11 | self.conv1 = nn.Conv2d(in_channels, 6, 5) 12 | self.conv2 = nn.Conv2d(6, 16, 5) 13 | self.fc1 = nn.Linear(16 * 16, 120) 14 | self.fc2 = nn.Linear(120, 84) 15 | 16 | # last classifier layer (head) with as many outputs as classes 17 | self.fc = nn.Linear(84, num_classes) 18 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 19 | self.head_var = 'fc' 20 | 21 | def forward(self, x): 22 | out = F.relu(self.conv1(x)) 23 | out = F.max_pool2d(out, 2) 24 | out = F.relu(self.conv2(out)) 25 | out = F.max_pool2d(out, 2) 26 | out = out.view(out.size(0), -1) 27 | out = F.relu(self.fc1(out)) 28 | out = F.relu(self.fc2(out)) 29 | out = self.fc(out) 30 | return out 31 | -------------------------------------------------------------------------------- /src/networks/mobile_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import timm 5 | 6 | 7 | # It can handle bsize of 50 to 60 8 | 9 | class Mobile_net(nn.Module): 10 | 11 | def __init__(self, num_classes=100, pretrained=False): 12 | super().__init__() 13 | 14 | self.net = torch.hub.load('szq0214/MEAL-V2','meal_v2', 'mealv2_mobilenet_v3_large_100', pretrained=False)#.module 15 | 16 | if pretrained: 17 | checkpoint = torch.hub.load_state_dict_from_url('https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Large_1.0_224.pth', map_location="cuda:2") 18 | state_dict = {key.replace("module.", ""): value for key, value in checkpoint.items()} 19 | self.net.load_state_dict(state_dict) 20 | 21 | 22 | # last classifier layer (head) with as many outputs as classes 23 | self.net.classifier = nn.Identity() 24 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 25 | self.fc = nn.Linear(in_features=1280, out_features=num_classes, bias=True) 26 | self.head_var = 'fc' 27 | def forward(self, x): 28 | h = self.fc(self.net(x)) 29 | return h 30 | 31 | 32 | def mobile_net(num_out=100, pretrained=False): 33 | if pretrained: 34 | return Mobile_net(num_out, pretrained) 35 | else: 36 | raise NotImplementedError 37 | assert 1==0, "you should not be here :/" 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/networks/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from copy import deepcopy 4 | 5 | 6 | class LLL_Net(nn.Module): 7 | """Basic class for implementing networks""" 8 | 9 | def __init__(self, model, remove_existing_head=False): 10 | head_var = model.head_var 11 | assert type(head_var) == str 12 | assert not remove_existing_head or hasattr(model, head_var), \ 13 | "Given model does not have a variable called {}".format(head_var) 14 | assert not remove_existing_head or type(getattr(model, head_var)) in [nn.Sequential, nn.Linear], \ 15 | "Given model's head {} does is not an instance of nn.Sequential or nn.Linear".format(head_var) 16 | super(LLL_Net, self).__init__() 17 | 18 | self.model = model 19 | last_layer = getattr(self.model, head_var) 20 | 21 | if remove_existing_head: 22 | if type(last_layer) == nn.Sequential: 23 | self.out_size = last_layer[-1].in_features 24 | # strips off last linear layer of classifier 25 | del last_layer[-1] 26 | elif type(last_layer) == nn.Linear: 27 | self.out_size = last_layer.in_features 28 | # converts last layer into identity 29 | # setattr(self.model, head_var, nn.Identity()) 30 | # WARNING: this is for when pytorch version is <1.2 31 | setattr(self.model, head_var, nn.Sequential()) 32 | else: 33 | self.out_size = last_layer.out_features 34 | 35 | self.heads = nn.ModuleList() 36 | self.task_cls = [] 37 | self.task_offset = [] 38 | self._initialize_weights() 39 | 40 | def add_head(self, num_outputs): 41 | """Add a new head with the corresponding number of outputs. Also update the number of classes per task and the 42 | corresponding offsets 43 | """ 44 | self.heads.append(nn.Linear(self.out_size, num_outputs)) 45 | # we re-compute instead of append in case an approach makes changes to the heads 46 | self.task_cls = torch.tensor([head.out_features for head in self.heads]) 47 | self.task_offset = torch.cat([torch.LongTensor(1).zero_(), self.task_cls.cumsum(0)[:-1]]) 48 | 49 | def forward(self, x, return_features=False): 50 | """Applies the forward pass 51 | 52 | Simplification to work on multi-head only -- returns all head outputs in a list 53 | Args: 54 | x (tensor): input images 55 | return_features (bool): return the representations before the heads 56 | """ 57 | x = self.model(x) 58 | assert (len(self.heads) > 0), "Cannot access any head" 59 | y = [] 60 | for head in self.heads: 61 | y.append(head(x)) 62 | if return_features: 63 | return y, x 64 | else: 65 | return y 66 | 67 | def get_copy(self): 68 | """Get weights from the model""" 69 | return deepcopy(self.state_dict()) 70 | 71 | def set_state_dict(self, state_dict): 72 | """Load weights into the model""" 73 | self.load_state_dict(deepcopy(state_dict)) 74 | return 75 | 76 | def freeze_all(self): 77 | """Freeze all parameters from the model, including the heads""" 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def freeze_backbone(self): 82 | """Freeze all parameters from the main model, but not the heads""" 83 | for param in self.model.parameters(): 84 | param.requires_grad = False 85 | 86 | def freeze_bn(self): 87 | """Freeze all Batch Normalization layers from the model and use them in eval() mode""" 88 | for m in self.model.modules(): 89 | if isinstance(m, nn.BatchNorm2d): 90 | m.eval() 91 | 92 | def _initialize_weights(self): 93 | """Initialize weights using different strategies""" 94 | # TODO: add different initialization strategies 95 | pass 96 | -------------------------------------------------------------------------------- /src/networks/ovit_tiny_16_augreg_224.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from .ovit import OVisionTransformer, _load_weights, get_attention_list 4 | 5 | 6 | # It can handle bsize of 50 to 60 7 | 8 | class OVit_tiny_16_augreg_224(nn.Module): 9 | 10 | def __init__(self, num_classes=100, pretrained=False): 11 | super().__init__() 12 | 13 | filename = 'src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz' 14 | 15 | #import ipdb; ipdb.set_trace() 16 | self.ovit = OVisionTransformer(embed_dim=192, num_classes=0, num_heads=3) 17 | if pretrained: 18 | _load_weights(model=self.ovit, checkpoint_path=filename) 19 | 20 | # last classifier layer (head) with as many outputs as classes 21 | self.fc = nn.Linear(in_features=192, out_features=num_classes, bias=True) 22 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 23 | self.head_var = 'fc' 24 | 25 | def forward(self, x): 26 | h = self.fc(self.ovit(x)) 27 | return h 28 | 29 | 30 | def ovit_tiny_16_augreg_224(num_out=100, pretrained=False): 31 | if pretrained: 32 | return OVit_tiny_16_augreg_224(num_out, pretrained) 33 | else: 34 | raise NotImplementedError 35 | assert 1==0, "you should not be here :/" 36 | -------------------------------------------------------------------------------- /src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz -------------------------------------------------------------------------------- /src/networks/resnet32.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['resnet32'] 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | out = self.relu(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | if self.downsample is not None: 29 | residual = self.downsample(x) 30 | out += residual 31 | return self.relu(out) 32 | 33 | 34 | class Bottleneck(nn.Module): 35 | expansion = 4 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None): 38 | super(Bottleneck, self).__init__() 39 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 44 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | out = self.relu(self.bn1(self.conv1(x))) 52 | out = self.relu(self.bn2(self.conv2(out))) 53 | out = self.bn3(self.conv3(out)) 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | out += residual 57 | return self.relu(out) 58 | 59 | 60 | class ResNet(nn.Module): 61 | 62 | def __init__(self, block, layers, num_classes=10): 63 | self.inplanes = 16 64 | super(ResNet, self).__init__() 65 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(16) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.layer1 = self._make_layer(block, 16, layers[0]) 69 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 70 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 71 | self.avgpool = nn.AvgPool2d(8, stride=1) 72 | # last classifier layer (head) with as many outputs as classes 73 | self.fc = nn.Linear(64 * block.expansion, num_classes) 74 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 75 | self.head_var = 'fc' 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 80 | elif isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | def _make_layer(self, block, planes, blocks, stride=1): 85 | downsample = None 86 | if stride != 1 or self.inplanes != planes * block.expansion: 87 | downsample = nn.Sequential( 88 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(planes * block.expansion), 90 | ) 91 | layers = [] 92 | layers.append(block(self.inplanes, planes, stride, downsample)) 93 | self.inplanes = planes * block.expansion 94 | for i in range(1, blocks): 95 | layers.append(block(self.inplanes, planes)) 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | x = self.relu(self.bn1(self.conv1(x))) 100 | x = self.layer1(x) 101 | x = self.layer2(x) 102 | x = self.layer3(x) 103 | x = self.avgpool(x) 104 | x = x.view(x.size(0), -1) 105 | x = self.fc(x) 106 | return x 107 | 108 | 109 | def resnet32(pretrained=False, **kwargs): 110 | if pretrained: 111 | raise NotImplementedError 112 | # change n=3 for ResNet-20, and n=9 for ResNet-56 113 | n = 5 114 | model = ResNet(BasicBlock, [n, n, n], **kwargs) 115 | return model 116 | -------------------------------------------------------------------------------- /src/networks/timm_vit_tiny_16_augreg_224.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import timm 4 | 5 | 6 | # It can handle bsize of 50 to 60 7 | 8 | class Timm_Vit_tiny_16_augreg_224(nn.Module): 9 | 10 | def __init__(self, num_classes=100, pretrained=False): 11 | super().__init__() 12 | 13 | filename = '/home/fpelosin/vit_facil/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz' 14 | 15 | #import ipdb; ipdb.set_trace() 16 | if pretrained: 17 | self.vit = timm.create_model('vit_tiny_patch16_224', num_classes=0) 18 | timm.models.load_checkpoint(self.vit, filename) 19 | else: 20 | self.vit = timm.create_model('vit_tiny_patch16_224', num_classes=0) 21 | 22 | # last classifier layer (head) with as many outputs as classes 23 | self.fc = nn.Linear(in_features=192, out_features=num_classes, bias=True) 24 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 25 | self.head_var = 'fc' 26 | 27 | def forward(self, x): 28 | h = self.fc(self.vit(x)) 29 | return h 30 | 31 | 32 | def timm_vit_tiny_16_augreg_224(num_out=100, pretrained=False): 33 | if pretrained: 34 | return Timm_Vit_tiny_16_augreg_224(num_out, pretrained) 35 | else: 36 | raise NotImplementedError 37 | assert 1==0, "you should not be here :/" 38 | -------------------------------------------------------------------------------- /src/networks/vggnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class VggNet(nn.Module): 6 | """ Following the VGGnet based on VGG16 but for smaller input (64x64) 7 | Check this blog for some info: https://learningai.io/projects/2017/06/29/tiny-imagenet.html 8 | """ 9 | 10 | def __init__(self, num_classes=1000): 11 | super().__init__() 12 | 13 | self.features = nn.Sequential( 14 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=2, stride=2), 19 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=2, stride=2), 24 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=2, stride=2), 31 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | ) 39 | self.fc6 = nn.Linear(in_features=512 * 4 * 4, out_features=4096, bias=True) 40 | self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True) 41 | # last classifier layer (head) with as many outputs as classes 42 | self.fc = nn.Linear(in_features=4096, out_features=num_classes, bias=True) 43 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 44 | self.head_var = 'fc' 45 | 46 | def forward(self, x): 47 | h = self.features(x) 48 | h = h.view(x.size(0), -1) 49 | h = F.dropout(F.relu(self.fc6(h))) 50 | h = F.dropout(F.relu(self.fc7(h))) 51 | h = self.fc(h) 52 | return h 53 | 54 | 55 | def vggnet(num_out=100, pretrained=False): 56 | if pretrained: 57 | raise NotImplementedError 58 | return VggNet(num_out) 59 | -------------------------------------------------------------------------------- /src/networks/vit_tiny_16_augreg_224.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from .vit_original import VisionTransformer, _load_weights 4 | 5 | 6 | # It can handle bsize of 50 to 60 7 | 8 | class Vit_tiny_16_augreg_224(nn.Module): 9 | 10 | def __init__(self, num_classes=100, pretrained=False): 11 | super().__init__() 12 | 13 | filename = '/home/fpelosin/transformers/FACIL/src/networks/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz' 14 | 15 | #import ipdb; ipdb.set_trace() 16 | self.vit = VisionTransformer(embed_dim=192, num_heads=3, num_classes=0) 17 | if pretrained: 18 | _load_weights(model=self.vit, checkpoint_path=filename) 19 | 20 | # last classifier layer (head) with as many outputs as classes 21 | self.fc = nn.Linear(in_features=192, out_features=num_classes, bias=True) 22 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments 23 | self.head_var = 'fc' 24 | 25 | def forward(self, x): 26 | h = self.fc(self.vit(x)) 27 | return h 28 | 29 | 30 | def vit_tiny_16_augreg_224(num_out=100, pretrained=False): 31 | if pretrained: 32 | return Vit_tiny_16_augreg_224(num_out, pretrained) 33 | else: 34 | raise NotImplementedError 35 | assert 1==0, "you should not be here :/" 36 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import timm 5 | import ipdb 6 | from networks.ovit import VisionTransformer, _load_weights 7 | 8 | 9 | filename = '/home/fpelosin/vit_facil/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz' 10 | 11 | 12 | # Custom 13 | ovit = VisionTransformer(embed_dim=192, num_classes=0) 14 | _load_weights(model=ovit, checkpoint_path=filename) 15 | ovit.eval() 16 | 17 | 18 | # timm 19 | #timm.model.visiontransformer 20 | vit = timm.create_model('vit_tiny_patch16_224', num_classes=0) 21 | timm.models.load_checkpoint(vit, filename) 22 | vit.eval() 23 | 24 | inp = torch.rand(1,3,224,224) 25 | 26 | 27 | for i in range(12): 28 | 29 | vit_chunck = torch.nn.Sequential(vit.patch_embed, vit.pos_drop, vit.blocks[:i]) 30 | ovit_chunck = torch.nn.Sequential(ovit.patch_embed, ovit.pos_drop, ovit.blocks[:i]) 31 | 32 | out_vit = vit_chunck(inp) 33 | out_ovit = ovit_chunck(inp) 34 | 35 | if torch.abs(out_vit - out_ovit).sum() > 1e-8: 36 | print(f"diff@{i}") 37 | else: 38 | print(f"NOT diff@{i}") 39 | 40 | 41 | #print(vit(inp)) 42 | #print(ovit(inp)) 43 | 44 | import ipdb; ipdb.set_trace() 45 | 46 | # Sanity check 47 | ovit_state_dict = ovit.state_dict() 48 | 49 | for name, param in vit.named_parameters(): 50 | if 'attn' in name or 'mlp' in name: 51 | ovit_param = ovit_state_dict[name] 52 | if torch.abs(param.data - ovit_param.data).sum() > 1e-8: 53 | print(f'--->[ DIFF ] {name} is NOT same') 54 | else: 55 | print(f'[ OK ] {name} is same') 56 | 57 | -------------------------------------------------------------------------------- /src/tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | The tests in this folder are our tool to check if everything is working as intended and to pick any errors that might 3 | appear when introducing new features. They can also be used to make sure all dependencies are available before running 4 | any long experiments. 5 | 6 | ## Running tests 7 | 8 | ### From console 9 | Type the following code in your console: 10 | ```bash 11 | cd src/ 12 | py.test -s tests/ 13 | ``` 14 | 15 | ### Running tests in parallel 16 | As the amount of tests grow, it can be faster to run them in parallel since they can take few minutes each. 17 | Pytest can support it by using the `pytest-xdist` plugin, e.g. running all tests with 5 workers: 18 | ```bash 19 | cd src/ 20 | py.test -n 5 tests/ 21 | ``` 22 | And for a more verbose output: 23 | ```bash 24 | cd src/ 25 | py.test -sv -n 5 tests/ 26 | ``` 27 | **_Warning:_** It is recommended to run a single test without parallelization the first time, since the first thing that 28 | our test will do is download the dataset (MNIST). If ran in parallel they can start downloading it in multiple workers 29 | at the same time. 30 | 31 | ### From your IDE (PyCharm, VSCode, ...) 32 | `py.tests` are well supported. It's usually enough to select `py.test` as a framework, right click on a test file or 33 | directory and select the option to run pytest tests (not as a python script!). 34 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | 5 | from main_incremental import main 6 | import datasets.dataset_config as c 7 | 8 | 9 | def run_main(args_line, result_dir='results_test', clean_run=False): 10 | assert "--results-path" not in args_line 11 | 12 | print('Staring dir:', os.getcwd()) 13 | if os.getcwd().endswith('tests'): 14 | os.chdir('..') 15 | elif os.getcwd().endswith('IL_Survey'): 16 | os.chdir('src') 17 | elif os.getcwd().endswith('src'): 18 | print('CWD is OK.') 19 | print('Test CWD:', os.getcwd()) 20 | test_results_path = os.getcwd() + f"/../{result_dir}" 21 | 22 | # for testing - use relative path to CWD 23 | c.dataset_config['mnist']['path'] = '../data' 24 | if os.path.exists(test_results_path) and clean_run: 25 | shutil.rmtree(test_results_path) 26 | os.makedirs(test_results_path, exist_ok=True) 27 | args_line += " --results-path {}".format(test_results_path) 28 | 29 | # if distributed test -- use all GPU 30 | worker_id = int(os.environ.get("PYTEST_XDIST_WORKER", "gw-1")[2:]) 31 | if worker_id >= 0 and torch.cuda.is_available(): 32 | gpu_idx = worker_id % torch.cuda.device_count() 33 | args_line += " --gpu {}".format(gpu_idx) 34 | 35 | print('ARGS:', args_line) 36 | return main(args_line.split(' ')) 37 | 38 | 39 | def run_main_and_assert(args_line, 40 | taw_current_task_min=0.01, 41 | tag_current_task_min=0.0, 42 | result_dir='results_test'): 43 | acc_taw, acc_tag, forg_taw, forg_tag, exp_dir = run_main(args_line, result_dir) 44 | 45 | # acc matrices sanity check 46 | assert acc_tag.shape == acc_taw.shape 47 | assert acc_tag.shape == forg_tag.shape 48 | assert acc_tag.shape == forg_taw.shape 49 | 50 | # check current task performance 51 | assert all(acc_tag.diagonal() >= tag_current_task_min) 52 | assert all(acc_taw.diagonal() >= taw_current_task_min) 53 | -------------------------------------------------------------------------------- /src/tests/test_bic.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach bic" 8 | 9 | 10 | def test_bic_exemplars(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --num-exemplars 200" 13 | run_main_and_assert(args_line) 14 | 15 | 16 | def test_bic_exemplars_lambda(): 17 | args_line = FAST_LOCAL_TEST_ARGS 18 | args_line += " --num-exemplars 200" 19 | args_line += " --lamb 1" 20 | run_main_and_assert(args_line) 21 | 22 | 23 | def test_bic_exemplars_per_class(): 24 | args_line = FAST_LOCAL_TEST_ARGS 25 | args_line += " --num-exemplars-per-class 20" 26 | run_main_and_assert(args_line) 27 | 28 | 29 | def test_bic_with_warmup(): 30 | args_line = FAST_LOCAL_TEST_ARGS 31 | args_line += " --warmup-nepochs 5" 32 | args_line += " --warmup-lr-factor 0.5" 33 | args_line += " --num-exemplars 200" 34 | run_main_and_assert(args_line) 35 | -------------------------------------------------------------------------------- /src/tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataset import TensorDataset 4 | 5 | from main_incremental import main 6 | 7 | 8 | def test_dataloader_dataset_swap(): 9 | # given 10 | data1 = TensorDataset(torch.arange(10)) 11 | data2 = TensorDataset(torch.arange(10, 20)) 12 | dl = DataLoader(data1, batch_size=2, shuffle=True, num_workers=1) 13 | # when 14 | batches1 = list(dl) 15 | 16 | try: 17 | dl.dataset += data2 18 | except ValueError: 19 | # In new pytorch this raise an error 20 | # which is expected and OK behaviour 21 | return 22 | 23 | batches2 = list(dl) 24 | all_data = list(dl.dataset) 25 | 26 | # then 27 | assert len(all_data) == 20 28 | assert len(batches1) == 5 29 | assert len(batches2) == 5 30 | # ^ is troublesome! 31 | # Sampler is initialized in DataLoader __init__ 32 | # and it holding reference to old DS. 33 | assert dl.sampler.data_source == data1 34 | # Thus, we will not see the new data. 35 | 36 | 37 | def test_dataloader_multiple_datasets(): 38 | args_line = "--exp-name local_test --approach finetuning --datasets mnist mnist mnist" \ 39 | " --network LeNet --num-tasks 2 --batch-size 32" \ 40 | " --results-path ../results/ --num-workers 0 --nepochs 2" 41 | print('ARGS:', args_line) 42 | main(args_line.split(' ')) 43 | -------------------------------------------------------------------------------- /src/tests/test_datasets_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision.transforms import Lambda 3 | from torch.utils.data.dataset import ConcatDataset 4 | 5 | from datasets.memory_dataset import MemoryDataset 6 | from datasets.exemplars_selection import override_dataset_transform 7 | 8 | 9 | def pic(i): 10 | return np.array([[i]], dtype=np.int8) 11 | 12 | 13 | def test_dataset_transform_override(): 14 | # given 15 | data1 = MemoryDataset({ 16 | 'x': [pic(1), pic(2), pic(3)], 'y': ['a', 'b', 'c'] 17 | }, transform=Lambda(lambda x: np.array(x)[0, 0] * 2)) 18 | data2 = MemoryDataset({ 19 | 'x': [pic(4), pic(5), pic(6)], 'y': ['d', 'e', 'f'] 20 | }, transform=Lambda(lambda x: np.array(x)[0, 0] * 3)) 21 | data3 = MemoryDataset({ 22 | 'x': [pic(7), pic(8), pic(9)], 'y': ['g', 'h', 'i'] 23 | }, transform=Lambda(lambda x: np.array(x)[0, 0] + 10)) 24 | ds = ConcatDataset([data1, ConcatDataset([data2, data3])]) 25 | 26 | # when 27 | x1, y1 = zip(*[ds[i] for i in range(len(ds))]) 28 | with override_dataset_transform(ds, Lambda(lambda x: np.array(x)[0, 0])) as ds_overriden: 29 | x2, y2 = zip(*[ds_overriden[i] for i in range(len(ds_overriden))]) 30 | x3, y3 = zip(*[ds[i] for i in range(len(ds))]) 31 | 32 | # then 33 | assert np.array_equal(x1, [2, 4, 6, 12, 15, 18, 17, 18, 19]) 34 | assert np.array_equal(x2, [1, 2, 3, 4, 5, 6, 7, 8, 9]) 35 | assert np.array_equal(x3, x1) # after everything is back to normal 36 | -------------------------------------------------------------------------------- /src/tests/test_dmc.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets cifar100" \ 4 | " --network resnet32 --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach dmc" \ 8 | " --aux-dataset cifar100" 9 | 10 | 11 | def test_dmc(): 12 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 13 | 14 | -------------------------------------------------------------------------------- /src/tests/test_eeil.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach eeil" 8 | 9 | 10 | def test_eeil_exemplars_with_noise_grad(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --num-exemplars 200" 13 | args_line += " --nepochs-finetuning 1" 14 | args_line += " --noise-grad" 15 | run_main_and_assert(args_line) 16 | 17 | 18 | def test_eeil_exemplars(): 19 | args_line = FAST_LOCAL_TEST_ARGS 20 | args_line += " --num-exemplars 200" 21 | args_line += " --nepochs-finetuning 1" 22 | run_main_and_assert(args_line) 23 | 24 | 25 | def test_eeil_with_warmup(): 26 | args_line = FAST_LOCAL_TEST_ARGS 27 | args_line += " --warmup-nepochs 5" 28 | args_line += " --warmup-lr-factor 0.5" 29 | args_line += " --num-exemplars 200" 30 | args_line += " --nepochs-finetuning 1" 31 | run_main_and_assert(args_line) 32 | -------------------------------------------------------------------------------- /src/tests/test_ewc.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach ewc" 8 | 9 | 10 | def test_ewc_without_exemplars(): 11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 12 | 13 | 14 | def test_ewc_with_exemplars(): 15 | args_line = FAST_LOCAL_TEST_ARGS 16 | args_line += " --num-exemplars 200" 17 | run_main_and_assert(args_line) 18 | 19 | 20 | def test_ewc_with_warmup(): 21 | args_line = FAST_LOCAL_TEST_ARGS 22 | args_line += " --warmup-nepochs 5" 23 | args_line += " --warmup-lr-factor 0.5" 24 | args_line += " --num-exemplars 200" 25 | run_main_and_assert(args_line) 26 | -------------------------------------------------------------------------------- /src/tests/test_finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from tests import run_main_and_assert 5 | 6 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 7 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 8 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 9 | " --num-workers 0" 10 | 11 | 12 | def test_finetuning_without_exemplars(): 13 | args_line = FAST_LOCAL_TEST_ARGS 14 | args_line += " --approach finetuning" 15 | run_main_and_assert(args_line) 16 | 17 | 18 | def test_finetuning_with_exemplars(): 19 | args_line = FAST_LOCAL_TEST_ARGS 20 | args_line += " --approach finetuning" 21 | args_line += " --num-exemplars 200" 22 | run_main_and_assert(args_line) 23 | 24 | 25 | @pytest.mark.xfail 26 | def test_finetuning_with_exemplars_per_class_and_herding(): 27 | args_line = FAST_LOCAL_TEST_ARGS 28 | args_line += " --approach finetuning" 29 | args_line += " --num-exemplars-per-class 10" 30 | args_line += " --exemplar-selection herding" 31 | run_main_and_assert(args_line) 32 | 33 | 34 | def test_finetuning_with_exemplars_per_class_and_entropy(): 35 | args_line = FAST_LOCAL_TEST_ARGS 36 | args_line += " --approach finetuning" 37 | args_line += " --num-exemplars-per-class 10" 38 | args_line += " --exemplar-selection entropy" 39 | run_main_and_assert(args_line) 40 | 41 | 42 | def test_finetuning_with_exemplars_per_class_and_distance(): 43 | args_line = FAST_LOCAL_TEST_ARGS 44 | args_line += " --approach finetuning" 45 | args_line += " --num-exemplars-per-class 10" 46 | args_line += " --exemplar-selection distance" 47 | run_main_and_assert(args_line) 48 | 49 | 50 | def test_wrong_args(): 51 | with pytest.raises(SystemExit): # error of providing both args 52 | args_line = FAST_LOCAL_TEST_ARGS 53 | args_line += " --approach finetuning" 54 | args_line += " --num-exemplars-per-class 10" 55 | args_line += " --num-exemplars 200" 56 | run_main_and_assert(args_line) 57 | 58 | 59 | def test_finetuning_with_eval_on_train(): 60 | args_line = FAST_LOCAL_TEST_ARGS 61 | args_line += " --approach finetuning" 62 | args_line += " --num-exemplars-per-class 10" 63 | args_line += " --exemplar-selection distance" 64 | args_line += " --eval-on-train" 65 | run_main_and_assert(args_line) 66 | 67 | def test_finetuning_with_no_cudnn_deterministic(): 68 | args_line = FAST_LOCAL_TEST_ARGS 69 | args_line += " --approach finetuning" 70 | args_line += " --num-exemplars-per-class 10" 71 | args_line += " --exemplar-selection distance" 72 | 73 | run_main_and_assert(args_line) 74 | assert torch.backends.cudnn.deterministic == True 75 | 76 | args_line += " --no-cudnn-deterministic" 77 | run_main_and_assert(args_line) 78 | assert torch.backends.cudnn.deterministic == False 79 | -------------------------------------------------------------------------------- /src/tests/test_fix_bn.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 6 | " --num-workers 0 --fix-bn" 7 | 8 | 9 | def test_finetuning_fix_bn(): 10 | args_line = FAST_LOCAL_TEST_ARGS 11 | args_line += " --approach finetuning" 12 | run_main_and_assert(args_line) 13 | 14 | 15 | def test_joint_fix_bn(): 16 | args_line = FAST_LOCAL_TEST_ARGS 17 | args_line += " --approach joint" 18 | run_main_and_assert(args_line) 19 | 20 | 21 | def test_freezingt_fix_bn(): 22 | args_line = FAST_LOCAL_TEST_ARGS 23 | args_line += " --approach freezing" 24 | run_main_and_assert(args_line) 25 | 26 | 27 | def test_icarl_fix_bn(): 28 | args_line = FAST_LOCAL_TEST_ARGS 29 | args_line += " --num-exemplars 200" 30 | args_line += " --approach icarl" 31 | run_main_and_assert(args_line) 32 | 33 | 34 | def test_eeil_fix_bn(): 35 | args_line = FAST_LOCAL_TEST_ARGS 36 | args_line += " --num-exemplars 200" 37 | args_line += " --approach eeil" 38 | run_main_and_assert(args_line) 39 | 40 | 41 | def test_mas_fix_bn(): 42 | args_line = FAST_LOCAL_TEST_ARGS 43 | args_line += " --approach mas" 44 | run_main_and_assert(args_line) 45 | 46 | 47 | def test_lwf_fix_bn(): 48 | args_line = FAST_LOCAL_TEST_ARGS 49 | args_line += " --approach lwf" 50 | run_main_and_assert(args_line) 51 | 52 | 53 | def test_lwm_fix_bn(): 54 | args_line = FAST_LOCAL_TEST_ARGS 55 | args_line += " --approach lwm --gradcam-layer conv2 --log-gradcam-samples 16" 56 | run_main_and_assert(args_line) 57 | 58 | 59 | def test_r_walk_fix_bn(): 60 | args_line = FAST_LOCAL_TEST_ARGS 61 | args_line += " --approach r_walk" 62 | run_main_and_assert(args_line) 63 | 64 | 65 | def test_path_integral_fix_bn(): 66 | args_line = FAST_LOCAL_TEST_ARGS 67 | args_line += " --approach path_integral" 68 | run_main_and_assert(args_line) 69 | 70 | 71 | def test_luci_fix_bn(): 72 | args_line = FAST_LOCAL_TEST_ARGS 73 | args_line += " --approach lucir" 74 | run_main_and_assert(args_line) 75 | 76 | 77 | def test_ewc_fix_bn(): 78 | args_line = FAST_LOCAL_TEST_ARGS 79 | args_line += " --approach ewc" 80 | run_main_and_assert(args_line) 81 | -------------------------------------------------------------------------------- /src/tests/test_freezing.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach freezing" 8 | 9 | 10 | def test_freezing_without_exemplars(): 11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 12 | 13 | 14 | def test_freezing_with_exemplars(): 15 | args_line = FAST_LOCAL_TEST_ARGS 16 | args_line += " --num-exemplars 200" 17 | run_main_and_assert(args_line) 18 | 19 | 20 | def test_freezing_with_warmup(): 21 | args_line = FAST_LOCAL_TEST_ARGS 22 | args_line += " --warmup-nepochs 5" 23 | args_line += " --warmup-lr-factor 0.5" 24 | run_main_and_assert(args_line) 25 | -------------------------------------------------------------------------------- /src/tests/test_gridsearch.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3 --num-workers 0" \ 6 | " --gridsearch-tasks 3 --gridsearch-config gridsearch_config" \ 7 | " --gridsearch-acc-drop-thr 0.2 --gridsearch-hparam-decay 0.5" 8 | 9 | 10 | def test_gridsearch_finetuning(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --approach finetuning --num-exemplars 200" 13 | run_main_and_assert(args_line) 14 | 15 | 16 | def test_gridsearch_freezing(): 17 | args_line = FAST_LOCAL_TEST_ARGS 18 | args_line += " --approach freezing --num-exemplars 200" 19 | run_main_and_assert(args_line) 20 | 21 | 22 | def test_gridsearch_joint(): 23 | args_line = FAST_LOCAL_TEST_ARGS 24 | args_line += " --approach joint" 25 | run_main_and_assert(args_line) 26 | 27 | 28 | def test_gridsearch_lwf(): 29 | args_line = FAST_LOCAL_TEST_ARGS 30 | args_line += " --approach lwf --num-exemplars 200" 31 | run_main_and_assert(args_line) 32 | 33 | 34 | def test_gridsearch_icarl(): 35 | args_line = FAST_LOCAL_TEST_ARGS 36 | args_line += " --approach icarl --num-exemplars 200" 37 | run_main_and_assert(args_line) 38 | 39 | 40 | def test_gridsearch_eeil(): 41 | args_line = FAST_LOCAL_TEST_ARGS 42 | args_line += " --approach eeil --nepochs-finetuning 3 --num-exemplars 200" 43 | run_main_and_assert(args_line) 44 | 45 | 46 | def test_gridsearch_bic(): 47 | args_line = FAST_LOCAL_TEST_ARGS 48 | args_line += " --approach bic --num-bias-epochs 3 --num-exemplars 200" 49 | run_main_and_assert(args_line) 50 | 51 | 52 | def test_gridsearch_lucir(): 53 | args_line = FAST_LOCAL_TEST_ARGS 54 | args_line += " --approach lucir --num-exemplars 200" 55 | run_main_and_assert(args_line) 56 | 57 | 58 | def test_gridsearch_lwm(): 59 | args_line = FAST_LOCAL_TEST_ARGS 60 | args_line += " --approach lwm --gradcam-layer conv2 --log-gradcam-samples 16 --num-exemplars 200" 61 | run_main_and_assert(args_line) 62 | 63 | 64 | def test_gridsearch_ewc(): 65 | args_line = FAST_LOCAL_TEST_ARGS 66 | args_line += " --approach ewc --num-exemplars 200" 67 | run_main_and_assert(args_line) 68 | 69 | 70 | def test_gridsearch_mas(): 71 | args_line = FAST_LOCAL_TEST_ARGS 72 | args_line += " --approach mas --num-exemplars 200" 73 | run_main_and_assert(args_line) 74 | 75 | 76 | def test_gridsearch_pathint(): 77 | args_line = FAST_LOCAL_TEST_ARGS 78 | args_line += " --approach path_integral --num-exemplars 200" 79 | run_main_and_assert(args_line) 80 | 81 | 82 | def test_gridsearch_rwalk(): 83 | args_line = FAST_LOCAL_TEST_ARGS 84 | args_line += " --approach r_walk --num-exemplars 200" 85 | run_main_and_assert(args_line) 86 | 87 | 88 | def test_gridsearch_dmc(): 89 | args_line = FAST_LOCAL_TEST_ARGS 90 | args_line += " --approach dmc" 91 | args_line += " --aux-dataset mnist" # just to test the grid search fast 92 | run_main_and_assert(args_line) 93 | -------------------------------------------------------------------------------- /src/tests/test_icarl.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach icarl" 8 | 9 | 10 | def test_icarl_exemplars(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --num-exemplars 200" 13 | args_line += " --lamb 1" 14 | run_main_and_assert(args_line) 15 | 16 | 17 | def test_icarl_exemplars_without_lamb(): 18 | args_line = FAST_LOCAL_TEST_ARGS 19 | args_line += " --num-exemplars 200" 20 | args_line += " --lamb 0" 21 | run_main_and_assert(args_line) 22 | 23 | 24 | def test_icarl_with_warmup(): 25 | args_line = FAST_LOCAL_TEST_ARGS 26 | args_line += " --warmup-nepochs 5" 27 | args_line += " --warmup-lr-factor 0.5" 28 | args_line += " --num-exemplars 200" 29 | args_line += " --lamb 1" 30 | run_main_and_assert(args_line) 31 | -------------------------------------------------------------------------------- /src/tests/test_il2m.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach il2m" 8 | 9 | 10 | def test_il2m(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --num-exemplars 200" 13 | run_main_and_assert(args_line) 14 | 15 | 16 | def test_il2m_with_warmup(): 17 | args_line = FAST_LOCAL_TEST_ARGS 18 | args_line += " --warmup-nepochs 5" 19 | args_line += " --warmup-lr-factor 0.5" 20 | args_line += " --num-exemplars 200" 21 | run_main_and_assert(args_line) 22 | -------------------------------------------------------------------------------- /src/tests/test_joint.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 6 | " --num-workers 0" 7 | 8 | 9 | def test_joint(): 10 | args_line = FAST_LOCAL_TEST_ARGS 11 | args_line += " --approach joint" 12 | run_main_and_assert(args_line) 13 | -------------------------------------------------------------------------------- /src/tests/test_last_layer_analysis.py: -------------------------------------------------------------------------------- 1 | from tests import run_main 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 6 | " --num-workers 0" \ 7 | " --approach finetuning" 8 | 9 | 10 | def test_last_layer_analysis(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --last-layer-analysis" 13 | run_main(args_line) 14 | 15 | -------------------------------------------------------------------------------- /src/tests/test_loggers.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | 4 | from tests import run_main 5 | 6 | FAST_LOCAL_TEST_ARGS = "--exp-name loggers_test --datasets mnist" \ 7 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 8 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 9 | " --num-workers 0 --approach finetuning" 10 | 11 | 12 | def test_disk_and_tensorflow_logger(): 13 | 14 | args_line = FAST_LOCAL_TEST_ARGS 15 | args_line += " --log disk tensorboard" 16 | result = run_main(args_line, 'results_test_loggers', clean_run=True) 17 | experiment_dir = Path(result[-1]) 18 | 19 | # check disk logger 20 | assert experiment_dir.is_dir() 21 | raw_logs = list(experiment_dir.glob('raw_log-*.txt')) 22 | assert len(raw_logs) == 1 23 | df = pd.read_json(raw_logs[0], lines=True) 24 | assert sorted(df.iter.unique()) == [0, 1, 2] 25 | assert sorted(df.group.unique()) == ['test', 'train', 'valid'] 26 | assert len(df.group.unique()) == 3 27 | 28 | # check tb logger 29 | tb_events_logs = list(experiment_dir.glob('events.out.tfevents*')) 30 | assert len(tb_events_logs) == 1 31 | assert experiment_dir.joinpath(tb_events_logs[0]).is_file() 32 | -------------------------------------------------------------------------------- /src/tests/test_lucir.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --gridsearch-tasks -1" \ 8 | " --approach lucir" 9 | 10 | 11 | def test_lucir_exemplars(): 12 | args_line = FAST_LOCAL_TEST_ARGS 13 | args_line += " --num-exemplars-per-class 20" 14 | run_main_and_assert(args_line) 15 | 16 | 17 | def test_lucir_exemplars_with_gridsearch(): 18 | args_line = FAST_LOCAL_TEST_ARGS 19 | args_line += " --num-exemplars-per-class 20" 20 | args_line = args_line.replace('--gridsearch-tasks -1', '--gridsearch-tasks 3') 21 | run_main_and_assert(args_line) 22 | 23 | 24 | def test_lucir_exemplars(): 25 | args_line = FAST_LOCAL_TEST_ARGS 26 | args_line += " --num-exemplars-per-class 20" 27 | run_main_and_assert(args_line) 28 | 29 | 30 | def test_lucir_exemplars_remove_margin_ranking(): 31 | args_line = FAST_LOCAL_TEST_ARGS 32 | args_line += " --num-exemplars-per-class 20" 33 | args_line += " --remove-margin-ranking" 34 | run_main_and_assert(args_line) 35 | 36 | 37 | def test_lucir_exemplars_remove_adapt_lamda(): 38 | args_line = FAST_LOCAL_TEST_ARGS 39 | args_line += " --num-exemplars-per-class 20" 40 | args_line += " --remove-adapt-lamda" 41 | run_main_and_assert(args_line) 42 | 43 | 44 | def test_lucir_exemplars_warmup(): 45 | args_line = FAST_LOCAL_TEST_ARGS 46 | args_line += " --num-exemplars-per-class 20" 47 | args_line += " --warmup-nepochs 5" 48 | args_line += " --warmup-lr-factor 0.5" 49 | run_main_and_assert(args_line) 50 | -------------------------------------------------------------------------------- /src/tests/test_lwf.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach lwf" 8 | 9 | 10 | def test_lwf_without_exemplars(): 11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 12 | 13 | 14 | def test_lwf_with_exemplars(): 15 | args_line = FAST_LOCAL_TEST_ARGS 16 | args_line += " --num-exemplars 200" 17 | run_main_and_assert(args_line) 18 | 19 | 20 | def test_lwf_with_warmup(): 21 | args_line = FAST_LOCAL_TEST_ARGS 22 | args_line += " --warmup-nepochs 5" 23 | args_line += " --warmup-lr-factor 0.5" 24 | args_line += " --num-exemplars 200" 25 | run_main_and_assert(args_line) 26 | -------------------------------------------------------------------------------- /src/tests/test_lwm.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach lwm --gradcam-layer conv2" 8 | 9 | 10 | def test_lwm_without_exemplars(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --log-gradcam-samples 16" 13 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 14 | 15 | 16 | def test_lwm_with_exemplars(): 17 | args_line = FAST_LOCAL_TEST_ARGS 18 | args_line += " --num-exemplars 200" 19 | run_main_and_assert(args_line) 20 | 21 | 22 | def test_lwm_with_warmup(): 23 | args_line = FAST_LOCAL_TEST_ARGS 24 | args_line += " --warmup-nepochs 5" 25 | args_line += " --warmup-lr-factor 0.5" 26 | args_line += " --num-exemplars 200" 27 | run_main_and_assert(args_line) 28 | -------------------------------------------------------------------------------- /src/tests/test_mas.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach mas" 8 | 9 | 10 | def test_mas_without_exemplars(): 11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 12 | 13 | 14 | def test_mas_with_exemplars(): 15 | args_line = FAST_LOCAL_TEST_ARGS 16 | args_line += " --num-exemplars 200" 17 | run_main_and_assert(args_line) 18 | 19 | 20 | def test_mas_with_warmup(): 21 | args_line = FAST_LOCAL_TEST_ARGS 22 | args_line += " --warmup-nepochs 5" 23 | args_line += " --warmup-lr-factor 0.5" 24 | args_line += " --num-exemplars 200" 25 | run_main_and_assert(args_line) 26 | -------------------------------------------------------------------------------- /src/tests/test_multisoftmax.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 6 | " --num-workers 0" 7 | 8 | 9 | def test_finetuning_without_multisoftmax(): 10 | args_line = FAST_LOCAL_TEST_ARGS 11 | args_line += " --approach finetuning" 12 | run_main_and_assert(args_line) 13 | 14 | 15 | def test_finetuning_with_multisoftmax(): 16 | args_line = FAST_LOCAL_TEST_ARGS 17 | args_line += " --approach finetuning" 18 | args_line += " --multi-softmax" 19 | run_main_and_assert(args_line) 20 | -------------------------------------------------------------------------------- /src/tests/test_path_integral.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach path_integral" 8 | 9 | 10 | def test_pi_without_exemplars(): 11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS) 12 | 13 | 14 | def test_pi_with_exemplars(): 15 | args_line = FAST_LOCAL_TEST_ARGS 16 | args_line += " --num-exemplars 200" 17 | run_main_and_assert(args_line) 18 | 19 | 20 | def test_pi_with_warmup(): 21 | args_line = FAST_LOCAL_TEST_ARGS 22 | args_line += " --warmup-nepochs 5" 23 | args_line += " --warmup-lr-factor 0.5" 24 | args_line += " --num-exemplars 200" 25 | run_main_and_assert(args_line) 26 | -------------------------------------------------------------------------------- /src/tests/test_rwalk.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 3" \ 6 | " --num-workers 0" \ 7 | " --approach r_walk" 8 | 9 | 10 | def test_rwalk_without_exemplars(): 11 | args_line = FAST_LOCAL_TEST_ARGS 12 | args_line += " --num-exemplars 0" 13 | run_main_and_assert(args_line) 14 | 15 | 16 | def test_rwalk_with_exemplars(): 17 | args_line = FAST_LOCAL_TEST_ARGS 18 | args_line += " --num-exemplars 200" 19 | run_main_and_assert(args_line) 20 | 21 | 22 | def test_rwalk_with_warmup(): 23 | args_line = FAST_LOCAL_TEST_ARGS 24 | args_line += " --warmup-nepochs 5" 25 | args_line += " --warmup-lr-factor 0.5" 26 | args_line += " --num-exemplars 200" 27 | run_main_and_assert(args_line) 28 | -------------------------------------------------------------------------------- /src/tests/test_stop_at_task.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 5 --seed 1 --batch-size 32" \ 5 | " --nepochs 2 --num-workers 0 --stop-at-task 3" 6 | 7 | 8 | def test_finetuning_stop_at_task(): 9 | args_line = FAST_LOCAL_TEST_ARGS 10 | args_line += " --approach finetuning" 11 | run_main_and_assert(args_line) 12 | -------------------------------------------------------------------------------- /src/tests/test_warmup.py: -------------------------------------------------------------------------------- 1 | from tests import run_main_and_assert 2 | 3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \ 4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \ 5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \ 6 | " --num-workers 0" 7 | 8 | 9 | def test_finetuning_without_warmup(): 10 | args_line = FAST_LOCAL_TEST_ARGS 11 | args_line += " --approach finetuning" 12 | run_main_and_assert(args_line) 13 | 14 | 15 | def test_finetuning_with_warmup(): 16 | args_line = FAST_LOCAL_TEST_ARGS 17 | args_line += " --approach finetuning" 18 | args_line += " --warmup-nepochs 5" 19 | args_line += " --warmup-lr-factor 0.5" 20 | run_main_and_assert(args_line) 21 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | cudnn_deterministic = True 7 | 8 | 9 | def seed_everything(seed=0): 10 | """Fix all random seeds""" 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.backends.cudnn.deterministic = cudnn_deterministic 17 | 18 | 19 | def print_summary(acc_taw, acc_tag, forg_taw, forg_tag): 20 | """Print summary of results""" 21 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]): 22 | print('*' * 108) 23 | print(name) 24 | for i in range(metric.shape[0]): 25 | print('\t', end='') 26 | for j in range(metric.shape[1]): 27 | print('{:5.1f}% '.format(100 * metric[i, j]), end='') 28 | if np.trace(metric) == 0.0: 29 | if i > 0: 30 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='') 31 | else: 32 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i + 1].mean()), end='') 33 | print() 34 | print('*' * 108) 35 | --------------------------------------------------------------------------------