├── LICENSE
├── README.md
├── docs
└── _static
│ ├── acc_mat.png
│ ├── asym_illustration(1).png
│ ├── att_fun.png
│ ├── cil_survey_approaches.png
│ ├── facil_logo.jpg
│ ├── facil_logo.png
│ ├── tb1.png
│ └── tb2.png
├── environment.yml
├── requirements.txt
├── scripts
├── script_cifar100.sh
└── script_imagenet.sh
└── src
├── README.md
├── __pycache__
├── last_layer_analysis.cpython-36.pyc
├── last_layer_analysis.cpython-38.pyc
├── utils.cpython-36.pyc
└── utils.cpython-38.pyc
├── approach
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ ├── ewc.cpython-38.pyc
│ ├── incremental_learning.cpython-38.pyc
│ ├── lwf.cpython-38.pyc
│ ├── olwf.cpython-38.pyc
│ ├── olwf_asym.cpython-38.pyc
│ ├── olwf_asympost.cpython-38.pyc
│ ├── olwf_jsd.cpython-38.pyc
│ └── path_integral.cpython-38.pyc
├── bic.py
├── dmc.py
├── eeil.py
├── ewc.py
├── finetuning.py
├── freezing.py
├── icarl.py
├── il2m.py
├── incremental_learning.py
├── joint.py
├── lucir.py
├── lwf.py
├── lwm.py
├── mas.py
├── oewc.py
├── olwf_asym.py
├── olwf_asym_original.py
├── olwf_asympost.py
├── path_integral.py
└── r_walk.py
├── datasets
├── README.md
├── __pycache__
│ ├── base_dataset.cpython-36.pyc
│ ├── base_dataset.cpython-38.pyc
│ ├── data_loader.cpython-36.pyc
│ ├── data_loader.cpython-38.pyc
│ ├── dataset_config.cpython-36.pyc
│ ├── dataset_config.cpython-38.pyc
│ ├── exemplars_dataset.cpython-38.pyc
│ ├── exemplars_selection.cpython-38.pyc
│ ├── memory_dataset.cpython-36.pyc
│ └── memory_dataset.cpython-38.pyc
├── base_dataset.py
├── data_loader.py
├── dataset_config.py
├── exemplars_dataset.py
├── exemplars_selection.py
└── memory_dataset.py
├── gridsearch.py
├── gridsearch_config.py
├── last_layer_analysis.py
├── loggers
├── README.md
├── __pycache__
│ ├── disk_logger.cpython-38.pyc
│ ├── exp_logger.cpython-36.pyc
│ ├── exp_logger.cpython-38.pyc
│ └── tensorboard_logger.cpython-38.pyc
├── disk_logger.py
├── exp_logger.py
└── tensorboard_logger.py
├── main_incremental.py
├── networks
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ ├── early_conv_vit.cpython-38.pyc
│ ├── early_conv_vit_net.cpython-38.pyc
│ ├── efficient_net.cpython-38.pyc
│ ├── lenet.cpython-36.pyc
│ ├── lenet.cpython-38.pyc
│ ├── mobile_net.cpython-38.pyc
│ ├── network.cpython-38.pyc
│ ├── ovit.cpython-36.pyc
│ ├── ovit.cpython-38.pyc
│ ├── ovit_tiny_16_augreg_224.cpython-36.pyc
│ ├── ovit_tiny_16_augreg_224.cpython-38.pyc
│ ├── resnet32.cpython-36.pyc
│ ├── resnet32.cpython-38.pyc
│ ├── timm_vit_tiny_16_augreg_224.cpython-38.pyc
│ ├── vggnet.cpython-36.pyc
│ ├── vggnet.cpython-38.pyc
│ ├── vit_original.cpython-38.pyc
│ └── vit_tiny_16_augreg_224.cpython-38.pyc
├── early_conv_vit.py
├── early_conv_vit_net.py
├── efficient_net.py
├── fpt.py
├── lenet.py
├── mobile_net.py
├── network.py
├── ovit.py
├── ovit_tiny_16_augreg_224.py
├── pretrained_weights
│ └── augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz
├── resnet32.py
├── timm_vit_tiny_16_augreg_224.py
├── vggnet.py
├── vit_original.py
└── vit_tiny_16_augreg_224.py
├── test.ipynb
├── test.py
├── tests
├── README.md
├── __init__.py
├── test_bic.py
├── test_dataloader.py
├── test_datasets_transforms.py
├── test_dmc.py
├── test_eeil.py
├── test_ewc.py
├── test_finetuning.py
├── test_fix_bn.py
├── test_freezing.py
├── test_gridsearch.py
├── test_icarl.py
├── test_il2m.py
├── test_joint.py
├── test_last_layer_analysis.py
├── test_loggers.py
├── test_lucir.py
├── test_lwf.py
├── test_lwm.py
├── test_mas.py
├── test_multisoftmax.py
├── test_path_integral.py
├── test_rwalk.py
├── test_stop_at_task.py
└── test_warmup.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Marc Masana Castrillo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Continual Learning with Vision Transformers
2 |
3 | > Update: Our paper wins the best runner-up award at the [3rd CLVision workshop](https://sites.google.com/view/clvision2022/call-for-papers/accepted-papers).
4 |
5 | This repo hosts the official implementation of our CVPR 2022 workshop paper [Towards Exemplar-Free Continual Learning in Vision Transformers: an Account of Attention, Functional and Weight Regularization](https://openaccess.thecvf.com/content/CVPR2022W/CLVision/html/Pelosin_Towards_Exemplar-Free_Continual_Learning_in_Vision_Transformers_An_Account_of_CVPRW_2022_paper.html).
6 |
7 | TLDR; We introduce attentional and functional variants for asymmetric and symmetric Pooled Attention Distillation (PAD) losses in Vision Transformers:
8 |
9 |

10 |
11 |
12 | ## Running the code
13 |
14 |
15 |
.png)
16 |
17 |
18 |
19 | Given below are two examples for the asymmetric attentional and functional variants pooling along the height dimension on ImageNet-100.
20 | 1. Attentional variant:
21 |
22 | ```python
23 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asym --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_attentional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' l
24 | ```
25 |
26 | 2. Functional variant:
27 | ```python
28 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asympost --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_functional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height'
29 | ```
30 |
31 | The corresponding runs for symmetric variants would then be:
32 | 1. Attentional variant:
33 |
34 | ```python
35 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asym --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_attentional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' --sym
36 | ```
37 |
38 | 2. Functional variant:
39 | ```python
40 | >>> python3 -u src/main_incremental.py --datasets imagenet_32_reduced --network Early_conv_vit --approach olwf_asympost --nepochs $NEPOCHS --log disk --batch-size 1024 --gpu $GPU --exp-name dummy_functional_exp --lr 0.01 --seed ${seed} --lamb 1.0 --num-tasks $NUM_TASKS --nc-first-task $NC_FIRST_TASK --lr-patience 20 --plast_mu 1.0 --pool-along 'height' --sym
41 | ```
42 |
43 | Other available continual learning approaches with Vision Transformers include:
44 |
45 |
46 | EWC • Finetuning • LwF • PathInt
47 |
48 |
49 |
50 | The detailed scripts for our experiments can be found in `scripts/`.
51 |
52 | ## Cite
53 | If you found our implementation to be useful, feel free to use the citation:
54 | ```bibtex
55 | @InProceedings{Pelosin_Jha_CVPR,
56 | author = {Pelosin, Francesco and Jha, Saurav and Torsello, Andrea and Raducanu, Bogdan and van de Weijer, Joost},
57 | title = {Towards Exemplar-Free Continual Learning in Vision Transformers: An Account of Attention, Functional and Weight Regularization},
58 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
59 | month = {June},
60 | year = {2022},
61 | pages = {3820-3829}
62 | }
63 | ```
64 | ## Acknowledgement
65 | This repo is based on [FACIL](https://github.com/mmasana/FACIL).
66 |
67 |
68 |

69 |
70 |
--------------------------------------------------------------------------------
/docs/_static/acc_mat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/acc_mat.png
--------------------------------------------------------------------------------
/docs/_static/asym_illustration(1).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/asym_illustration(1).png
--------------------------------------------------------------------------------
/docs/_static/att_fun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/att_fun.png
--------------------------------------------------------------------------------
/docs/_static/cil_survey_approaches.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/cil_survey_approaches.png
--------------------------------------------------------------------------------
/docs/_static/facil_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/facil_logo.jpg
--------------------------------------------------------------------------------
/docs/_static/facil_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/facil_logo.png
--------------------------------------------------------------------------------
/docs/_static/tb1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/tb1.png
--------------------------------------------------------------------------------
/docs/_static/tb2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/docs/_static/tb2.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: FACIL
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - absl-py=0.11.0=pyhd3eb1b0_1
8 | - aiohttp=3.7.3=py38h27cfd23_1
9 | - apipkg=1.5=py38_0
10 | - async-timeout=3.0.1=py38h06a4308_0
11 | - attrs=20.3.0=pyhd3eb1b0_0
12 | - blas=1.0=mkl
13 | - blinker=1.4=py38h06a4308_0
14 | - brotlipy=0.7.0=py38h27cfd23_1003
15 | - c-ares=1.17.1=h27cfd23_0
16 | - ca-certificates=2021.1.19=h06a4308_0
17 | - cachetools=4.2.1=pyhd3eb1b0_0
18 | - certifi=2020.12.5=py38h06a4308_0
19 | - cffi=1.14.4=py38h261ae71_0
20 | - chardet=3.0.4=py38h06a4308_1003
21 | - click=7.1.2=pyhd3eb1b0_0
22 | - cryptography=2.9.2=py38h1ba5d50_0
23 | - cudatoolkit=10.1.243=h6bb024c_0
24 | - cycler=0.10.0=py38_0
25 | - dbus=1.13.18=hb2f20db_0
26 | - execnet=1.8.0=pyhd3eb1b0_0
27 | - expat=2.2.10=he6710b0_2
28 | - fontconfig=2.13.1=h6c09931_0
29 | - freetype=2.10.4=h5ab3b9f_0
30 | - glib=2.66.1=h92f7085_0
31 | - google-auth=1.24.0=pyhd3eb1b0_0
32 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2
33 | - grpcio=1.31.0=py38hf8bcb03_0
34 | - gst-plugins-base=1.14.0=h8213a91_2
35 | - gstreamer=1.14.0=h28cd5cc_2
36 | - icu=58.2=he6710b0_3
37 | - idna=2.10=pyhd3eb1b0_0
38 | - importlib-metadata=2.0.0=py_1
39 | - iniconfig=1.1.1=pyhd3eb1b0_0
40 | - intel-openmp=2020.2=254
41 | - jpeg=9b=h024ee3a_2
42 | - kiwisolver=1.3.1=py38h2531618_0
43 | - lcms2=2.11=h396b838_0
44 | - ld_impl_linux-64=2.33.1=h53a641e_7
45 | - libedit=3.1.20191231=h14c3975_1
46 | - libffi=3.3=he6710b0_2
47 | - libgcc-ng=9.1.0=hdf63c60_0
48 | - libpng=1.6.37=hbc83047_0
49 | - libprotobuf=3.14.0=h8c45485_0
50 | - libstdcxx-ng=9.1.0=hdf63c60_0
51 | - libtiff=4.1.0=h2733197_1
52 | - libuuid=1.0.3=h1bed415_2
53 | - libuv=1.40.0=h7b6447c_0
54 | - libxcb=1.14=h7b6447c_0
55 | - libxml2=2.9.10=hb55368b_3
56 | - lz4-c=1.9.3=h2531618_0
57 | - markdown=3.3.3=py38h06a4308_0
58 | - matplotlib=3.3.2=h06a4308_0
59 | - matplotlib-base=3.3.2=py38h817c723_0
60 | - mkl=2020.2=256
61 | - mkl-service=2.3.0=py38he904b0f_0
62 | - mkl_fft=1.2.0=py38h23d657b_0
63 | - mkl_random=1.1.1=py38h0573a6f_0
64 | - more-itertools=8.6.0=pyhd3eb1b0_0
65 | - multidict=4.7.6=py38h7b6447c_1
66 | - ncurses=6.2=he6710b0_1
67 | - ninja=1.10.2=py38hff7bd54_0
68 | - numpy=1.19.2=py38h54aff64_0
69 | - numpy-base=1.19.2=py38hfa32c7d_0
70 | - oauthlib=3.1.0=py_0
71 | - olefile=0.46=py_0
72 | - openssl=1.1.1i=h27cfd23_0
73 | - packaging=20.9=pyhd3eb1b0_0
74 | - pandas=1.2.1=py38ha9443f7_0
75 | - pcre=8.44=he6710b0_0
76 | - pillow=8.1.0=py38he98fc37_0
77 | - pip=20.3.3=py38h06a4308_0
78 | - pluggy=0.13.1=py38_0
79 | - protobuf=3.14.0=py38h2531618_1
80 | - py=1.10.0=pyhd3eb1b0_0
81 | - pyasn1=0.4.8=py_0
82 | - pyasn1-modules=0.2.8=py_0
83 | - pycparser=2.20=py_2
84 | - pyjwt=2.0.1=py38h06a4308_0
85 | - pyopenssl=20.0.1=pyhd3eb1b0_1
86 | - pyparsing=2.4.7=pyhd3eb1b0_0
87 | - pyqt=5.9.2=py38h05f1152_4
88 | - pysocks=1.7.1=py38h06a4308_0
89 | - pytest=6.2.2=py38h06a4308_2
90 | - pytest-forked=1.3.0=pyhd3eb1b0_0
91 | - pytest-xdist=2.2.0=pyhd3eb1b0_0
92 | - python=3.8.5=h7579374_1
93 | - python-dateutil=2.8.1=pyhd3eb1b0_0
94 | - pytorch=1.7.1=py3.8_cuda10.1.243_cudnn7.6.3_0
95 | - pytz=2021.1=pyhd3eb1b0_0
96 | - qt=5.9.7=h5867ecd_1
97 | - readline=8.1=h27cfd23_0
98 | - requests=2.25.1=pyhd3eb1b0_0
99 | - requests-oauthlib=1.3.0=py_0
100 | - rsa=4.7=pyhd3eb1b0_1
101 | - setuptools=52.0.0=py38h06a4308_0
102 | - sip=4.19.13=py38he6710b0_0
103 | - six=1.15.0=py38h06a4308_0
104 | - sqlite=3.33.0=h62c20be_0
105 | - tensorboard=2.3.0=pyh4dce500_0
106 | - tensorboard-plugin-wit=1.6.0=py_0
107 | - tk=8.6.10=hbc83047_0
108 | - toml=0.10.1=py_0
109 | - torchvision=0.8.2=py38_cu101
110 | - tornado=6.1=py38h27cfd23_0
111 | - typing-extensions=3.7.4.3=hd3eb1b0_0
112 | - typing_extensions=3.7.4.3=pyh06a4308_0
113 | - urllib3=1.26.3=pyhd3eb1b0_0
114 | - werkzeug=1.0.1=pyhd3eb1b0_0
115 | - wheel=0.36.2=pyhd3eb1b0_0
116 | - xz=5.2.5=h7b6447c_0
117 | - yarl=1.5.1=py38h7b6447c_0
118 | - zipp=3.4.0=pyhd3eb1b0_0
119 | - zlib=1.2.11=h7b6447c_3
120 | - zstd=1.4.5=h9ceee32_0
121 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # NOTE: Previous versions of pytorch and torchvision might also work as well,
2 | # but we haven't test them yet
3 | torch>=1.7.1
4 | torchvision>=0.8.2
5 | matplotlib
6 | numpy
7 | tensorboard
--------------------------------------------------------------------------------
/scripts/script_cifar100.sh:
--------------------------------------------------------------------------------
1 |
2 | #!/bin/bash
3 |
4 | if [ "$1" != "" ]; then
5 | echo "Running approach: $1"
6 | else
7 | echo "No approach has been assigned."
8 | fi
9 | if [ "$2" != "" ]; then
10 | echo "Running on gpu: $2"
11 | else
12 | echo "No gpu has been assigned."
13 | fi
14 |
15 | PROJECT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && cd .. && pwd )"
16 | SRC_DIR="$PROJECT_DIR/src"
17 | echo "Project dir: $PROJECT_DIR"
18 | echo "Sources dir: $SRC_DIR"
19 |
20 | RESULTS_DIR="$PROJECT_DIR/results"
21 | if [ "$4" != "" ]; then
22 | RESULTS_DIR=$4
23 | else
24 | echo "No results dir is given. Default will be used."
25 | fi
26 | echo "Results dir: $RESULTS_DIR"
27 |
28 | MU=1.0
29 | # LAMBDA=1.0
30 |
31 | for LAMBDA in 0.0
32 | do
33 | for seed in $(seq 1 1 1)
34 | do
35 | python3 -u src/main_incremental.py --datasets cifar100_224 --network OVit_tiny_16_augreg_224 --approach olwf_asym --nepochs 100 --log disk --batch-size 96 --gpu 0 --exp-name 10epochs_sym_pool_layer_6_${SPARSENESS}_lambda_${LAMBDA}_50_cls_6tasks_${seed}_mu_${MU} --after-norm --sym --lr 0.01 --seed ${seed} --lamb $LAMBDA --plast_mu $MU --num-tasks 6 --nc-first-task 50 --lr-patience 15 --int-layer --pool-layers 6
36 | python3 -u src/main_incremental.py --datasets cifar100_224 --network OVit_tiny_16_augreg_224 --approach olwf_asym --nepochs 100 --log disk --batch-size 96 --gpu 0 --exp-name 10epochs_sym_pool_layer_1_${SPARSENESS}_lambda_${LAMBDA}_50_cls_6tasks_${seed}_mu_${MU} --after-norm --sym --lr 0.01 --seed ${seed} --lamb $LAMBDA --plast_mu $MU --num-tasks 6 --nc-first-task 50 --lr-patience 15 --int-layer --pool-layers 1
37 | python3 -u src/main_incremental.py --datasets cifar100_224 --network OVit_tiny_16_augreg_224 --approach olwf_asym --nepochs 100 --log disk --batch-size 96 --gpu 0 --exp-name 10epochs_sym_pool_layer_12_${SPARSENESS}_lambda_${LAMBDA}_50_cls_6tasks_${seed}_mu_${MU} --after-norm --sym --lr 0.01 --seed ${seed} --lamb $LAMBDA --plast_mu $MU --num-tasks 6 --nc-first-task 50 --lr-patience 15 --int-layer --pool-layers 12
38 |
39 | done
40 | done
41 |
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | # Framework for Analysis of Class-Incremental Learning
2 | Run the code with:
3 | ```
4 | python3 -u src/main_incremental.py
5 | ```
6 | followed by general options:
7 |
8 | * `--gpu`: index of GPU to run the experiment on (default=0)
9 | * `--results-path`: path where results are stored (default='../results')
10 | * `--exp-name`: experiment name (default=None)
11 | * `--seed`: random seed (default=0)
12 | * `--save-models`: save trained models (default=False)
13 | * `--last-layer-analysis`: plot last layer analysis (default=False)
14 | * `--no-cudnn-deterministic`: disable CUDNN deterministic (default=False)
15 |
16 | and specific options for each of the code parts (corresponding to folders):
17 |
18 | * `--approach`: learning approach used (default='finetuning') [[more](approaches/README.md)]
19 | * `--datasets`: dataset or datasets used (default=['cifar100']) [[more](datasets/README.md)]
20 | * `--network`: network architecture used (default='resnet32') [[more](networks/README.md)]
21 | * `--log`: loggers used (default='disk') [[more](loggers/README.md)]
22 |
23 | go to each of their respective readme to see all available options for each of them.
24 |
25 | ## Approaches
26 | Initially, the approaches included in the framework correspond to the ones presented in
27 | _**Class-incremental learning: survey and performance evaluation**_ (preprint , 2020). The regularization-based
28 | approaches are EWC, MAS, PathInt, LwF, LwM and DMC (green). The rehearsal approaches are iCaRL, EEIL and RWalk (blue).
29 | The bias-correction approaches are IL2M, BiC and LUCIR (orange).
30 |
31 | 
32 |
33 | More approaches will be included in the future. To learn more about them refer to the readme in
34 | [src/approaches](approaches).
35 |
36 | ## Datasets
37 | To learn about the dataset management refer to the readme in [src/datasets](datasets).
38 |
39 | ## Networks
40 | To learn about the different torchvision and custom networks refer to the readme in [src/networks](networks).
41 |
42 | ## GridSearch
43 | We implement the option to use a realistic grid search for hyperparameters which only takes into account the task at
44 | hand, without access to previous or future information not available in the incremental learning scenario. It
45 | corresponds to the one introduced in _**Class-incremental learning: survey and performance evaluation**_. The GridSearch
46 | can be applied by using:
47 |
48 | * `--gridsearch-tasks`: number of tasks to apply GridSearch (-1: all tasks) (default=-1)
49 |
50 | which we recommend to set to the total number of tasks of the experiment for a more realistic setting of the correct
51 | learning rate and possible forgetting-intransigence trade-off. However, since this has a considerable extra
52 | computational cost, it can also be set to the first 3 tasks, which would fix those hyperparameters for the remaining
53 | tasks. Other GridSearch options include:
54 |
55 | * `--gridsearch-config`: configuration file for GridSearch options (default='gridsearch_config') [[more](gridsearch_config.py)]
56 | * `--gridsearch-acc-drop-thr`: GridSearch accuracy drop threshold (default=0.2)
57 | * `--gridsearch-hparam-decay`: GridSearch hyperparameter decay (default=0.5)
58 | * `--gridsearch-max-num-searches`: GridSearch maximum number of hyperparameter search (default=7)
59 |
60 | ## Utils
61 | We have some utility functions added into `utils.py`.
--------------------------------------------------------------------------------
/src/__pycache__/last_layer_analysis.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/last_layer_analysis.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/last_layer_analysis.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/last_layer_analysis.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # list all approaches available
4 | __all__ = list(
5 | map(lambda x: x[:-3],
6 | filter(lambda x: x not in ['__init__.py', 'incremental_learning.py'] and x.endswith('.py'),
7 | os.listdir(os.path.dirname(__file__))
8 | )
9 | )
10 | )
11 |
--------------------------------------------------------------------------------
/src/approach/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/ewc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/ewc.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/incremental_learning.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/incremental_learning.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/lwf.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/lwf.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/olwf.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/olwf_asym.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf_asym.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/olwf_asympost.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf_asympost.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/olwf_jsd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/olwf_jsd.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/__pycache__/path_integral.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/approach/__pycache__/path_integral.cpython-38.pyc
--------------------------------------------------------------------------------
/src/approach/ewc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | from argparse import ArgumentParser
4 |
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 | from .incremental_learning import Inc_Learning_Appr
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Elastic Weight Consolidation (EWC) approach
11 | described in http://arxiv.org/abs/1612.00796
12 | """
13 |
14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
16 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred',
17 | fi_num_samples=-1):
18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
20 | exemplars_dataset)
21 | self.lamb = lamb
22 | self.alpha = alpha
23 | self.sampling_type = fi_sampling_type
24 | self.num_samples = fi_num_samples
25 |
26 | # In all cases, we only keep importance weights for the model, but not for the heads.
27 | feat_ext = self.model.model
28 | # Store current parameters as the initial parameters before first task starts
29 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad}
30 | # Store fisher information weight importance
31 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
32 | if p.requires_grad}
33 |
34 | @staticmethod
35 | def exemplars_dataset_class():
36 | return ExemplarsDataset
37 |
38 | @staticmethod
39 | def extra_parser(args):
40 | """Returns a parser containing the approach specific parameters"""
41 | parser = ArgumentParser()
42 | # Eq. 3: "lambda sets how important the old task is compared to the new one"
43 | parser.add_argument('--lamb', default=5000, type=float, required=False,
44 | help='Forgetting-intransigence trade-off (default=%(default)s)')
45 | # Define how old and new fisher is fused, by default it is a 50-50 fusion
46 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
47 | help='EWC alpha (default=%(default)s)')
48 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False,
49 | choices=['true', 'max_pred', 'multinomial'],
50 | help='Sampling type for Fisher information (default=%(default)s)')
51 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False,
52 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
53 |
54 | return parser.parse_known_args(args)
55 |
56 | def _get_optimizer(self):
57 | """Returns the optimizer"""
58 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
59 | # if there are no exemplars, previous heads are not modified
60 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
61 | else:
62 | params = self.model.parameters()
63 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
64 |
65 | def compute_fisher_matrix_diag(self, trn_loader):
66 | # Store Fisher Information
67 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
68 | if p.requires_grad}
69 | # Compute fisher information for specified number of samples -- rounded to the batch size
70 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
71 | else (len(trn_loader.dataset) // trn_loader.batch_size)
72 | # Do forward and backward pass to compute the fisher information
73 | self.model.train()
74 | for images, targets in itertools.islice(trn_loader, n_samples_batches):
75 | outputs = self.model.forward(images.to(self.device))
76 |
77 | if self.sampling_type == 'true':
78 | # Use the labels to compute the gradients based on the CE-loss with the ground truth
79 | preds = targets.to(self.device)
80 | elif self.sampling_type == 'max_pred':
81 | # Not use labels and compute the gradients related to the prediction the model has learned
82 | preds = torch.cat(outputs, dim=1).argmax(1).flatten()
83 | elif self.sampling_type == 'multinomial':
84 | # Use a multinomial sampling to compute the gradients
85 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1)
86 | preds = torch.multinomial(probs, len(targets)).flatten()
87 |
88 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds)
89 | self.optimizer.zero_grad()
90 | loss.backward()
91 | # Accumulate all gradients from loss with regularization
92 | for n, p in self.model.model.named_parameters():
93 | if p.grad is not None:
94 | fisher[n] += p.grad.pow(2) * len(targets)
95 | # Apply mean across all samples
96 | n_samples = n_samples_batches * trn_loader.batch_size
97 | fisher = {n: (p / n_samples) for n, p in fisher.items()}
98 | return fisher
99 |
100 | def train_loop(self, t, trn_loader, val_loader):
101 | """Contains the epochs loop"""
102 |
103 | # add exemplars to train_loader
104 | if len(self.exemplars_dataset) > 0 and t > 0:
105 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
106 | batch_size=trn_loader.batch_size,
107 | shuffle=True,
108 | num_workers=trn_loader.num_workers,
109 | pin_memory=trn_loader.pin_memory)
110 |
111 | # FINETUNING TRAINING -- contains the epochs loop
112 | super().train_loop(t, trn_loader, val_loader)
113 |
114 | # EXEMPLAR MANAGEMENT -- select training subset
115 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
116 |
117 | def post_train_process(self, t, trn_loader):
118 | """Runs after training all the epochs of the task (after the train session)"""
119 |
120 | # Store current parameters for the next task
121 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
122 |
123 | # calculate Fisher information
124 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader)
125 | # merge fisher information, we do not want to keep fisher information for each task in memory
126 | for n in self.fisher.keys():
127 | # Added option to accumulate fisher over time with a pre-fixed growing alpha
128 | if self.alpha == -1:
129 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
130 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n]
131 | else:
132 | self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n])
133 |
134 | def criterion(self, t, outputs, targets):
135 | """Returns the loss value"""
136 | loss = 0
137 | if t > 0:
138 | loss_reg = 0
139 | # Eq. 3: elastic weight consolidation quadratic penalty
140 | for n, p in self.model.model.named_parameters():
141 | if n in self.fisher.keys():
142 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2
143 | loss += self.lamb * loss_reg
144 | # Current cross-entropy loss -- with exemplars use all heads
145 | if len(self.exemplars_dataset) > 0:
146 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
147 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
148 |
--------------------------------------------------------------------------------
/src/approach/finetuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the finetuning baseline"""
10 |
11 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
12 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
13 | logger=None, exemplars_dataset=None, all_outputs=False):
14 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
15 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
16 | exemplars_dataset)
17 | self.all_out = all_outputs
18 |
19 | @staticmethod
20 | def exemplars_dataset_class():
21 | return ExemplarsDataset
22 |
23 | @staticmethod
24 | def extra_parser(args):
25 | """Returns a parser containing the approach specific parameters"""
26 | parser = ArgumentParser()
27 | parser.add_argument('--all-outputs', action='store_true', required=False,
28 | help='Allow all weights related to all outputs to be modified (default=%(default)s)')
29 | return parser.parse_known_args(args)
30 |
31 | def _get_optimizer(self):
32 | """Returns the optimizer"""
33 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1 and not self.all_out:
34 | # if there are no exemplars, previous heads are not modified
35 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
36 | else:
37 | params = self.model.parameters()
38 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
39 |
40 | def train_loop(self, t, trn_loader, val_loader):
41 | """Contains the epochs loop"""
42 |
43 | # add exemplars to train_loader
44 | if len(self.exemplars_dataset) > 0 and t > 0:
45 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
46 | batch_size=trn_loader.batch_size,
47 | shuffle=True,
48 | num_workers=trn_loader.num_workers,
49 | pin_memory=trn_loader.pin_memory)
50 |
51 | # FINETUNING TRAINING -- contains the epochs loop
52 | super().train_loop(t, trn_loader, val_loader)
53 |
54 | # EXEMPLAR MANAGEMENT -- select training subset
55 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
56 |
57 | def criterion(self, t, outputs, targets):
58 | """Returns the loss value"""
59 | if self.all_out or len(self.exemplars_dataset) > 0:
60 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
61 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
62 |
--------------------------------------------------------------------------------
/src/approach/freezing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the freezing baseline"""
10 |
11 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
12 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
13 | logger=None, exemplars_dataset=None, freeze_after=0, all_outputs=False):
14 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
15 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
16 | exemplars_dataset)
17 | self.freeze_after = freeze_after
18 | self.all_out = all_outputs
19 |
20 | @staticmethod
21 | def exemplars_dataset_class():
22 | return ExemplarsDataset
23 |
24 | @staticmethod
25 | def extra_parser(args):
26 | """Returns a parser containing the approach specific parameters"""
27 | parser = ArgumentParser()
28 | parser.add_argument('--freeze-after', default=0, type=int, required=False,
29 | help='Freeze model except current head after the specified task (default=%(default)s)')
30 | parser.add_argument('--all-outputs', action='store_true', required=False,
31 | help='Allow all weights related to all outputs to be modified (default=%(default)s)')
32 | return parser.parse_known_args(args)
33 |
34 | def _get_optimizer(self):
35 | """Returns the optimizer"""
36 | return torch.optim.SGD(self._train_parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
37 |
38 | def _has_exemplars(self):
39 | """Returns True in case exemplars are being used"""
40 | return self.exemplars_dataset is not None and len(self.exemplars_dataset) > 0
41 |
42 | def post_train_process(self, t, trn_loader):
43 | """Runs after training all the epochs of the task (after the train session)"""
44 | if t >= self.freeze_after:
45 | self.model.freeze_backbone()
46 |
47 | def train_loop(self, t, trn_loader, val_loader):
48 | """Contains the epochs loop"""
49 |
50 | # add exemplars to train_loader
51 | if t > 0 and self._has_exemplars():
52 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
53 | batch_size=trn_loader.batch_size,
54 | shuffle=True,
55 | num_workers=trn_loader.num_workers,
56 | pin_memory=trn_loader.pin_memory)
57 |
58 | # FINETUNING TRAINING -- contains the epochs loop
59 | super().train_loop(t, trn_loader, val_loader)
60 |
61 | # EXEMPLAR MANAGEMENT -- select training subset
62 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
63 |
64 | def train_epoch(self, t, trn_loader):
65 | """Runs a single epoch"""
66 | self._model_train(t)
67 | for images, targets in trn_loader:
68 | # Forward current model
69 | outputs = self.model(images.to(self.device))
70 | loss = self.criterion(t, outputs, targets.to(self.device))
71 | # Backward
72 | self.optimizer.zero_grad()
73 | loss.backward()
74 | torch.nn.utils.clip_grad_norm_(self._train_parameters(), self.clipgrad)
75 | self.optimizer.step()
76 |
77 | def _model_train(self, t):
78 | """Freezes the necessary weights"""
79 | if self.fix_bn and t > 0:
80 | self.model.freeze_bn()
81 | if self.freeze_after >= 0 and t <= self.freeze_after: # non-frozen task - whole model to train
82 | self.model.train()
83 | else:
84 | self.model.model.eval()
85 | if self._has_exemplars():
86 | # with exemplars - use all heads
87 | for head in self.model.heads:
88 | head.train()
89 | else:
90 | # no exemplars - use current head
91 | self.model.heads[-1].train()
92 |
93 | def _train_parameters(self):
94 | """Includes the necessary weights to the optimizer"""
95 | if len(self.model.heads) <= (self.freeze_after + 1):
96 | return self.model.parameters()
97 | else:
98 | if self._has_exemplars():
99 | return [p for head in self.model.heads for p in head.parameters()]
100 | else:
101 | return self.model.heads[-1].parameters()
102 |
103 | def criterion(self, t, outputs, targets):
104 | """Returns the loss value"""
105 | if self.all_out or self._has_exemplars():
106 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
107 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
108 |
--------------------------------------------------------------------------------
/src/approach/il2m.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the Class Incremental Learning With Dual Memory (IL2M) approach described in
10 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf
11 | """
12 |
13 | def __init__(self, model, device, nepochs=100, lr=0.1, lr_min=1e-4, lr_factor=3, lr_patience=10, clipgrad=10000,
14 | momentum=0.9, wd=0.0001, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False,
15 | eval_on_train=False, logger=None, exemplars_dataset=None):
16 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
17 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
18 | exemplars_dataset)
19 | self.init_classes_means = []
20 | self.current_classes_means = []
21 | self.models_confidence = []
22 | # FLAG to not do scores rectification while finetuning training
23 | self.ft_train = False
24 |
25 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
26 | assert (have_exemplars > 0), 'Error: IL2M needs exemplars.'
27 |
28 | @staticmethod
29 | def exemplars_dataset_class():
30 | return ExemplarsDataset
31 |
32 | def il2m(self, t, trn_loader):
33 | """Compute and store statistics for score rectification"""
34 | old_classes_number = sum(self.model.task_cls[:t])
35 | classes_counts = [0 for _ in range(sum(self.model.task_cls))]
36 | models_counts = 0
37 |
38 | # to store statistics for the classes as learned in the current incremental state
39 | self.current_classes_means = [0 for _ in range(old_classes_number)]
40 | # to store statistics for past classes as learned in their initial states
41 | for cls in range(old_classes_number, old_classes_number + self.model.task_cls[t]):
42 | self.init_classes_means.append(0)
43 | # to store statistics for model confidence in different states (i.e. avg top-1 pred scores)
44 | self.models_confidence.append(0)
45 |
46 | # compute the mean prediction scores that will be used to rectify scores in subsequent tasks
47 | with torch.no_grad():
48 | self.model.eval()
49 | for images, targets in trn_loader:
50 | outputs = self.model(images.to(self.device))
51 | scores = np.array(torch.cat(outputs, dim=1).data.cpu().numpy(), dtype=np.float)
52 | for m in range(len(targets)):
53 | if targets[m] < old_classes_number:
54 | # computation of class means for past classes of the current state.
55 | self.current_classes_means[targets[m]] += scores[m, targets[m]]
56 | classes_counts[targets[m]] += 1
57 | else:
58 | # compute the mean prediction scores for the new classes of the current state
59 | self.init_classes_means[targets[m]] += scores[m, targets[m]]
60 | classes_counts[targets[m]] += 1
61 | # compute the mean top scores for the new classes of the current state
62 | self.models_confidence[t] += np.max(scores[m, ])
63 | models_counts += 1
64 | # Normalize by corresponding number of images
65 | for cls in range(old_classes_number):
66 | self.current_classes_means[cls] /= classes_counts[cls]
67 | for cls in range(old_classes_number, old_classes_number + self.model.task_cls[t]):
68 | self.init_classes_means[cls] /= classes_counts[cls]
69 | self.models_confidence[t] /= models_counts
70 |
71 | def train_loop(self, t, trn_loader, val_loader):
72 | """Contains the epochs loop"""
73 |
74 | # add exemplars to train_loader
75 | if len(self.exemplars_dataset) > 0 and t > 0:
76 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
77 | batch_size=trn_loader.batch_size,
78 | shuffle=True,
79 | num_workers=trn_loader.num_workers,
80 | pin_memory=trn_loader.pin_memory)
81 |
82 | # FINETUNING TRAINING -- contains the epochs loop
83 | self.ft_train = True
84 | super().train_loop(t, trn_loader, val_loader)
85 | self.ft_train = False
86 |
87 | # IL2M outputs rectification
88 | self.il2m(t, trn_loader)
89 |
90 | # EXEMPLAR MANAGEMENT -- select training subset
91 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
92 |
93 | def calculate_metrics(self, outputs, targets):
94 | """Contains the main Task-Aware and Task-Agnostic metrics"""
95 | if self.ft_train:
96 | # no score rectification while training
97 | hits_taw, hits_tag = super().calculate_metrics(outputs, targets)
98 | else:
99 | # Task-Aware Multi-Head
100 | pred = torch.zeros_like(targets.to(self.device))
101 | for m in range(len(pred)):
102 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum()
103 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task]
104 | hits_taw = (pred == targets.to(self.device)).float()
105 | # Task-Agnostic Multi-Head
106 | if self.multi_softmax:
107 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs]
108 | # Eq. 1: rectify predicted scores
109 | old_classes_number = sum(self.model.task_cls[:-1])
110 | for m in range(len(targets)):
111 | rectified_outputs = torch.cat(outputs, dim=1)
112 | pred[m] = rectified_outputs[m].argmax()
113 | if old_classes_number:
114 | # if the top-1 class predicted by the network is a new one, rectify the score
115 | if int(pred[m]) >= old_classes_number:
116 | for o in range(old_classes_number):
117 | o_task = int((self.model.task_cls.cumsum(0) <= o).sum())
118 | rectified_outputs[m, o] *= (self.init_classes_means[o] / self.current_classes_means[o]) * \
119 | (self.models_confidence[-1] / self.models_confidence[o_task])
120 | pred[m] = rectified_outputs[m].argmax()
121 | # otherwise, rectification is not done because an old class is directly predicted
122 | hits_tag = (pred == targets.to(self.device)).float()
123 | return hits_taw, hits_tag
124 |
125 | def criterion(self, t, outputs, targets):
126 | """Returns the loss value"""
127 | if len(self.exemplars_dataset) > 0:
128 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
129 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
130 |
--------------------------------------------------------------------------------
/src/approach/joint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 | from torch.utils.data import DataLoader, Dataset
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the joint baseline"""
11 |
12 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
13 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
14 | logger=None, exemplars_dataset=None, freeze_after=-1):
15 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
16 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
17 | exemplars_dataset)
18 | self.trn_datasets = []
19 | self.val_datasets = []
20 | self.freeze_after = freeze_after
21 |
22 | have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
23 | assert (have_exemplars == 0), 'Warning: Joint does not use exemplars. Comment this line to force it.'
24 |
25 | @staticmethod
26 | def exemplars_dataset_class():
27 | return ExemplarsDataset
28 |
29 | @staticmethod
30 | def extra_parser(args):
31 | """Returns a parser containing the approach specific parameters"""
32 | parser = ArgumentParser()
33 | parser.add_argument('--freeze-after', default=-1, type=int, required=False,
34 | help='Freeze model except heads after the specified task'
35 | '(-1: normal Incremental Joint Training, no freeze) (default=%(default)s)')
36 | return parser.parse_known_args(args)
37 |
38 | def post_train_process(self, t, trn_loader):
39 | """Runs after training all the epochs of the task (after the train session)"""
40 | if self.freeze_after > -1 and t >= self.freeze_after:
41 | self.model.freeze_all()
42 | for head in self.model.heads:
43 | for param in head.parameters():
44 | param.requires_grad = True
45 |
46 | def train_loop(self, t, trn_loader, val_loader):
47 | """Contains the epochs loop"""
48 |
49 | # add new datasets to existing cumulative ones
50 | self.trn_datasets.append(trn_loader.dataset)
51 | self.val_datasets.append(val_loader.dataset)
52 | trn_dset = JointDataset(self.trn_datasets)
53 | val_dset = JointDataset(self.val_datasets)
54 | trn_loader = DataLoader(trn_dset,
55 | batch_size=trn_loader.batch_size,
56 | shuffle=True,
57 | num_workers=trn_loader.num_workers,
58 | pin_memory=trn_loader.pin_memory)
59 | val_loader = DataLoader(val_dset,
60 | batch_size=val_loader.batch_size,
61 | shuffle=False,
62 | num_workers=val_loader.num_workers,
63 | pin_memory=val_loader.pin_memory)
64 | # continue training as usual
65 | super().train_loop(t, trn_loader, val_loader)
66 |
67 | def train_epoch(self, t, trn_loader):
68 | """Runs a single epoch"""
69 | if self.freeze_after < 0 or t <= self.freeze_after:
70 | self.model.train()
71 | if self.fix_bn and t > 0:
72 | self.model.freeze_bn()
73 | else:
74 | self.model.eval()
75 | for head in self.model.heads:
76 | head.train()
77 | for images, targets in trn_loader:
78 | # Forward current model
79 | outputs = self.model(images.to(self.device))
80 | loss = self.criterion(t, outputs, targets.to(self.device))
81 | # Backward
82 | self.optimizer.zero_grad()
83 | loss.backward()
84 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
85 | self.optimizer.step()
86 |
87 | def criterion(self, t, outputs, targets):
88 | """Returns the loss value"""
89 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
90 |
91 |
92 | class JointDataset(Dataset):
93 | """Characterizes a dataset for PyTorch -- this dataset accumulates each task dataset incrementally"""
94 |
95 | def __init__(self, datasets):
96 | self.datasets = datasets
97 | self._len = sum([len(d) for d in self.datasets])
98 |
99 | def __len__(self):
100 | 'Denotes the total number of samples'
101 | return self._len
102 |
103 | def __getitem__(self, index):
104 | for d in self.datasets:
105 | if len(d) <= index:
106 | index -= len(d)
107 | else:
108 | x, y = d[index]
109 | return x, y
110 |
--------------------------------------------------------------------------------
/src/approach/lwf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 | from argparse import ArgumentParser
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Learning Without Forgetting (LwF) approach
11 | described in https://arxiv.org/abs/1606.09282
12 | """
13 |
14 | # Weight decay of 0.0005 is used in the original article (page 4).
15 | # Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our
16 | # method or the compared Less Forgetting Learning (see Table 2(b))."
17 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
18 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
19 | logger=None, exemplars_dataset=None, lamb=1, T=2):
20 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
21 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
22 | exemplars_dataset)
23 | self.model_old = None
24 | self.lamb = lamb
25 | self.T = T
26 |
27 | @staticmethod
28 | def exemplars_dataset_class():
29 | return ExemplarsDataset
30 |
31 | @staticmethod
32 | def extra_parser(args):
33 | """Returns a parser containing the approach specific parameters"""
34 | parser = ArgumentParser()
35 | # Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor
36 | # the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by
37 | # changing lambda."
38 | parser.add_argument('--lamb', default=1, type=float, required=False,
39 | help='Forgetting-intransigence trade-off (default=%(default)s)')
40 | # Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’
41 | # recommendations." -- Using a higher value for T produces a softer probability distribution over classes.
42 | parser.add_argument('--T', default=2, type=int, required=False,
43 | help='Temperature scaling (default=%(default)s)')
44 | return parser.parse_known_args(args)
45 |
46 | def _get_optimizer(self):
47 | """Returns the optimizer"""
48 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
49 | # if there are no exemplars, previous heads are not modified
50 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
51 | else:
52 | params = self.model.parameters()
53 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
54 |
55 | def train_loop(self, t, trn_loader, val_loader):
56 | """Contains the epochs loop"""
57 |
58 | # add exemplars to train_loader
59 | if len(self.exemplars_dataset) > 0 and t > 0:
60 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
61 | batch_size=trn_loader.batch_size,
62 | shuffle=True,
63 | num_workers=trn_loader.num_workers,
64 | pin_memory=trn_loader.pin_memory)
65 |
66 | # FINETUNING TRAINING -- contains the epochs loop
67 | super().train_loop(t, trn_loader, val_loader)
68 |
69 | # EXEMPLAR MANAGEMENT -- select training subset
70 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
71 |
72 | def post_train_process(self, t, trn_loader):
73 | """Runs after training all the epochs of the task (after the train session)"""
74 |
75 | # Restore best and save model for future tasks
76 | self.model_old = deepcopy(self.model)
77 | self.model_old.eval()
78 | self.model_old.freeze_all()
79 |
80 | def train_epoch(self, t, trn_loader):
81 | """Runs a single epoch"""
82 | self.model.train()
83 | if self.fix_bn and t > 0:
84 | self.model.freeze_bn()
85 | for images, targets in trn_loader:
86 | # Forward old model
87 | targets_old = None
88 | if t > 0:
89 | targets_old = self.model_old(images.to(self.device))
90 | # Forward current model
91 | outputs = self.model(images.to(self.device))
92 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
93 | # Backward
94 | self.optimizer.zero_grad()
95 | loss.backward()
96 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
97 | self.optimizer.step()
98 |
99 | def eval(self, t, val_loader):
100 | """Contains the evaluation code"""
101 | with torch.no_grad():
102 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
103 | self.model.eval()
104 | for images, targets in val_loader:
105 | # Forward old model
106 | targets_old = None
107 | if t > 0:
108 | targets_old = self.model_old(images.to(self.device))
109 | # Forward current model
110 | outputs = self.model(images.to(self.device))
111 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
112 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
113 | # Log
114 | total_loss += loss.data.cpu().numpy().item() * len(targets)
115 | total_acc_taw += hits_taw.sum().data.cpu().numpy().item()
116 | total_acc_tag += hits_tag.sum().data.cpu().numpy().item()
117 | total_num += len(targets)
118 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
119 |
120 | def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):
121 | """Calculates cross-entropy with temperature scaling"""
122 | out = torch.nn.functional.softmax(outputs, dim=1)
123 | tar = torch.nn.functional.softmax(targets, dim=1)
124 | if exp != 1:
125 | out = out.pow(exp)
126 | out = out / out.sum(1).view(-1, 1).expand_as(out)
127 | tar = tar.pow(exp)
128 | tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)
129 | out = out + eps / out.size(1)
130 | out = out / out.sum(1).view(-1, 1).expand_as(out)
131 | ce = -(tar * out.log()).sum(1)
132 | if size_average:
133 | ce = ce.mean()
134 | return ce
135 |
136 | def criterion(self, t, outputs, targets, outputs_old=None):
137 | """Returns the loss value"""
138 | loss = 0
139 | if t > 0:
140 | # Knowledge distillation loss for all previous tasks
141 | loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1),
142 | torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T)
143 | # Current cross-entropy loss -- with exemplars use all heads
144 | if len(self.exemplars_dataset) > 0:
145 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
146 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
147 |
--------------------------------------------------------------------------------
/src/approach/mas.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | from argparse import ArgumentParser
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Memory Aware Synapses (MAS) approach (global version)
11 | described in https://arxiv.org/abs/1711.09601
12 | Original code available at https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses
13 | """
14 |
15 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
16 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
17 | logger=None, exemplars_dataset=None, lamb=1, alpha=0.5, fi_num_samples=-1):
18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
20 | exemplars_dataset)
21 | self.lamb = lamb
22 | self.alpha = alpha
23 | self.num_samples = fi_num_samples
24 |
25 | # In all cases, we only keep importance weights for the model, but not for the heads.
26 | feat_ext = self.model.model
27 | # Store current parameters as the initial parameters before first task starts
28 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad}
29 | # Store fisher information weight importance
30 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
31 | if p.requires_grad}
32 |
33 | @staticmethod
34 | def exemplars_dataset_class():
35 | return ExemplarsDataset
36 |
37 | @staticmethod
38 | def extra_parser(args):
39 | """Returns a parser containing the approach specific parameters"""
40 | parser = ArgumentParser()
41 | # Eq. 3: lambda is the regularizer trade-off -- In original code: MAS.ipynb block [4]: lambda set to 1
42 | parser.add_argument('--lamb', default=1, type=float, required=False,
43 | help='Forgetting-intransigence trade-off (default=%(default)s)')
44 | # Define how old and new importance is fused, by default it is a 50-50 fusion
45 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
46 | help='MAS alpha (default=%(default)s)')
47 | # Number of samples from train for estimating importance
48 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False,
49 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
50 | return parser.parse_known_args(args)
51 |
52 | def _get_optimizer(self):
53 | """Returns the optimizer"""
54 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
55 | # if there are no exemplars, previous heads are not modified
56 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
57 | else:
58 | params = self.model.parameters()
59 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
60 |
61 | # Section 4.1: MAS (global) is implemented since the paper shows is more efficient than l-MAS (local)
62 | def estimate_parameter_importance(self, trn_loader):
63 | # Initialize importance matrices
64 | importance = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
65 | if p.requires_grad}
66 | # Compute fisher information for specified number of samples -- rounded to the batch size
67 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
68 | else (len(trn_loader.dataset) // trn_loader.batch_size)
69 | # Do forward and backward pass to accumulate L2-loss gradients
70 | self.model.train()
71 | for images, targets in itertools.islice(trn_loader, n_samples_batches):
72 | # MAS allows any unlabeled data to do the estimation, we choose the current data as in main experiments
73 | outputs = self.model.forward(images.to(self.device))
74 | # Page 6: labels not required, "...use the gradients of the squared L2-norm of the learned function output."
75 | loss = torch.norm(torch.cat(outputs, dim=1), p=2, dim=1).mean()
76 | self.optimizer.zero_grad()
77 | loss.backward()
78 | # Eq. 2: accumulate the gradients over the inputs to obtain importance weights
79 | for n, p in self.model.model.named_parameters():
80 | if p.grad is not None:
81 | importance[n] += p.grad.abs() * len(targets)
82 | # Eq. 2: divide by N total number of samples
83 | n_samples = n_samples_batches * trn_loader.batch_size
84 | importance = {n: (p / n_samples) for n, p in importance.items()}
85 | return importance
86 |
87 | def train_loop(self, t, trn_loader, val_loader):
88 | """Contains the epochs loop"""
89 |
90 | # add exemplars to train_loader
91 | if len(self.exemplars_dataset) > 0 and t > 0:
92 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
93 | batch_size=trn_loader.batch_size,
94 | shuffle=True,
95 | num_workers=trn_loader.num_workers,
96 | pin_memory=trn_loader.pin_memory)
97 |
98 | # FINETUNING TRAINING -- contains the epochs loop
99 | super().train_loop(t, trn_loader, val_loader)
100 |
101 | # EXEMPLAR MANAGEMENT -- select training subset
102 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
103 |
104 | def post_train_process(self, t, trn_loader):
105 | """Runs after training all the epochs of the task (after the train session)"""
106 |
107 | # Store current parameters for the next task
108 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
109 |
110 | # calculate Fisher information
111 | curr_importance = self.estimate_parameter_importance(trn_loader)
112 | # merge fisher information, we do not want to keep fisher information for each task in memory
113 | for n in self.importance.keys():
114 | # Added option to accumulate importance over time with a pre-fixed growing alpha
115 | if self.alpha == -1:
116 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
117 | self.importance[n] = alpha * self.importance[n] + (1 - alpha) * curr_importance[n]
118 | else:
119 | # As in original code: MAS_utils/MAS_based_Training.py line 638 -- just add prev and new
120 | self.importance[n] = self.alpha * self.importance[n] + (1 - self.alpha) * curr_importance[n]
121 |
122 | def criterion(self, t, outputs, targets):
123 | """Returns the loss value"""
124 | loss = 0
125 | if t > 0:
126 | loss_reg = 0
127 | # Eq. 3: memory aware synapses regularizer penalty
128 | for n, p in self.model.model.named_parameters():
129 | if n in self.importance.keys():
130 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2)) / 2
131 | loss += self.lamb * loss_reg
132 | # Current cross-entropy loss -- with exemplars use all heads
133 | if len(self.exemplars_dataset) > 0:
134 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
135 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
136 |
--------------------------------------------------------------------------------
/src/approach/oewc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | from argparse import ArgumentParser
4 |
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 | from .incremental_learning import Inc_Learning_Appr
7 |
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Elastic Weight Consolidation (EWC) approach
11 | described in http://arxiv.org/abs/1612.00796
12 | """
13 |
14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
16 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred',
17 | fi_num_samples=-1):
18 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
19 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
20 | exemplars_dataset)
21 | self.lamb = lamb
22 | self.alpha = alpha
23 | self.sampling_type = fi_sampling_type
24 | self.num_samples = fi_num_samples
25 |
26 | # In all cases, we only keep importance weights for the model, but not for the heads.
27 | feat_ext = self.model.model
28 | # Store current parameters as the initial parameters before first task starts
29 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad}
30 | # Store fisher information weight importance
31 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
32 | if p.requires_grad}
33 |
34 | @staticmethod
35 | def exemplars_dataset_class():
36 | return ExemplarsDataset
37 |
38 | @staticmethod
39 | def extra_parser(args):
40 | """Returns a parser containing the approach specific parameters"""
41 | parser = ArgumentParser()
42 | # Eq. 3: "lambda sets how important the old task is compared to the new one"
43 | parser.add_argument('--lamb', default=5000, type=float, required=False,
44 | help='Forgetting-intransigence trade-off (default=%(default)s)')
45 | # Define how old and new fisher is fused, by default it is a 50-50 fusion
46 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
47 | help='EWC alpha (default=%(default)s)')
48 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False,
49 | choices=['true', 'max_pred', 'multinomial'],
50 | help='Sampling type for Fisher information (default=%(default)s)')
51 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False,
52 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
53 |
54 | return parser.parse_known_args(args)
55 |
56 | def _get_optimizer(self):
57 | """Returns the optimizer"""
58 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
59 | # if there are no exemplars, previous heads are not modified
60 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
61 | else:
62 | params = self.model.parameters()
63 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
64 |
65 | def compute_fisher_matrix_diag(self, trn_loader):
66 | # Store Fisher Information
67 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters()
68 | if p.requires_grad}
69 | # Compute fisher information for specified number of samples -- rounded to the batch size
70 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
71 | else (len(trn_loader.dataset) // trn_loader.batch_size)
72 | # Do forward and backward pass to compute the fisher information
73 | self.model.train()
74 | for images, targets in itertools.islice(trn_loader, n_samples_batches):
75 | outputs = self.model.forward(images.to(self.device))
76 |
77 | if self.sampling_type == 'true':
78 | # Use the labels to compute the gradients based on the CE-loss with the ground truth
79 | preds = targets.to(self.device)
80 | elif self.sampling_type == 'max_pred':
81 | # Not use labels and compute the gradients related to the prediction the model has learned
82 | preds = torch.cat(outputs, dim=1).argmax(1).flatten()
83 | elif self.sampling_type == 'multinomial':
84 | # Use a multinomial sampling to compute the gradients
85 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1)
86 | preds = torch.multinomial(probs, len(targets)).flatten()
87 |
88 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds)
89 | self.optimizer.zero_grad()
90 | loss.backward()
91 | # Accumulate all gradients from loss with regularization
92 | for n, p in self.model.model.named_parameters():
93 | if p.grad is not None:
94 | fisher[n] += p.grad.pow(2) * len(targets)
95 | # Apply mean across all samples
96 | n_samples = n_samples_batches * trn_loader.batch_size
97 | fisher = {n: (p / n_samples) for n, p in fisher.items()}
98 | return fisher
99 |
100 | def train_loop(self, t, trn_loader, val_loader):
101 | """Contains the epochs loop"""
102 |
103 | # add exemplars to train_loader
104 | if len(self.exemplars_dataset) > 0 and t > 0:
105 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
106 | batch_size=trn_loader.batch_size,
107 | shuffle=True,
108 | num_workers=trn_loader.num_workers,
109 | pin_memory=trn_loader.pin_memory)
110 |
111 | # FINETUNING TRAINING -- contains the epochs loop
112 | super().train_loop(t, trn_loader, val_loader)
113 |
114 | # EXEMPLAR MANAGEMENT -- select training subset
115 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
116 |
117 | def post_train_process(self, t, trn_loader):
118 | """Runs after training all the epochs of the task (after the train session)"""
119 |
120 | # Store current parameters for the next task
121 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
122 |
123 | # calculate Fisher information
124 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader)
125 | # merge fisher information, we do not want to keep fisher information for each task in memory
126 | for n in self.fisher.keys():
127 | # Added option to accumulate fisher over time with a pre-fixed growing alpha
128 | if self.alpha == -1:
129 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
130 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n]
131 | else:
132 | self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n])
133 |
134 | def compute_orth_loss(self, old_attention_list, attention_list):
135 | totloss = 0.
136 | for i in range(len(attention_list)):
137 | al = attention_list[i].view(-1, 197,197).to(self.device)
138 | ol = old_attention_list[i].view(-1, 197,197).to(self.device)
139 | totloss += (torch.linalg.matrix_norm((al-ol)+1e-08)*float(i)).mean()
140 | return totloss
141 |
142 | def criterion(self, t, outputs, targets):
143 | """Returns the loss value"""
144 | loss = 0
145 | if t > 0:
146 | loss_reg = 0
147 | # Eq. 3: elastic weight consolidation quadratic penalty
148 | for n, p in self.model.model.named_parameters():
149 | if n in self.fisher.keys():
150 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2
151 | loss += self.lamb * loss_reg
152 | # Current cross-entropy loss -- with exemplars use all heads
153 | if len(self.exemplars_dataset) > 0:
154 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
155 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
156 |
--------------------------------------------------------------------------------
/src/approach/path_integral.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the Path Integral (aka Synaptic Intelligence) approach
10 | described in http://proceedings.mlr.press/v70/zenke17a.html
11 | Original code available at https://github.com/ganguli-lab/pathint
12 | """
13 |
14 | def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
15 | momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
16 | logger=None, exemplars_dataset=None, lamb=0.1, damping=0.1):
17 | super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
18 | multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
19 | exemplars_dataset)
20 | self.lamb = lamb
21 | self.damping = damping
22 |
23 | # In all cases, we only keep importance weights for the model, but not for the heads.
24 | feat_ext = self.model.model
25 | # Page 3, following Eq. 3: "The w now have an intuitive interpretation as the parameter specific contribution to
26 | # changes in the total loss."
27 | self.w = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters() if p.requires_grad}
28 | # Store current parameters as the initial parameters before first task starts
29 | self.older_params = {n: p.clone().detach().to(self.device) for n, p in feat_ext.named_parameters()
30 | if p.requires_grad}
31 | # Store importance weights matrices
32 | self.importance = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
33 | if p.requires_grad}
34 |
35 | @staticmethod
36 | def exemplars_dataset_class():
37 | return ExemplarsDataset
38 |
39 | @staticmethod
40 | def extra_parser(args):
41 | """Returns a parser containing the approach specific parameters"""
42 | parser = ArgumentParser()
43 | # Eq. 4: lamb is the 'c' trade-off parameter from the surrogate loss -- 1e-3 < c < 0.1
44 | parser.add_argument('--lamb', default=0.1, type=float, required=False,
45 | help='Forgetting-intransigence trade-off (default=%(default)s)')
46 | # Eq. 5: damping parameter is set to 0.1 in the MNIST case
47 | parser.add_argument('--damping', default=0.1, type=float, required=False,
48 | help='Damping (default=%(default)s)')
49 | return parser.parse_known_args(args)
50 |
51 | def _get_optimizer(self):
52 | """Returns the optimizer"""
53 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
54 | # if there are no exemplars, previous heads are not modified
55 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
56 | else:
57 | params = self.model.parameters()
58 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
59 |
60 | def train_loop(self, t, trn_loader, val_loader):
61 | """Contains the epochs loop"""
62 |
63 | # add exemplars to train_loader
64 | if len(self.exemplars_dataset) > 0 and t > 0:
65 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
66 | batch_size=trn_loader.batch_size,
67 | shuffle=True,
68 | num_workers=trn_loader.num_workers,
69 | pin_memory=trn_loader.pin_memory)
70 |
71 | # FINETUNING TRAINING -- contains the epochs loop
72 | super().train_loop(t, trn_loader, val_loader)
73 |
74 | # EXEMPLAR MANAGEMENT -- select training subset
75 | self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
76 |
77 | def post_train_process(self, t, trn_loader):
78 | """Runs after training all the epochs of the task (after the train session)"""
79 |
80 | # Eq. 5: accumulate Omega regularization strength (importance matrix)
81 | with torch.no_grad():
82 | curr_params = {n: p for n, p in self.model.model.named_parameters() if p.requires_grad}
83 | for n, p in self.importance.items():
84 | p += self.w[n] / ((curr_params[n] - self.older_params[n]) ** 2 + self.damping)
85 | self.w[n].zero_()
86 |
87 | # Store current parameters for the next task
88 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
89 |
90 | def train_epoch(self, t, trn_loader):
91 | """Runs a single epoch"""
92 | self.model.train()
93 | if self.fix_bn and t > 0:
94 | self.model.freeze_bn()
95 | for images, targets in trn_loader:
96 | # store current model without heads
97 | curr_feat_ext = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
98 |
99 | # Forward current model
100 | outputs = self.model(images.to(self.device))
101 | # theoretically this is the correct one for 2 tasks, however, for more tasks maybe is the current loss
102 | # check https://github.com/ganguli-lab/pathint/blob/master/pathint/optimizers.py line 123
103 | # cross-entropy loss on current task
104 | if len(self.exemplars_dataset) == 0:
105 | loss = torch.nn.functional.cross_entropy(outputs[t], targets.to(self.device) - self.model.task_offset[t])
106 | else:
107 | # with exemplars we check output from all heads (train data has all labels)
108 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets.to(self.device))
109 | self.optimizer.zero_grad()
110 | loss.backward(retain_graph=True)
111 | # store gradients without regularization term
112 | unreg_grads = {n: p.grad.clone().detach() for n, p in self.model.model.named_parameters()
113 | if p.grad is not None}
114 | # apply loss with path integral regularization
115 | loss = self.criterion(t, outputs, targets.to(self.device))
116 |
117 | # Backward
118 | self.optimizer.zero_grad()
119 | loss.backward()
120 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
121 | self.optimizer.step()
122 |
123 | # Eq. 3: accumulate w, compute the path integral -- "In practice, we can approximate w online as the running
124 | # sum of the product of the gradient with the parameter update".
125 | with torch.no_grad():
126 | for n, p in self.model.model.named_parameters():
127 | if n in unreg_grads.keys():
128 | # w[n] >=0, but minus for loss decrease
129 | self.w[n] -= unreg_grads[n] * (p.detach() - curr_feat_ext[n])
130 |
131 | def criterion(self, t, outputs, targets):
132 | """Returns the loss value"""
133 | loss = 0
134 | if t > 0:
135 | loss_reg = 0
136 | # Eq. 4: quadratic surrogate loss
137 | for n, p in self.model.model.named_parameters():
138 | loss_reg += torch.sum(self.importance[n] * (p - self.older_params[n]).pow(2))
139 | loss += self.lamb * loss_reg
140 | # Current cross-entropy loss -- with exemplars use all heads
141 | if len(self.exemplars_dataset) > 0:
142 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
143 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
144 |
--------------------------------------------------------------------------------
/src/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 | We include predefined datasets MNIST, CIFAR-100, SVHN, VGGFace2, ImageNet, ImageNet-subset. The first three are
3 | integrated directly from their respective torchvision classes. The other follow our proposed dataset implementation
4 | (see below). We also include `imagenet_32_reduced` to be used with the DMC approach as an external dataset.
5 |
6 | ## Main usage
7 | When running an experiment, datasets used can be defined in [main_incremental.py](../main_incremental.py) using
8 | `--datasets`. A single dataset or a list of datasets can be provided, and will be learned in the given order. If
9 | `--num_tasks` is larger than 1, each dataset will further be split into the given number of tasks. All tasks from the
10 | first dataset will be learned before moving to the next dataset in the list. Other main arguments related to datasets
11 | are:
12 |
13 | * `--num-workers`: number of subprocesses to use for dataloader (default=4)
14 | * `--pin-memory`: copy Tensors into CUDA pinned memory before returning them (default=False)
15 | * `--batch-size`: number of samples per batch to load (default=64)
16 | * `--nc-first-task`: number of classes of the first task (default=None)
17 | * `--use-valid-only`: use validation split instead of test (default=False)
18 | * `--stop-at-task`: stop training after specified task (0: no stop) (default=0)
19 |
20 | Datasets are defined in [dataset_config.py](dataset_config.py). Each entry key is considered a dataset name. For a
21 | proper configuration, the mandatory entry value is `path`, which contains the path to the dataset folder (or for
22 | torchvision datasets, the path to where it will be downloaded). The rest of values are possible transformations to be
23 | applied to the datasets (see [data_loader.py](data_loader.py)):
24 |
25 | * `resize`: resize the input image to the given size on both train and eval
26 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.Resize)].
27 | * `pad`: pad the given image on all sides with the given “pad” value on both train and eval
28 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.functional.pad)].
29 | * `crop`: crop the given image to random size and aspect ratio for train
30 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.RandomResizedCrop)], and at the center
31 | on eval [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.CenterCrop)].
32 | * `flip`: horizontally flip the given image randomly with a 50% probability on train only
33 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.RandomHorizontalFlip)].
34 | * `normalize`: normalize a tensor image with mean and standard deviation
35 | [[source](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.Normalize)].
36 | * `class_order`: fixes the class order given the list of ordered labels. It also allows to limit the dataset to only the
37 | classes provided.
38 | * `extend_channel`: times that the input channels have to be extended.
39 |
40 | where the first ones are adapted from PyTorch transforms, and the last two are our own additions.
41 |
42 | ### Dataset types
43 | MNIST, CIFAR-100 and SVHN are small enough to use [memory_dataset.py](memory_dataset.py), which loads all images in
44 | memory. The other datasets are too large to fully load in memory and, therefore, use [base_dataset.py](base_dataset.py).
45 | This dataset type loads the corresponding paths and labels in memory and the images make use the PyTorch DataLoader to
46 | load the images in batches when needed. This can be modified in [data_loader.py](data_loader.py#L64) by modifying the
47 | type of dataset used.
48 |
49 | ### Dataset path
50 | Modify variable `_BASE_DATA_PATH` in [dataset_config.py](dataset_config.py) pointing to the root folder containing your
51 | datasets. You can also modify the `path` field in the entries in [dataset_config.py](dataset_config.py) if a specific
52 | dataset is found to a special location in your machine.
53 |
54 | ### Exemplars
55 | For those approaches that can use exemplars, those can be used with either `--num_exemplars` or `--num_exemplars_per_class`. The first one is for a fixed memory where the number of exemplars is the same for all classes. The second one is a growing memory which specifies the number of exemplars per class that will be stored.
56 |
57 | Different exemplar sampling strategies are implemented in [exemplars_selection.py](exemplars_selection.py). Those can be selected by using `--exemplar_selection` followed by one of this strategies:
58 |
59 | * `random`: produces a random list of samples.
60 | * `herding`: sampling based on distance to the mean sample of each class. From [iCaRL](https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf) algorithms 4 and 5.
61 | * `entropy`: sampling based on entropy of each sample. From [RWalk](http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112).
62 | * `distance`: sampling based on closeness to decision boundary of each sample. From [RWalk](http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112).
63 |
64 |
65 | ## Adding new datasets
66 | To add a new dataset, follow this:
67 |
68 | 1. Create a new entry in [dataset_config.py](dataset_config.py), add the folder with the data in `path` and any other transformations or class ordering needed.
69 | 2. Depending if the dataset is in [torchvision](https://pytorch.org/docs/stable/torchvision/datasets.html) or custom:
70 | * If the new dataset is available in **torchvision**, you can add a new option like in lines 67, 79 or 91 from [data_loader.py](data_loader.py). If the dataset is too large to fit in memory, use `base_dataset`, else `memory_dataset`.
71 | * If the dataset is **custom**, the option from line 135 in [data_loader.py](data_loader.py) should be enough. In the same folder as the data add a `train.txt` and a `test.txt` files. Each containing one line per sample with the path and the corresponding class label:
72 | ```
73 | /data/train/sample1.jpg 0
74 | /data/train/sample2.jpg 1
75 | /data/train/sample3.jpg 2
76 | ...
77 | ```
78 |
79 | No need to modify any other file. If the dataset is a subset or modification of an already added dataset, only use step 1.
80 |
81 | ### Available custom datasets
82 | We provide `.txt` and `.zip` files for the following datasets for an easier integration with FACIL:
83 |
84 | * (COMING SOON)
85 |
86 | ## Changing dataset transformations
87 |
88 | There are some cases when it's necessary to get data from the existing dataset with a different transformation (i.e. self-supervision) or none at all (i.e. selecting exemplars). For this particular cases we prepared a simple solution to override transformations for a dataset within a selected context, e.g. taking raw pictures in `numpy` format:
89 |
90 | ```python
91 | with override_dataset_transform(trn_loader.dataset, Lambda(lambda x: np.array(x))) as ds_for_raw:
92 | x, y = zip(*(ds_for_raw[idx] for idx in selected_indices))
93 | ```
94 |
95 | This can be used in other cases too, like train/eval change of transformation when checking the training process, etc. Also, you can check out a simple unit test in [test_dataset_transforms.py](src/tests/test_datasets_transforms.py).
96 |
97 | ## Notes
98 | * As an example, we include two versions of CIFAR-100. The one with entry `cifar100` is the default one which by default
99 | shuffles the class order depending on the fixed seed. We also include `cifar100_icarl` which fixes the class order
100 | from iCaRL (given by seed 1993), and thus makes the comparison more fair with results from papers that use that class
101 | order (e.g. iCaRL, LUCIR, BiC).
102 | * When using multiple machines, you can create different [dataset_config.py](dataset_config.py) files for each of them.
103 | Otherwise, remember you can create symbolic links that point to a specific folder with
104 | `ln -s TARGET_DIRECTORY LINK_NAME` in Linux/UNIX.
105 |
--------------------------------------------------------------------------------
/src/datasets/__pycache__/base_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/base_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/base_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/base_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/data_loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/data_loader.cpython-38.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/dataset_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/dataset_config.cpython-36.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/dataset_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/dataset_config.cpython-38.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/exemplars_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/exemplars_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/exemplars_selection.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/exemplars_selection.cpython-38.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/memory_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/memory_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/src/datasets/__pycache__/memory_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/datasets/__pycache__/memory_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/src/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class BaseDataset(Dataset):
9 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory"""
10 |
11 | def __init__(self, data, transform, class_indices=None):
12 | """Initialization"""
13 | self.labels = data['y']
14 | self.images = data['x']
15 | self.transform = transform
16 | self.class_indices = class_indices
17 |
18 | def __len__(self):
19 | """Denotes the total number of samples"""
20 | return len(self.images)
21 |
22 | def __getitem__(self, index):
23 | """Generates one sample of data"""
24 | x = Image.open(self.images[index]).convert('RGB')
25 | x = self.transform(x)
26 | y = self.labels[index]
27 | return x, y
28 |
29 |
30 | def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
31 | """Prepare data: dataset splits, task partition, class order"""
32 |
33 | data = {}
34 | taskcla = []
35 |
36 | # read filenames and labels
37 | trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str)
38 | tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str)
39 | if class_order is None:
40 | num_classes = len(np.unique(trn_lines[:, 1]))
41 | class_order = list(range(num_classes))
42 | else:
43 | num_classes = len(class_order)
44 | class_order = class_order.copy()
45 | if shuffle_classes:
46 | np.random.shuffle(class_order)
47 |
48 | # compute classes per task and num_tasks
49 | if nc_first_task is None:
50 | cpertask = np.array([num_classes // num_tasks] * num_tasks)
51 | for i in range(num_classes % num_tasks):
52 | cpertask[i] += 1
53 | else:
54 | assert nc_first_task < num_classes, "first task wants more classes than exist"
55 | remaining_classes = num_classes - nc_first_task
56 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
57 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
58 | for i in range(remaining_classes % (num_tasks - 1)):
59 | cpertask[i + 1] += 1
60 |
61 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
62 | cpertask_cumsum = np.cumsum(cpertask)
63 | init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
64 |
65 | # initialize data structure
66 | for tt in range(num_tasks):
67 | data[tt] = {}
68 | data[tt]['name'] = 'task-' + str(tt)
69 | data[tt]['trn'] = {'x': [], 'y': []}
70 | data[tt]['val'] = {'x': [], 'y': []}
71 | data[tt]['tst'] = {'x': [], 'y': []}
72 |
73 | # ALL OR TRAIN
74 | for this_image, this_label in trn_lines:
75 | if not os.path.isabs(this_image):
76 | this_image = os.path.join(path, this_image)
77 | this_label = int(this_label)
78 | if this_label not in class_order:
79 | continue
80 | # If shuffling is false, it won't change the class number
81 | this_label = class_order.index(this_label)
82 |
83 | # add it to the corresponding split
84 | this_task = (this_label >= cpertask_cumsum).sum()
85 | data[this_task]['trn']['x'].append(this_image)
86 | data[this_task]['trn']['y'].append(this_label - init_class[this_task])
87 |
88 | # ALL OR TEST
89 | for this_image, this_label in tst_lines:
90 | if not os.path.isabs(this_image):
91 | this_image = os.path.join(path, this_image)
92 | this_label = int(this_label)
93 | if this_label not in class_order:
94 | continue
95 | # If shuffling is false, it won't change the class number
96 | this_label = class_order.index(this_label)
97 |
98 | # add it to the corresponding split
99 | this_task = (this_label >= cpertask_cumsum).sum()
100 | data[this_task]['tst']['x'].append(this_image)
101 | data[this_task]['tst']['y'].append(this_label - init_class[this_task])
102 |
103 | # check classes
104 | for tt in range(num_tasks):
105 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
106 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
107 |
108 | # validation
109 | if validation > 0.0:
110 | for tt in data.keys():
111 | for cc in range(data[tt]['ncla']):
112 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
113 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
114 | rnd_img.sort(reverse=True)
115 | for ii in range(len(rnd_img)):
116 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
117 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
118 | data[tt]['trn']['x'].pop(rnd_img[ii])
119 | data[tt]['trn']['y'].pop(rnd_img[ii])
120 |
121 | # other
122 | n = 0
123 | for t in data.keys():
124 | taskcla.append((t, data[t]['ncla']))
125 | n += data[t]['ncla']
126 | data['ncla'] = n
127 |
128 | return data, taskcla, class_order
129 |
--------------------------------------------------------------------------------
/src/datasets/dataset_config.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | _BASE_DATA_PATH = "/datatmp/users/fpelosin/data/"
4 |
5 | dataset_config = {
6 | 'mnist': {
7 | 'path': join(_BASE_DATA_PATH, 'mnist'),
8 | 'normalize': ((0.1307,), (0.3081,)),
9 | # Use the next 3 lines to use MNIST with a 3x32x32 input
10 | # 'extend_channel': 3,
11 | # 'pad': 2,
12 | # 'normalize': ((0.1,), (0.2752,)) # values including padding
13 | },
14 | 'svhn': {
15 | 'path': join(_BASE_DATA_PATH, 'svhn'),
16 | 'resize': (224, 224),
17 | 'crop': None,
18 | 'flip': False,
19 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
20 | },
21 | 'cifar100_224': {
22 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
23 | 'resize': 224,
24 | 'pad': 4,
25 | 'crop': 224,
26 | 'flip': True,
27 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
28 | },
29 | 'cifar100': {
30 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
31 | 'resize': None,
32 | 'pad': 4,
33 | 'crop': 32,
34 | 'flip': True,
35 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
36 | },
37 | 'cifar10': {
38 | 'path': join(_BASE_DATA_PATH, 'cifar10'),
39 | 'resize': None,
40 | 'pad': 4,
41 | 'crop': 32,
42 | 'flip': True,
43 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
44 | },
45 | 'cifar100_icarl': {
46 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
47 | 'resize': None,
48 | 'pad': 4,
49 | 'crop': 32,
50 | 'flip': True,
51 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
52 | 'class_order': [
53 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
54 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
55 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
56 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
57 | ]
58 | },
59 | 'vggface2': {
60 | 'path': join(_BASE_DATA_PATH, 'VGGFace2'),
61 | 'resize': 256,
62 | 'crop': 224,
63 | 'flip': True,
64 | 'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169))
65 | },
66 | 'imagenet_256': {
67 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'),
68 | 'resize': None,
69 | 'crop': 224,
70 | 'flip': True,
71 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
72 | },
73 | 'imagenet_subset': {
74 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'),
75 | 'resize': None,
76 | 'crop': 224,
77 | 'flip': True,
78 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
79 | 'class_order': [
80 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
81 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
82 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
83 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
84 | ]
85 | },
86 | 'imagenet_32_reduced': {
87 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_32'),
88 | 'resize': None,
89 | 'pad': 4,
90 | 'crop': 32,
91 | 'flip': True,
92 | 'normalize': ((0.481, 0.457, 0.408), (0.260, 0.253, 0.268)),
93 | 'class_order': [
94 | 472, 46, 536, 806, 547, 976, 662, 12, 955, 651, 492, 80, 999, 996, 788, 471, 911, 907, 680, 126, 42, 882,
95 | 327, 719, 716, 224, 918, 647, 808, 261, 140, 908, 833, 925, 57, 388, 407, 215, 45, 479, 525, 641, 915, 923,
96 | 108, 461, 186, 843, 115, 250, 829, 625, 769, 323, 974, 291, 438, 50, 825, 441, 446, 200, 162, 373, 872, 112,
97 | 212, 501, 91, 672, 791, 370, 942, 172, 315, 959, 636, 635, 66, 86, 197, 182, 59, 736, 175, 445, 947, 268,
98 | 238, 298, 926, 851, 494, 760, 61, 293, 696, 659, 69, 819, 912, 486, 706, 343, 390, 484, 282, 729, 575, 731,
99 | 530, 32, 534, 838, 466, 734, 425, 400, 290, 660, 254, 266, 551, 775, 721, 134, 886, 338, 465, 236, 522, 655,
100 | 209, 861, 88, 491, 985, 304, 981, 560, 405, 902, 521, 909, 763, 455, 341, 905, 280, 776, 113, 434, 274, 581,
101 | 158, 738, 671, 702, 147, 718, 148, 35, 13, 585, 591, 371, 745, 281, 956, 935, 346, 352, 284, 604, 447, 415,
102 | 98, 921, 118, 978, 880, 509, 381, 71, 552, 169, 600, 334, 171, 835, 798, 77, 249, 318, 419, 990, 335, 374,
103 | 949, 316, 755, 878, 946, 142, 299, 863, 558, 306, 183, 417, 64, 765, 565, 432, 440, 939, 297, 805, 364, 735,
104 | 251, 270, 493, 94, 773, 610, 278, 16, 363, 92, 15, 593, 96, 468, 252, 699, 377, 95, 799, 868, 820, 328, 756,
105 | 81, 991, 464, 774, 584, 809, 844, 940, 720, 498, 310, 384, 619, 56, 406, 639, 285, 67, 634, 792, 232, 54,
106 | 664, 818, 513, 349, 330, 207, 361, 345, 279, 549, 944, 817, 353, 228, 312, 796, 193, 179, 520, 451, 871,
107 | 692, 60, 481, 480, 929, 499, 673, 331, 506, 70, 645, 759, 744, 459]
108 | }
109 | }
110 |
111 | # Add missing keys:
112 | for dset in dataset_config.keys():
113 | for k in ['resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel']:
114 | if k not in dataset_config[dset].keys():
115 | dataset_config[dset][k] = None
116 | if 'flip' not in dataset_config[dset].keys():
117 | dataset_config[dset]['flip'] = False
118 |
--------------------------------------------------------------------------------
/src/datasets/exemplars_dataset.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from argparse import ArgumentParser
3 |
4 | from datasets.memory_dataset import MemoryDataset
5 |
6 |
7 | class ExemplarsDataset(MemoryDataset):
8 | """Exemplar storage for approaches with an interface of Dataset"""
9 |
10 | def __init__(self, transform, class_indices,
11 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'):
12 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices)
13 | self.max_num_exemplars_per_class = num_exemplars_per_class
14 | self.max_num_exemplars = num_exemplars
15 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!'
16 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize())
17 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name)
18 | self.exemplars_selector = selector_cls(self)
19 |
20 | # Returns a parser containing the approach specific parameters
21 | @staticmethod
22 | def extra_parser(args):
23 | parser = ArgumentParser("Exemplars Management Parameters")
24 | _group = parser.add_mutually_exclusive_group()
25 | _group.add_argument('--num-exemplars', default=0, type=int, required=False,
26 | help='Fixed memory, total number of exemplars (default=%(default)s)')
27 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False,
28 | help='Growing memory, number of exemplars per class (default=%(default)s)')
29 | parser.add_argument('--exemplar-selection', default='random', type=str,
30 | choices=['herding', 'random', 'entropy', 'distance'],
31 | required=False, help='Exemplar selection strategy (default=%(default)s)')
32 | return parser.parse_known_args(args)
33 |
34 | def _is_active(self):
35 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0
36 |
37 | def collect_exemplars(self, model, trn_loader, selection_transform):
38 | if self._is_active():
39 | self.images, self.labels = self.exemplars_selector(model, trn_loader, selection_transform)
40 |
--------------------------------------------------------------------------------
/src/datasets/memory_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class MemoryDataset(Dataset):
8 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory"""
9 |
10 | def __init__(self, data, transform, class_indices=None):
11 | """Initialization"""
12 | self.labels = data['y']
13 | self.images = data['x']
14 | self.transform = transform
15 | self.class_indices = class_indices
16 |
17 | def __len__(self):
18 | """Denotes the total number of samples"""
19 | return len(self.images)
20 |
21 | def __getitem__(self, index):
22 | """Generates one sample of data"""
23 | x = Image.fromarray(self.images[index])
24 | x = self.transform(x)
25 | y = self.labels[index]
26 | return x, y
27 |
28 |
29 | def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
30 | """Prepare data: dataset splits, task partition, class order"""
31 |
32 | data = {}
33 | taskcla = []
34 | if class_order is None:
35 | num_classes = len(np.unique(trn_data['y']))
36 | class_order = list(range(num_classes))
37 | else:
38 | num_classes = len(class_order)
39 | class_order = class_order.copy()
40 | if shuffle_classes:
41 | np.random.shuffle(class_order)
42 |
43 | # compute classes per task and num_tasks
44 | if nc_first_task is None:
45 | cpertask = np.array([num_classes // num_tasks] * num_tasks)
46 | for i in range(num_classes % num_tasks):
47 | cpertask[i] += 1
48 | else:
49 | print(nc_first_task, num_classes)
50 | assert nc_first_task < num_classes, "first task wants more classes than exist"
51 | remaining_classes = num_classes - nc_first_task
52 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
53 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
54 | for i in range(remaining_classes % (num_tasks - 1)):
55 | cpertask[i + 1] += 1
56 |
57 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
58 | cpertask_cumsum = np.cumsum(cpertask)
59 | init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
60 |
61 | # initialize data structure
62 | for tt in range(num_tasks):
63 | data[tt] = {}
64 | data[tt]['name'] = 'task-' + str(tt)
65 | data[tt]['trn'] = {'x': [], 'y': []}
66 | data[tt]['val'] = {'x': [], 'y': []}
67 | data[tt]['tst'] = {'x': [], 'y': []}
68 |
69 | # ALL OR TRAIN
70 | filtering = np.isin(trn_data['y'], class_order)
71 | if filtering.sum() != len(trn_data['y']):
72 | trn_data['x'] = trn_data['x'][filtering]
73 | trn_data['y'] = np.array(trn_data['y'])[filtering]
74 | for this_image, this_label in zip(trn_data['x'], trn_data['y']):
75 | # If shuffling is false, it won't change the class number
76 | this_label = class_order.index(this_label)
77 | # add it to the corresponding split
78 | this_task = (this_label >= cpertask_cumsum).sum()
79 | data[this_task]['trn']['x'].append(this_image)
80 | data[this_task]['trn']['y'].append(this_label - init_class[this_task])
81 |
82 | # ALL OR TEST
83 | filtering = np.isin(tst_data['y'], class_order)
84 | if filtering.sum() != len(tst_data['y']):
85 | tst_data['x'] = tst_data['x'][filtering]
86 | tst_data['y'] = tst_data['y'][filtering]
87 | for this_image, this_label in zip(tst_data['x'], tst_data['y']):
88 | # If shuffling is false, it won't change the class number
89 | this_label = class_order.index(this_label)
90 | # add it to the corresponding split
91 | this_task = (this_label >= cpertask_cumsum).sum()
92 | data[this_task]['tst']['x'].append(this_image)
93 | data[this_task]['tst']['y'].append(this_label - init_class[this_task])
94 |
95 | # check classes
96 | for tt in range(num_tasks):
97 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
98 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
99 |
100 | # validation
101 | if validation > 0.0:
102 | for tt in data.keys():
103 | for cc in range(data[tt]['ncla']):
104 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
105 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
106 | rnd_img.sort(reverse=True)
107 | for ii in range(len(rnd_img)):
108 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
109 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
110 | data[tt]['trn']['x'].pop(rnd_img[ii])
111 | data[tt]['trn']['y'].pop(rnd_img[ii])
112 | # drop_percent = 0.9
113 | # for tt in list(data.keys())[1:]:
114 | # for cc in range(data[tt]['ncla']):
115 | # cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
116 | # rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * drop_percent)))
117 | # rnd_img.sort(reverse=True)
118 | # for ii in range(len(rnd_img)):
119 | # data[tt]['trn']['x'].pop(rnd_img[ii])
120 | # data[tt]['trn']['y'].pop(rnd_img[ii])
121 | for tt in list(data.keys()):
122 | print(tt, len(data[tt]['trn']['x']))
123 | # convert them to numpy arrays
124 | for tt in data.keys():
125 | for split in ['trn', 'val', 'tst']:
126 | data[tt][split]['x'] = np.asarray(data[tt][split]['x'])
127 |
128 | # other
129 | n = 0
130 | for t in data.keys():
131 | taskcla.append((t, data[t]['ncla']))
132 | n += data[t]['ncla']
133 | data['ncla'] = n
134 |
135 | return data, taskcla, class_order
136 |
--------------------------------------------------------------------------------
/src/gridsearch.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from argparse import ArgumentParser
4 |
5 | import utils
6 |
7 |
8 | class GridSearch:
9 | """Basic class for implementing hyperparameter grid search"""
10 |
11 | def __init__(self, appr_ft, seed, gs_config='gridsearch_config', acc_drop_thr=0.2, hparam_decay=0.5,
12 | max_num_searches=7):
13 | self.seed = seed
14 | GridSearchConfig = getattr(importlib.import_module(name=gs_config), 'GridSearchConfig')
15 | self.appr_ft = appr_ft
16 | self.gs_config = GridSearchConfig()
17 | self.acc_drop_thr = acc_drop_thr
18 | self.hparam_decay = hparam_decay
19 | self.max_num_searches = max_num_searches
20 | self.lr_first = 1.0
21 |
22 | @staticmethod
23 | def extra_parser(args):
24 | """Returns a parser containing the GridSearch specific parameters"""
25 | parser = ArgumentParser()
26 | # Configuration file with a GridSearchConfig class with all necessary args
27 | parser.add_argument('--gridsearch-config', type=str, default='gridsearch_config', required=False,
28 | help='Configuration file for GridSearch options (default=%(default)s)')
29 | # Accuracy threshold drop below which the search stops for that phase
30 | parser.add_argument('--gridsearch-acc-drop-thr', default=0.2, type=float, required=False,
31 | help='GridSearch accuracy drop threshold (default=%(default)f)')
32 | # Value at which hyperparameters decay
33 | parser.add_argument('--gridsearch-hparam-decay', default=0.5, type=float, required=False,
34 | help='GridSearch hyperparameter decay (default=%(default)f)')
35 | # Maximum number of searched before the search stops for that phase
36 | parser.add_argument('--gridsearch-max-num-searches', default=7, type=int, required=False,
37 | help='GridSearch maximum number of hyperparameter search (default=%(default)f)')
38 | return parser.parse_known_args(args)
39 |
40 | def search_lr(self, model, t, trn_loader, val_loader):
41 | """Search for accuracy and best LR on finetuning"""
42 | best_ft_acc = 0.0
43 | best_ft_lr = 0.0
44 |
45 | # Get general parameters and fix the ones with only one value
46 | gen_params = self.gs_config.get_params('general')
47 | for k, v in gen_params.items():
48 | if not isinstance(v, list):
49 | setattr(self.appr_ft, k, v)
50 | if t > 0:
51 | # LR for search are 'lr_searches' largest LR below 'lr_first'
52 | list_lr = [lr for lr in gen_params['lr'] if lr < self.lr_first][:gen_params['lr_searches'][0]]
53 | else:
54 | # For first task, try larger LR range
55 | list_lr = gen_params['lr_first']
56 |
57 | # Iterate through the other variable parameters
58 | for curr_lr in list_lr:
59 | utils.seed_everything(seed=self.seed)
60 | self.appr_ft.model = deepcopy(model)
61 | self.appr_ft.lr = curr_lr
62 | self.appr_ft.train(t, trn_loader, val_loader)
63 | _, ft_acc_taw, _ = self.appr_ft.eval(t, val_loader)
64 | if ft_acc_taw > best_ft_acc:
65 | best_ft_acc = ft_acc_taw
66 | best_ft_lr = curr_lr
67 | print('Current best LR: ' + str(best_ft_lr))
68 | self.gs_config.current_lr = best_ft_lr
69 | print('Current best acc: {:5.1f}'.format(best_ft_acc * 100))
70 | # After first task, keep LR used
71 | if t == 0:
72 | self.lr_first = best_ft_lr
73 |
74 | return best_ft_acc, best_ft_lr
75 |
76 | def search_tradeoff(self, appr_name, appr, t, trn_loader, val_loader, best_ft_acc):
77 | """Search for less-forgetting tradeoff with minimum accuracy loss"""
78 | best_tradeoff = None
79 | tradeoff_name = None
80 |
81 | # Get general parameters and fix all the ones that have only one option
82 | appr_params = self.gs_config.get_params(appr_name)
83 | for k, v in appr_params.items():
84 | if isinstance(v, list):
85 | # get tradeoff name as the only one with multiple values
86 | tradeoff_name = k
87 | else:
88 | # Any other hyperparameters are fixed
89 | setattr(appr, k, v)
90 |
91 | # If there is no tradeoff, no need to gridsearch more
92 | if tradeoff_name is not None and t > 0:
93 | # get starting value for trade-off hyperparameter
94 | best_tradeoff = appr_params[tradeoff_name][0]
95 | # iterate through decreasing trade-off values -- limit to `max_num_searches` searches
96 | num_searches = 0
97 | while num_searches < self.max_num_searches:
98 | utils.seed_everything(seed=self.seed)
99 | # Make deepcopy of the appr without duplicating the logger
100 | appr_gs = type(appr)(deepcopy(appr.model), appr.device, exemplars_dataset=appr.exemplars_dataset)
101 | for attr, value in vars(appr).items():
102 | if attr == 'logger':
103 | setattr(appr_gs, attr, value)
104 | else:
105 | setattr(appr_gs, attr, deepcopy(value))
106 |
107 | # update tradeoff value
108 | setattr(appr_gs, tradeoff_name, best_tradeoff)
109 | # train this iteration
110 | appr_gs.train(t, trn_loader, val_loader)
111 | _, curr_acc, _ = appr_gs.eval(t, val_loader)
112 | print('Current acc: ' + str(curr_acc) + ' for ' + tradeoff_name + '=' + str(best_tradeoff))
113 | # Check if accuracy is within acceptable threshold drop
114 | if curr_acc < ((1 - self.acc_drop_thr) * best_ft_acc):
115 | best_tradeoff = best_tradeoff * self.hparam_decay
116 | else:
117 | break
118 | num_searches += 1
119 | else:
120 | print('There is no trade-off to gridsearch.')
121 |
122 | return best_tradeoff, tradeoff_name
123 |
--------------------------------------------------------------------------------
/src/gridsearch_config.py:
--------------------------------------------------------------------------------
1 | class GridSearchConfig():
2 | def __init__(self):
3 | self.params = {
4 | 'general': {
5 | 'lr_first': [5e-1, 1e-1, 5e-2],
6 | 'lr': [1e-1, 5e-2, 1e-2, 5e-3, 1e-3],
7 | 'lr_searches': [3],
8 | 'lr_min': 1e-4,
9 | 'lr_factor': 3,
10 | 'lr_patience': 10,
11 | 'clipping': 10000,
12 | 'momentum': 0.9,
13 | 'wd': 0.0002
14 | },
15 | 'finetuning': {
16 | },
17 | 'freezing': {
18 | },
19 | 'joint': {
20 | },
21 | 'lwf': {
22 | 'lamb': [10],
23 | 'T': 2
24 | },
25 | 'icarl': {
26 | 'lamb': [4]
27 | },
28 | 'dmc': {
29 | 'aux_dataset': 'imagenet_32_reduced',
30 | 'aux_batch_size': 128
31 | },
32 | 'il2m': {
33 | },
34 | 'eeil': {
35 | 'lamb': [10],
36 | 'T': 2,
37 | 'lr_finetuning_factor': 0.1,
38 | 'nepochs_finetuning': 40,
39 | 'noise_grad': False
40 | },
41 | 'bic': {
42 | 'T': 2,
43 | 'val_percentage': 0.1,
44 | 'bias_epochs': 200
45 | },
46 | 'lucir': {
47 | 'lamda_base': [10],
48 | 'lamda_mr': 1.0,
49 | 'dist': 0.5,
50 | 'K': 2
51 | },
52 | 'lwm': {
53 | 'beta': [2],
54 | 'gamma': 1.0
55 | },
56 | 'ewc': {
57 | 'lamb': [10000]
58 | },
59 | 'mas': {
60 | 'lamb': [400]
61 | },
62 | 'path_integral': {
63 | 'lamb': [10],
64 | },
65 | 'r_walk': {
66 | 'lamb': [20],
67 | },
68 | }
69 | self.current_lr = self.params['general']['lr'][0]
70 | self.current_tradeoff = 0
71 |
72 | def get_params(self, approach):
73 | return self.params[approach]
74 |
--------------------------------------------------------------------------------
/src/last_layer_analysis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | matplotlib.use('Agg')
6 |
7 |
8 | def last_layer_analysis(heads, task, taskcla, y_lim=False, sort_weights=False):
9 | """Plot last layer weight and bias analysis"""
10 | print('Plotting last layer analysis...')
11 | num_classes = sum([x for (_, x) in taskcla])
12 | weights, biases, indexes = [], [], []
13 | class_id = 0
14 | with torch.no_grad():
15 | for t in range(task + 1):
16 | n_classes_t = taskcla[t][1]
17 | indexes.append(np.arange(class_id, class_id + n_classes_t))
18 | if type(heads) == torch.nn.Linear: # Single head
19 | biases.append(heads.bias[class_id: class_id + n_classes_t].detach().cpu().numpy())
20 | weights.append((heads.weight[class_id: class_id + n_classes_t] ** 2).sum(1).sqrt().detach().cpu().numpy())
21 | else: # Multi-head
22 | weights.append((heads[t].weight ** 2).sum(1).sqrt().detach().cpu().numpy())
23 | if type(heads[t]) == torch.nn.Linear:
24 | biases.append(heads[t].bias.detach().cpu().numpy())
25 | else:
26 | biases.append(np.zeros(weights[-1].shape)) # For LUCIR
27 | class_id += n_classes_t
28 |
29 | # Figure weights
30 | f_weights = plt.figure(dpi=300)
31 | ax = f_weights.subplots(nrows=1, ncols=1)
32 | for i, (x, y) in enumerate(zip(indexes, weights), 0):
33 | if sort_weights:
34 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i))
35 | else:
36 | ax.bar(x, y, label="Task {}".format(i))
37 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif')
38 | ax.set_ylabel("Weights L2-norm", fontsize=11, fontfamily='serif')
39 | if num_classes is not None:
40 | ax.set_xlim(0, num_classes)
41 | if y_lim:
42 | ax.set_ylim(0, 5)
43 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif')
44 |
45 | # Figure biases
46 | f_biases = plt.figure(dpi=300)
47 | ax = f_biases.subplots(nrows=1, ncols=1)
48 | for i, (x, y) in enumerate(zip(indexes, biases), 0):
49 | if sort_weights:
50 | ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i))
51 | else:
52 | ax.bar(x, y, label="Task {}".format(i))
53 | ax.set_xlabel("Classes", fontsize=11, fontfamily='serif')
54 | ax.set_ylabel("Bias values", fontsize=11, fontfamily='serif')
55 | if num_classes is not None:
56 | ax.set_xlim(0, num_classes)
57 | if y_lim:
58 | ax.set_ylim(-1.0, 1.0)
59 | ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif')
60 |
61 | return f_weights, f_biases
62 |
--------------------------------------------------------------------------------
/src/loggers/README.md:
--------------------------------------------------------------------------------
1 | # Loggers
2 |
3 | We include a disk logger, which logs into files and folders in the disk. We also provide a tensorboard logger which
4 | provides a faster way of analysing a training process without need of further development. They can be specified with
5 | `--log` followed by `disk`, `tensorboard` or both. Custom loggers can be defined by inheriting the `ExperimentLogger`
6 | in [exp_logger.py](exp_logger.py).
7 |
8 | When enabled, both loggers will output everything in the path `[RESULTS_PATH]/[DATASETS]_[APPROACH]_[EXP_NAME]` or
9 | `[RESULTS_PATH]/[DATASETS]_[APPROACH]` if `--exp-name` is not set.
10 |
11 | ## Disk logger
12 | The disk logger outputs the following file and folder structure:
13 | - **figures/**: folder where generated figures are logged.
14 | - **models/**: folder where model weight checkpoints are saved.
15 | - **results/**: folder containing the results.
16 | - **acc_tag**: task-agnostic accuracy table.
17 | - **acc_taw**: task-aware accuracy table.
18 | - **avg_acc_tag**: task-agnostic average accuracies.
19 | - **avg_acc_taw**: task-agnostic average accuracies.
20 | - **forg_tag**: task-agnostic forgetting table.
21 | - **forg_taw**: task-aware forgetting table.
22 | - **wavg_acc_tag**: task-agnostic average accuracies weighted according to the number of classes of each task.
23 | - **wavg_acc_taw**: task-aware average accuracies weighted according to the number of classes of each task.
24 | - **raw_log**: json file containing all the logged metrics easily read by many tools (e.g. `pandas`).
25 | - stdout: a copy from the standard output of the terminal.
26 | - stderr: a copy from the error output of the terminal.
27 |
28 | ## TensorBoard logger
29 | The tensorboard logger outputs analogous metrics to the disk logger separated into different tabs according to the task
30 | and different graphs according to the data splits.
31 |
32 | Screenshot for a 10 task experiment, showing the last task plots:
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/src/loggers/__pycache__/disk_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/disk_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/src/loggers/__pycache__/exp_logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/exp_logger.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loggers/__pycache__/exp_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/exp_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/src/loggers/__pycache__/tensorboard_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/loggers/__pycache__/tensorboard_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/src/loggers/disk_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import torch
5 | import numpy as np
6 | from datetime import datetime
7 |
8 | from loggers.exp_logger import ExperimentLogger
9 |
10 |
11 | class Logger(ExperimentLogger):
12 | """Characterizes a disk logger"""
13 |
14 | def __init__(self, log_path, exp_name, begin_time=None):
15 | super(Logger, self).__init__(log_path, exp_name, begin_time)
16 |
17 | self.begin_time_str = self.begin_time.strftime("%Y-%m-%d-%H-%M")
18 |
19 | # Duplicate standard outputs
20 | sys.stdout = FileOutputDuplicator(sys.stdout,
21 | os.path.join(self.exp_path, 'stdout-{}.txt'.format(self.begin_time_str)), 'w')
22 | sys.stderr = FileOutputDuplicator(sys.stderr,
23 | os.path.join(self.exp_path, 'stderr-{}.txt'.format(self.begin_time_str)), 'w')
24 |
25 | # Raw log file
26 | self.raw_log_file = open(os.path.join(self.exp_path, "raw_log-{}.txt".format(self.begin_time_str)), 'a')
27 |
28 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
29 | if curtime is None:
30 | curtime = datetime.now()
31 |
32 | # Raw dump
33 | entry = {"task": task, "iter": iter, "name": name, "value": value, "group": group,
34 | "time": curtime.strftime("%Y-%m-%d-%H-%M")}
35 | self.raw_log_file.write(json.dumps(entry, sort_keys=True) + "\n")
36 | self.raw_log_file.flush()
37 |
38 | def log_args(self, args):
39 | with open(os.path.join(self.exp_path, 'args-{}.txt'.format(self.begin_time_str)), 'w') as f:
40 | json.dump(args.__dict__, f, separators=(',\n', ' : '), sort_keys=True)
41 |
42 | def log_result(self, array, name, step):
43 | if array.ndim <= 1:
44 | array = array[None]
45 | np.savetxt(os.path.join(self.exp_path, 'results', '{}-{}.txt'.format(name, self.begin_time_str)),
46 | array, '%.6f', delimiter='\t')
47 |
48 | def log_figure(self, name, iter, figure, curtime=None):
49 | curtime = datetime.now()
50 | figure.savefig(os.path.join(self.exp_path, 'figures',
51 | '{}_{}-{}.png'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S"))))
52 | figure.savefig(os.path.join(self.exp_path, 'figures',
53 | '{}_{}-{}.pdf'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S"))))
54 |
55 | def save_model(self, state_dict, task):
56 | torch.save(state_dict, os.path.join(self.exp_path, "models", "task{}.ckpt".format(task)))
57 |
58 | def __del__(self):
59 | self.raw_log_file.close()
60 |
61 |
62 | class FileOutputDuplicator(object):
63 | def __init__(self, duplicate, fname, mode):
64 | self.file = open(fname, mode)
65 | self.duplicate = duplicate
66 |
67 | def __del__(self):
68 | self.file.close()
69 |
70 | def write(self, data):
71 | self.file.write(data)
72 | self.duplicate.write(data)
73 |
74 | def flush(self):
75 | self.file.flush()
76 |
--------------------------------------------------------------------------------
/src/loggers/exp_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | from datetime import datetime
4 |
5 |
6 | class ExperimentLogger:
7 | """Main class for experiment logging"""
8 |
9 | def __init__(self, log_path, exp_name, begin_time=None):
10 | self.log_path = log_path
11 | self.exp_name = exp_name
12 | self.exp_path = os.path.join(log_path, exp_name)
13 | if begin_time is None:
14 | self.begin_time = datetime.now()
15 | else:
16 | self.begin_time = begin_time
17 |
18 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
19 | pass
20 |
21 | def log_args(self, args):
22 | pass
23 |
24 | def log_result(self, array, name, step):
25 | pass
26 |
27 | def log_figure(self, name, iter, figure, curtime=None):
28 | pass
29 |
30 | def save_model(self, state_dict, task):
31 | pass
32 |
33 |
34 | class MultiLogger(ExperimentLogger):
35 | """This class allows to use multiple loggers"""
36 |
37 | def __init__(self, log_path, exp_name, loggers=None, save_models=True):
38 | super(MultiLogger, self).__init__(log_path, exp_name)
39 | if os.path.exists(self.exp_path):
40 | print("WARNING: {} already exists!".format(self.exp_path))
41 | else:
42 | os.makedirs(os.path.join(self.exp_path, 'models'))
43 | os.makedirs(os.path.join(self.exp_path, 'results'))
44 | os.makedirs(os.path.join(self.exp_path, 'figures'))
45 |
46 | self.save_models = save_models
47 | self.loggers = []
48 | for l in loggers:
49 | lclass = getattr(importlib.import_module(name='loggers.' + l + '_logger'), 'Logger')
50 | self.loggers.append(lclass(self.log_path, self.exp_name))
51 |
52 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
53 | if curtime is None:
54 | curtime = datetime.now()
55 | for l in self.loggers:
56 | l.log_scalar(task, iter, name, value, group, curtime)
57 |
58 | def log_args(self, args):
59 | for l in self.loggers:
60 | l.log_args(args)
61 |
62 | def log_result(self, array, name, step):
63 | for l in self.loggers:
64 | l.log_result(array, name, step)
65 |
66 | def log_figure(self, name, iter, figure, curtime=None):
67 | if curtime is None:
68 | curtime = datetime.now()
69 | for l in self.loggers:
70 | l.log_figure(name, iter, figure, curtime)
71 |
72 | def save_model(self, state_dict, task):
73 | if self.save_models:
74 | for l in self.loggers:
75 | l.save_model(state_dict, task)
76 |
--------------------------------------------------------------------------------
/src/loggers/tensorboard_logger.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 |
3 | from loggers.exp_logger import ExperimentLogger
4 | import json
5 | import numpy as np
6 |
7 |
8 | class Logger(ExperimentLogger):
9 | """Characterizes a Tensorboard logger"""
10 |
11 | def __init__(self, log_path, exp_name, begin_time=None):
12 | super(Logger, self).__init__(log_path, exp_name, begin_time)
13 | self.tbwriter = SummaryWriter(self.exp_path)
14 |
15 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
16 | self.tbwriter.add_scalar(tag="t{}/{}_{}".format(task, group, name),
17 | scalar_value=value,
18 | global_step=iter)
19 | self.tbwriter.file_writer.flush()
20 |
21 | def log_figure(self, name, iter, figure, curtime=None):
22 | self.tbwriter.add_figure(tag=name, figure=figure, global_step=iter)
23 | self.tbwriter.file_writer.flush()
24 |
25 | def log_args(self, args):
26 | self.tbwriter.add_text(
27 | 'args',
28 | json.dumps(args.__dict__,
29 | separators=(',\n', ' : '),
30 | sort_keys=True))
31 | self.tbwriter.file_writer.flush()
32 |
33 | def log_result(self, array, name, step):
34 | if array.ndim == 1:
35 | # log as scalars
36 | self.tbwriter.add_scalar(f'results/{name}', array[step], step)
37 |
38 | elif array.ndim == 2:
39 | s = ""
40 | i = step
41 | # for i in range(array.shape[0]):
42 | for j in range(array.shape[1]):
43 | s += '{:5.1f}% '.format(100 * array[i, j])
44 | if np.trace(array) == 0.0:
45 | if i > 0:
46 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i].mean())
47 | else:
48 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i + 1].mean())
49 | self.tbwriter.add_text(f'results/{name}', s, step)
50 |
51 | def __del__(self):
52 | self.tbwriter.close()
53 |
--------------------------------------------------------------------------------
/src/networks/README.md:
--------------------------------------------------------------------------------
1 | # Networks
2 | We include a core [network](network.py) class which handles the architecture used as well as the heads needed to do
3 | image classification in an incremental learning setting.
4 |
5 | ## Main usage
6 | When running an experiment, the network model used can be defined in [main_incremental.py](../main_incremental.py) using
7 | `--network`. By default, the existing head of the architecture (usually with 1,000 outputs because of ImageNet) will be
8 | removed since we create a head each time a task is learned. This default behaviour can be disabled by using
9 | `--keep-existing-head`. If the architecture used has the option to use a pretrained model, it can be called with
10 | `--pretrained`.
11 |
12 | We define a [network](network.py) class which contains the architecture model class to be used (torchvision models or
13 | custom), and also a `ModuleList()` of heads that grows incrementally as tasks are learned, called `model` and `heads`
14 | respectively. When doing a forward pass, inputs are passed through the `model` specific forward pass, and the outputs of
15 | it are then fed to the `heads`. This results in a list of outputs, corresponding to the different tasks learned so far
16 | (multi-head base). However, it has to be noted that when using approaches for class-incremental learning, which has no
17 | access to task-ID during test, the heads are treated as if they were concatenated, so the task-ID has no influence.
18 |
19 | We use this system since it would be equivalent to having a head that grows at each task and it would concatenate the
20 | heads after (or create a new head with the new number of outputs and copy the previous heads weights to their respective
21 | positions). However, an advantage of this system is that it allows to update previous heads by adding them to the
22 | optimizer (needed when using exemplars). This is also important when using some regularization such as weight decay,
23 | which would affect previous task heads/outputs if the corresponding weights are included in the optimizer. Furthermore,
24 | it makes it very easy to evaluate on both task-incremental and class-incremental scenarios.
25 |
26 | ### Torchvision models
27 | * Alexnet: `alexnet`
28 | * DenseNet: `densenet121, densenet169, densenet201, densenet161`
29 | * Googlenet: `googlenet`
30 | * Inception: `inception_v3`
31 | * MobileNet: `mobilenet_v2`
32 | * ResNet: `resnet18`, `resnet34`, `resnet50`, `resnet101`, `resnet152`, `resnext50_32x4d`, `resnext101_32x8d`
33 | * ShuffleNet: `shufflenet_v2_x0_5`, `shufflenet_v2_x1_0`, `shufflenet_v2_x1_5`, `shufflenet_v2_x2_0`
34 | * Squeezenet: `squeezenet1_0`, `squeezenet1_1`
35 | * VGG: `vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19_bn`, `vgg19`
36 | * WideResNet: `wide_resnet50_2`, `wide_resnet101_2`
37 |
38 | ### Custom models
39 | We include versions of [LeNet](lenet.py), [ResNet-32](resnet32.py) and [VGGnet](vggnet.py), which use a smaller input
40 | size than the torchvision models. LeNet together with MNIST is useful for quick tests and debugging.
41 |
42 | ## Adding new networks
43 | To add a new custom model architecture, follow this:
44 |
45 | 1. Take as an example [vggnet.py](vggnet.py) and define a Class for the architecture. Initialize all necessary layers
46 | and modules and define a non-incremental last layer (e.g. `self.fc`). Then add `self.head_var = 'fc'` to point to the
47 | variable containing the non-incremental head (it is not important how many classes it has as output since we remove
48 | it when using it for incremental learning).
49 | 2. Define the forward pass of the architecture inside the Class and any other necessary functions.
50 | 3. Define a function outside of the Class to call the model. It needs to contain `num_out` and `pretrained` as inputs.
51 | 4. Include the import to [\_\_init\_\_.py](__init__.py) and add the architecture name to `allmodels`.
52 |
53 | ## Notes
54 | * We provide an implementation of ResNet-32 (see [resnet32.py](resnet32.py)) which is commonly used by several works in
55 | the literature for learning CIFAR-100 in incremental learning scenarios. This network architecture is an adaptation of
56 | ResNet for smaller input size. The number of blocks can be modified in
57 | [this line](https://github.com/mmasana/IL_Survey/blob/9837386d9efddf48d22fc4d23e031248decce68d/src/networks/resnet32.py#L113)
58 | by changing `n=5` to `n=3` for ResNet-20, and `n=9` for ResNet-56.
59 |
--------------------------------------------------------------------------------
/src/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torchvision import models
3 |
4 | from .lenet import LeNet
5 | from .vggnet import VggNet
6 | from .resnet32 import resnet32
7 | from .ovit_tiny_16_augreg_224 import OVit_tiny_16_augreg_224
8 | from .vit_tiny_16_augreg_224 import Vit_tiny_16_augreg_224
9 | from .timm_vit_tiny_16_augreg_224 import Timm_Vit_tiny_16_augreg_224
10 | from .efficient_net import Efficient_net
11 | from .mobile_net import Mobile_net
12 | from .early_conv_vit import Early_conv_vit
13 |
14 |
15 | # available torchvision models
16 | tvmodels = ['alexnet',
17 | 'densenet121', 'densenet169', 'densenet201', 'densenet161',
18 | 'googlenet',
19 | 'inception_v3',
20 | 'mobilenet_v2',
21 | 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
22 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
23 | 'squeezenet1_0', 'squeezenet1_1',
24 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',
25 | 'wide_resnet50_2', 'wide_resnet101_2'
26 | ]
27 |
28 | allmodels = tvmodels + ['resnet32', 'LeNet', 'VggNet', "OVit_tiny_16_augreg_224", "Vit_tiny_16_augreg_224", "Timm_Vit_tiny_16_augreg_224", "Efficient_net", "Mobile_net", "Early_conv_vit"]
29 |
30 |
31 | def set_tvmodel_head_var(model):
32 | if type(model) == models.AlexNet:
33 | model.head_var = 'classifier'
34 | elif type(model) == models.DenseNet:
35 | model.head_var = 'classifier'
36 | elif type(model) == models.Inception3:
37 | model.head_var = 'fc'
38 | elif type(model) == models.ResNet:
39 | model.head_var = 'fc'
40 | elif type(model) == models.VGG:
41 | model.head_var = 'classifier'
42 | elif type(model) == models.GoogLeNet:
43 | model.head_var = 'fc'
44 | elif type(model) == models.MobileNetV2:
45 | model.head_var = 'classifier'
46 | elif type(model) == models.ShuffleNetV2:
47 | model.head_var = 'fc'
48 | elif type(model) == models.SqueezeNet:
49 | model.head_var = 'classifier'
50 | else:
51 | raise ModuleNotFoundError
--------------------------------------------------------------------------------
/src/networks/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/early_conv_vit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/early_conv_vit.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/early_conv_vit_net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/early_conv_vit_net.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/efficient_net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/efficient_net.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/lenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/lenet.cpython-36.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/lenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/lenet.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/mobile_net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/mobile_net.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/network.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/network.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/ovit.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit.cpython-36.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/ovit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-36.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/ovit_tiny_16_augreg_224.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/resnet32.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/resnet32.cpython-36.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/resnet32.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/resnet32.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/timm_vit_tiny_16_augreg_224.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/timm_vit_tiny_16_augreg_224.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/vggnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vggnet.cpython-36.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/vggnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vggnet.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/vit_original.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vit_original.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/__pycache__/vit_tiny_16_augreg_224.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/__pycache__/vit_tiny_16_augreg_224.cpython-38.pyc
--------------------------------------------------------------------------------
/src/networks/early_conv_vit.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from .early_conv_vit_net import EarlyConvViT
4 |
5 |
6 | class Early_conv_vit(nn.Module):
7 |
8 | def __init__(self, num_classes=100, pretrained=False):
9 | super().__init__()
10 |
11 | self.model = EarlyConvViT(
12 | dim=768,
13 | num_classes=2,
14 | depth = 12,
15 | heads = 12,
16 | mlp_dim = 2048,
17 | channels=3,
18 | )
19 |
20 | #import ipdb; ipdb.set_trace()
21 | if pretrained:
22 | raise NotImplementedError
23 |
24 | # last classifier layer (head) with as many outputs as classes
25 | self.fc = nn.Linear(in_features=768, out_features=num_classes, bias=True)
26 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
27 | self.head_var = 'fc'
28 |
29 | def forward(self, x):
30 | h = self.fc(self.model(x))
31 | return h
32 |
33 |
34 | def early_conv_vit(num_out=100, pretrained=False):
35 | if pretrained:
36 | return Early_conv_vit(num_out, pretrained)
37 | else:
38 | raise NotImplementedError
39 | assert 1==0, "you should not be here :/"
40 |
41 |
--------------------------------------------------------------------------------
/src/networks/early_conv_vit_net.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import torch
4 | from einops import rearrange, repeat
5 | from einops.layers.torch import Rearrange
6 |
7 | POST_ATTENTION_LIST = []
8 | POST_REC = False
9 |
10 | def start_post_rec():
11 | global POST_REC
12 | POST_REC = True
13 |
14 | def stop_post_rec():
15 | global POST_REC
16 | POST_REC = False
17 |
18 | def get_post_attention_list():
19 | global POST_ATTENTION_LIST
20 | post_attention_list = POST_ATTENTION_LIST
21 | POST_ATTENTION_LIST = []
22 | return post_attention_list
23 |
24 | ATTENTION_LIST = []
25 | REC = False
26 |
27 | def start_rec():
28 | global REC
29 | REC = True
30 |
31 | def stop_rec():
32 | global REC
33 | REC = False
34 |
35 | def get_attention_list():
36 | global ATTENTION_LIST
37 | attention_list = ATTENTION_LIST
38 | ATTENTION_LIST = []
39 | return attention_list
40 |
41 |
42 | # helpers
43 | def make_tuple(t):
44 | """
45 | return the input if it's already a tuple.
46 | return a tuple of the input if the input is not already a tuple.
47 | """
48 | return t if isinstance(t, tuple) else (t, t)
49 |
50 | # classes
51 | class PreNorm(nn.Module):
52 | def __init__(self, dim, fn):
53 | super().__init__()
54 | self.norm = nn.LayerNorm(dim)
55 | self.fn = fn
56 | def forward(self, x, **kwargs):
57 | return self.fn(self.norm(x), **kwargs)
58 |
59 | class FeedForward(nn.Module):
60 | def __init__(self, dim, hidden_dim, dropout=0.):
61 | super().__init__()
62 | self.net = nn.Sequential(
63 | nn.Linear(dim, hidden_dim),
64 | nn.GELU(),
65 | nn.Dropout(dropout),
66 | nn.Linear(hidden_dim, dim),
67 | nn.Dropout(dropout)
68 | )
69 | def forward(self, x):
70 | return self.net(x)
71 |
72 | class Attention(nn.Module):
73 | def __init__(self, dim, heads=7, dim_head=64, dropout=0.):
74 | """
75 | reduced the default number of heads by 1 per https://arxiv.org/pdf/2106.14881v2.pdf
76 | """
77 | super().__init__()
78 | inner_dim = dim_head * heads
79 | project_out = not (heads == 1 and dim_head == dim)
80 |
81 | self.heads = heads
82 | self.scale = dim_head ** -0.5
83 |
84 | self.attend = nn.Softmax(dim = -1)
85 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
86 |
87 | self.to_out = nn.Sequential(
88 | nn.Linear(inner_dim, dim),
89 | nn.Dropout(dropout)
90 | ) if project_out else nn.Identity()
91 |
92 | def forward(self, x):
93 | global REC
94 | global ATTENTION_LIST
95 | global POST_REC
96 | global POST_ATTENTION_LIST
97 |
98 | qkv = self.to_qkv(x).chunk(3, dim = -1)
99 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
100 |
101 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
102 |
103 | if REC:
104 | ATTENTION_LIST.append(dots)
105 |
106 | attn = self.attend(dots)
107 |
108 | out = torch.matmul(attn, v)
109 |
110 | if POST_REC:
111 | POST_ATTENTION_LIST.append(out)
112 |
113 | out = rearrange(out, 'b h n d -> b n (h d)')
114 | return self.to_out(out)
115 |
116 | class Transformer(nn.Module):
117 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
118 | super().__init__()
119 | self.layers = nn.ModuleList([])
120 | for _ in range(depth):
121 | self.layers.append(nn.ModuleList([
122 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
123 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
124 | ]))
125 | def forward(self, x):
126 | for attn, ff in self.layers:
127 | x = attn(x) + x
128 | x = ff(x) + x
129 | return x
130 |
131 | class EarlyConvViT(nn.Module):
132 | def __init__(self, *, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
133 | """
134 | 3x3 conv, stride 1, 5 conv layers per https://arxiv.org/pdf/2106.14881v2.pdf
135 | """
136 | super().__init__()
137 |
138 | n_filter_list = (channels, 24, 48, 96, 196) # hardcoding for now because that's what the paper used
139 |
140 | self.conv_layers = nn.Sequential(
141 | *[nn.Sequential(
142 | nn.Conv2d(in_channels=n_filter_list[i],
143 | out_channels=n_filter_list[i + 1],
144 | kernel_size=3, # hardcoding for now because that's what the paper used
145 | stride=2, # hardcoding for now because that's what the paper used
146 | padding=1), # hardcoding for now because that's what the paper used
147 | )
148 | for i in range(len(n_filter_list)-1)
149 | ])
150 |
151 | self.conv_layers.add_module("conv_1x1", torch.nn.Conv2d(in_channels=n_filter_list[-1],
152 | out_channels=dim,
153 | stride=1, # hardcoding for now because that's what the paper used
154 | kernel_size=1, # hardcoding for now because that's what the paper used
155 | padding=0)) # hardcoding for now because that's what the paper used
156 | self.conv_layers.add_module("flatten image",
157 | Rearrange('batch channels height width -> batch (height width) channels'))
158 | self.pos_embedding = nn.Parameter(torch.randn(1, n_filter_list[-1] + 1, dim))
159 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
160 | self.dropout = nn.Dropout(emb_dropout)
161 |
162 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
163 |
164 | self.pool = pool
165 | self.to_latent = nn.Identity()
166 |
167 | """
168 | self.mlp_head = nn.Sequential(
169 | nn.LayerNorm(dim),
170 | nn.Linear(dim, num_classes)
171 | )
172 | """
173 |
174 | def forward(self, img):
175 | x = self.conv_layers(img)
176 | b, n, _ = x.shape
177 |
178 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
179 | x = torch.cat((cls_tokens, x), dim=1)
180 | x += self.pos_embedding[:, :(n + 1)]
181 | x = self.dropout(x)
182 |
183 | x = self.transformer(x)
184 |
185 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
186 |
187 | x = self.to_latent(x)
188 | return x
189 | #return self.mlp_head(x)
190 |
191 | # It can handle bsize of 50 to 60
192 |
--------------------------------------------------------------------------------
/src/networks/efficient_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import timm
5 |
6 |
7 | # It can handle bsize of 50 to 60
8 |
9 | class Efficient_net(nn.Module):
10 |
11 | def __init__(self, num_classes=100, pretrained=False):
12 | super().__init__()
13 |
14 | self.net = torch.hub.load('szq0214/MEAL-V2','meal_v2', 'mealv2_efficientnet_b0', pretrained=False)#.module
15 |
16 | if pretrained:
17 | checkpoint = torch.hub.load_state_dict_from_url('https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_EfficientNet_B0_224.pth', map_location="cuda:3")
18 | state_dict = {key.replace("module.", ""): value for key, value in checkpoint.items()}
19 | self.net.load_state_dict(state_dict)
20 |
21 |
22 | # last classifier layer (head) with as many outputs as classes
23 | self.net.classifier = nn.Identity()
24 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
25 | self.fc = nn.Linear(in_features=1280, out_features=num_classes, bias=True)
26 | self.head_var = 'fc'
27 | def forward(self, x):
28 | h = self.fc(self.net(x))
29 | return h
30 |
31 |
32 | def efficient_net(num_out=100, pretrained=False):
33 | if pretrained:
34 | return Efficient_net(num_out, pretrained)
35 | else:
36 | raise NotImplementedError
37 | assert 1==0, "you should not be here :/"
38 |
--------------------------------------------------------------------------------
/src/networks/fpt.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from .vit_original import VisionTransformer, _load_weights
4 | from transformers import GPT2Model
5 |
6 |
7 | # It can handle bsize of 50 to 60
8 |
9 | class FPT(nn.Module):
10 |
11 | def __init__(self, num_classes=100, pretrained=False):
12 | super().__init__(
13 |
14 | if not pretrained:
15 | asser 1==0, 'cannot run without pretrained'
16 |
17 | self.patch_size = 8
18 | self.input_dim = 3 * self.patch_size**2
19 |
20 | self.in_net = nn.Linear(input_dim, 768, bias=True)
21 | self.fpt = GPT2Model.from_pretrained('gpt2')
22 | self.fc = nn.Linear(in_features=768, out_features=num_classes, bias=True)
23 | self.head_var = 'fc'
24 |
25 |
26 |
27 |
28 | def forward(self, x):
29 |
30 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
31 | x = self.in_net(x)
32 | with torch.no_grad():
33 | x = self.fpt(x)
34 |
35 | h = self.fc(x)
36 |
37 | return h
38 |
39 |
40 | def fPT(num_out=100, pretrained=False):
41 | if pretrained:
42 | return FPT(num_out, pretrained)
43 | else:
44 | raise NotImplementedError
45 | assert 1==0, "you should not be here :/"
46 |
--------------------------------------------------------------------------------
/src/networks/lenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class LeNet(nn.Module):
6 | """LeNet-like network for tests with MNIST (28x28)."""
7 |
8 | def __init__(self, in_channels=1, num_classes=10, **kwargs):
9 | super().__init__()
10 | # main part of the network
11 | self.conv1 = nn.Conv2d(in_channels, 6, 5)
12 | self.conv2 = nn.Conv2d(6, 16, 5)
13 | self.fc1 = nn.Linear(16 * 16, 120)
14 | self.fc2 = nn.Linear(120, 84)
15 |
16 | # last classifier layer (head) with as many outputs as classes
17 | self.fc = nn.Linear(84, num_classes)
18 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
19 | self.head_var = 'fc'
20 |
21 | def forward(self, x):
22 | out = F.relu(self.conv1(x))
23 | out = F.max_pool2d(out, 2)
24 | out = F.relu(self.conv2(out))
25 | out = F.max_pool2d(out, 2)
26 | out = out.view(out.size(0), -1)
27 | out = F.relu(self.fc1(out))
28 | out = F.relu(self.fc2(out))
29 | out = self.fc(out)
30 | return out
31 |
--------------------------------------------------------------------------------
/src/networks/mobile_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import timm
5 |
6 |
7 | # It can handle bsize of 50 to 60
8 |
9 | class Mobile_net(nn.Module):
10 |
11 | def __init__(self, num_classes=100, pretrained=False):
12 | super().__init__()
13 |
14 | self.net = torch.hub.load('szq0214/MEAL-V2','meal_v2', 'mealv2_mobilenet_v3_large_100', pretrained=False)#.module
15 |
16 | if pretrained:
17 | checkpoint = torch.hub.load_state_dict_from_url('https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Large_1.0_224.pth', map_location="cuda:2")
18 | state_dict = {key.replace("module.", ""): value for key, value in checkpoint.items()}
19 | self.net.load_state_dict(state_dict)
20 |
21 |
22 | # last classifier layer (head) with as many outputs as classes
23 | self.net.classifier = nn.Identity()
24 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
25 | self.fc = nn.Linear(in_features=1280, out_features=num_classes, bias=True)
26 | self.head_var = 'fc'
27 | def forward(self, x):
28 | h = self.fc(self.net(x))
29 | return h
30 |
31 |
32 | def mobile_net(num_out=100, pretrained=False):
33 | if pretrained:
34 | return Mobile_net(num_out, pretrained)
35 | else:
36 | raise NotImplementedError
37 | assert 1==0, "you should not be here :/"
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src/networks/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from copy import deepcopy
4 |
5 |
6 | class LLL_Net(nn.Module):
7 | """Basic class for implementing networks"""
8 |
9 | def __init__(self, model, remove_existing_head=False):
10 | head_var = model.head_var
11 | assert type(head_var) == str
12 | assert not remove_existing_head or hasattr(model, head_var), \
13 | "Given model does not have a variable called {}".format(head_var)
14 | assert not remove_existing_head or type(getattr(model, head_var)) in [nn.Sequential, nn.Linear], \
15 | "Given model's head {} does is not an instance of nn.Sequential or nn.Linear".format(head_var)
16 | super(LLL_Net, self).__init__()
17 |
18 | self.model = model
19 | last_layer = getattr(self.model, head_var)
20 |
21 | if remove_existing_head:
22 | if type(last_layer) == nn.Sequential:
23 | self.out_size = last_layer[-1].in_features
24 | # strips off last linear layer of classifier
25 | del last_layer[-1]
26 | elif type(last_layer) == nn.Linear:
27 | self.out_size = last_layer.in_features
28 | # converts last layer into identity
29 | # setattr(self.model, head_var, nn.Identity())
30 | # WARNING: this is for when pytorch version is <1.2
31 | setattr(self.model, head_var, nn.Sequential())
32 | else:
33 | self.out_size = last_layer.out_features
34 |
35 | self.heads = nn.ModuleList()
36 | self.task_cls = []
37 | self.task_offset = []
38 | self._initialize_weights()
39 |
40 | def add_head(self, num_outputs):
41 | """Add a new head with the corresponding number of outputs. Also update the number of classes per task and the
42 | corresponding offsets
43 | """
44 | self.heads.append(nn.Linear(self.out_size, num_outputs))
45 | # we re-compute instead of append in case an approach makes changes to the heads
46 | self.task_cls = torch.tensor([head.out_features for head in self.heads])
47 | self.task_offset = torch.cat([torch.LongTensor(1).zero_(), self.task_cls.cumsum(0)[:-1]])
48 |
49 | def forward(self, x, return_features=False):
50 | """Applies the forward pass
51 |
52 | Simplification to work on multi-head only -- returns all head outputs in a list
53 | Args:
54 | x (tensor): input images
55 | return_features (bool): return the representations before the heads
56 | """
57 | x = self.model(x)
58 | assert (len(self.heads) > 0), "Cannot access any head"
59 | y = []
60 | for head in self.heads:
61 | y.append(head(x))
62 | if return_features:
63 | return y, x
64 | else:
65 | return y
66 |
67 | def get_copy(self):
68 | """Get weights from the model"""
69 | return deepcopy(self.state_dict())
70 |
71 | def set_state_dict(self, state_dict):
72 | """Load weights into the model"""
73 | self.load_state_dict(deepcopy(state_dict))
74 | return
75 |
76 | def freeze_all(self):
77 | """Freeze all parameters from the model, including the heads"""
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | def freeze_backbone(self):
82 | """Freeze all parameters from the main model, but not the heads"""
83 | for param in self.model.parameters():
84 | param.requires_grad = False
85 |
86 | def freeze_bn(self):
87 | """Freeze all Batch Normalization layers from the model and use them in eval() mode"""
88 | for m in self.model.modules():
89 | if isinstance(m, nn.BatchNorm2d):
90 | m.eval()
91 |
92 | def _initialize_weights(self):
93 | """Initialize weights using different strategies"""
94 | # TODO: add different initialization strategies
95 | pass
96 |
--------------------------------------------------------------------------------
/src/networks/ovit_tiny_16_augreg_224.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from .ovit import OVisionTransformer, _load_weights, get_attention_list
4 |
5 |
6 | # It can handle bsize of 50 to 60
7 |
8 | class OVit_tiny_16_augreg_224(nn.Module):
9 |
10 | def __init__(self, num_classes=100, pretrained=False):
11 | super().__init__()
12 |
13 | filename = 'src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'
14 |
15 | #import ipdb; ipdb.set_trace()
16 | self.ovit = OVisionTransformer(embed_dim=192, num_classes=0, num_heads=3)
17 | if pretrained:
18 | _load_weights(model=self.ovit, checkpoint_path=filename)
19 |
20 | # last classifier layer (head) with as many outputs as classes
21 | self.fc = nn.Linear(in_features=192, out_features=num_classes, bias=True)
22 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
23 | self.head_var = 'fc'
24 |
25 | def forward(self, x):
26 | h = self.fc(self.ovit(x))
27 | return h
28 |
29 |
30 | def ovit_tiny_16_augreg_224(num_out=100, pretrained=False):
31 | if pretrained:
32 | return OVit_tiny_16_augreg_224(num_out, pretrained)
33 | else:
34 | raise NotImplementedError
35 | assert 1==0, "you should not be here :/"
36 |
--------------------------------------------------------------------------------
/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srvCodes/continual_learning_with_vit/1021a5a93d0093c1c7e358ace47a7a5235a12b19/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz
--------------------------------------------------------------------------------
/src/networks/resnet32.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | __all__ = ['resnet32']
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, inplanes, planes, stride=1, downsample=None):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = conv3x3(inplanes, planes, stride)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.downsample = downsample
22 | self.stride = stride
23 |
24 | def forward(self, x):
25 | residual = x
26 | out = self.relu(self.bn1(self.conv1(x)))
27 | out = self.bn2(self.conv2(out))
28 | if self.downsample is not None:
29 | residual = self.downsample(x)
30 | out += residual
31 | return self.relu(out)
32 |
33 |
34 | class Bottleneck(nn.Module):
35 | expansion = 4
36 |
37 | def __init__(self, inplanes, planes, stride=1, downsample=None):
38 | super(Bottleneck, self).__init__()
39 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
40 | self.bn1 = nn.BatchNorm2d(planes)
41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
44 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
45 | self.relu = nn.ReLU(inplace=True)
46 | self.downsample = downsample
47 | self.stride = stride
48 |
49 | def forward(self, x):
50 | residual = x
51 | out = self.relu(self.bn1(self.conv1(x)))
52 | out = self.relu(self.bn2(self.conv2(out)))
53 | out = self.bn3(self.conv3(out))
54 | if self.downsample is not None:
55 | residual = self.downsample(x)
56 | out += residual
57 | return self.relu(out)
58 |
59 |
60 | class ResNet(nn.Module):
61 |
62 | def __init__(self, block, layers, num_classes=10):
63 | self.inplanes = 16
64 | super(ResNet, self).__init__()
65 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
66 | self.bn1 = nn.BatchNorm2d(16)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.layer1 = self._make_layer(block, 16, layers[0])
69 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
70 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
71 | self.avgpool = nn.AvgPool2d(8, stride=1)
72 | # last classifier layer (head) with as many outputs as classes
73 | self.fc = nn.Linear(64 * block.expansion, num_classes)
74 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
75 | self.head_var = 'fc'
76 |
77 | for m in self.modules():
78 | if isinstance(m, nn.Conv2d):
79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
80 | elif isinstance(m, nn.BatchNorm2d):
81 | nn.init.constant_(m.weight, 1)
82 | nn.init.constant_(m.bias, 0)
83 |
84 | def _make_layer(self, block, planes, blocks, stride=1):
85 | downsample = None
86 | if stride != 1 or self.inplanes != planes * block.expansion:
87 | downsample = nn.Sequential(
88 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
89 | nn.BatchNorm2d(planes * block.expansion),
90 | )
91 | layers = []
92 | layers.append(block(self.inplanes, planes, stride, downsample))
93 | self.inplanes = planes * block.expansion
94 | for i in range(1, blocks):
95 | layers.append(block(self.inplanes, planes))
96 | return nn.Sequential(*layers)
97 |
98 | def forward(self, x):
99 | x = self.relu(self.bn1(self.conv1(x)))
100 | x = self.layer1(x)
101 | x = self.layer2(x)
102 | x = self.layer3(x)
103 | x = self.avgpool(x)
104 | x = x.view(x.size(0), -1)
105 | x = self.fc(x)
106 | return x
107 |
108 |
109 | def resnet32(pretrained=False, **kwargs):
110 | if pretrained:
111 | raise NotImplementedError
112 | # change n=3 for ResNet-20, and n=9 for ResNet-56
113 | n = 5
114 | model = ResNet(BasicBlock, [n, n, n], **kwargs)
115 | return model
116 |
--------------------------------------------------------------------------------
/src/networks/timm_vit_tiny_16_augreg_224.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import timm
4 |
5 |
6 | # It can handle bsize of 50 to 60
7 |
8 | class Timm_Vit_tiny_16_augreg_224(nn.Module):
9 |
10 | def __init__(self, num_classes=100, pretrained=False):
11 | super().__init__()
12 |
13 | filename = '/home/fpelosin/vit_facil/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'
14 |
15 | #import ipdb; ipdb.set_trace()
16 | if pretrained:
17 | self.vit = timm.create_model('vit_tiny_patch16_224', num_classes=0)
18 | timm.models.load_checkpoint(self.vit, filename)
19 | else:
20 | self.vit = timm.create_model('vit_tiny_patch16_224', num_classes=0)
21 |
22 | # last classifier layer (head) with as many outputs as classes
23 | self.fc = nn.Linear(in_features=192, out_features=num_classes, bias=True)
24 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
25 | self.head_var = 'fc'
26 |
27 | def forward(self, x):
28 | h = self.fc(self.vit(x))
29 | return h
30 |
31 |
32 | def timm_vit_tiny_16_augreg_224(num_out=100, pretrained=False):
33 | if pretrained:
34 | return Timm_Vit_tiny_16_augreg_224(num_out, pretrained)
35 | else:
36 | raise NotImplementedError
37 | assert 1==0, "you should not be here :/"
38 |
--------------------------------------------------------------------------------
/src/networks/vggnet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class VggNet(nn.Module):
6 | """ Following the VGGnet based on VGG16 but for smaller input (64x64)
7 | Check this blog for some info: https://learningai.io/projects/2017/06/29/tiny-imagenet.html
8 | """
9 |
10 | def __init__(self, num_classes=1000):
11 | super().__init__()
12 |
13 | self.features = nn.Sequential(
14 | nn.Conv2d(3, 64, kernel_size=3, padding=1),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
17 | nn.ReLU(inplace=True),
18 | nn.MaxPool2d(kernel_size=2, stride=2),
19 | nn.Conv2d(64, 128, kernel_size=3, padding=1),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(128, 128, kernel_size=3, padding=1),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=2, stride=2),
24 | nn.Conv2d(128, 256, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=2, stride=2),
31 | nn.Conv2d(256, 512, kernel_size=3, padding=1),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(512, 512, kernel_size=3, padding=1),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(512, 512, kernel_size=3, padding=1),
36 | nn.ReLU(inplace=True),
37 | nn.MaxPool2d(kernel_size=2, stride=2),
38 | )
39 | self.fc6 = nn.Linear(in_features=512 * 4 * 4, out_features=4096, bias=True)
40 | self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
41 | # last classifier layer (head) with as many outputs as classes
42 | self.fc = nn.Linear(in_features=4096, out_features=num_classes, bias=True)
43 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
44 | self.head_var = 'fc'
45 |
46 | def forward(self, x):
47 | h = self.features(x)
48 | h = h.view(x.size(0), -1)
49 | h = F.dropout(F.relu(self.fc6(h)))
50 | h = F.dropout(F.relu(self.fc7(h)))
51 | h = self.fc(h)
52 | return h
53 |
54 |
55 | def vggnet(num_out=100, pretrained=False):
56 | if pretrained:
57 | raise NotImplementedError
58 | return VggNet(num_out)
59 |
--------------------------------------------------------------------------------
/src/networks/vit_tiny_16_augreg_224.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from .vit_original import VisionTransformer, _load_weights
4 |
5 |
6 | # It can handle bsize of 50 to 60
7 |
8 | class Vit_tiny_16_augreg_224(nn.Module):
9 |
10 | def __init__(self, num_classes=100, pretrained=False):
11 | super().__init__()
12 |
13 | filename = '/home/fpelosin/transformers/FACIL/src/networks/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'
14 |
15 | #import ipdb; ipdb.set_trace()
16 | self.vit = VisionTransformer(embed_dim=192, num_heads=3, num_classes=0)
17 | if pretrained:
18 | _load_weights(model=self.vit, checkpoint_path=filename)
19 |
20 | # last classifier layer (head) with as many outputs as classes
21 | self.fc = nn.Linear(in_features=192, out_features=num_classes, bias=True)
22 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
23 | self.head_var = 'fc'
24 |
25 | def forward(self, x):
26 | h = self.fc(self.vit(x))
27 | return h
28 |
29 |
30 | def vit_tiny_16_augreg_224(num_out=100, pretrained=False):
31 | if pretrained:
32 | return Vit_tiny_16_augreg_224(num_out, pretrained)
33 | else:
34 | raise NotImplementedError
35 | assert 1==0, "you should not be here :/"
36 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import timm
5 | import ipdb
6 | from networks.ovit import VisionTransformer, _load_weights
7 |
8 |
9 | filename = '/home/fpelosin/vit_facil/src/networks/pretrained_weights/augreg_Ti_16-i1k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'
10 |
11 |
12 | # Custom
13 | ovit = VisionTransformer(embed_dim=192, num_classes=0)
14 | _load_weights(model=ovit, checkpoint_path=filename)
15 | ovit.eval()
16 |
17 |
18 | # timm
19 | #timm.model.visiontransformer
20 | vit = timm.create_model('vit_tiny_patch16_224', num_classes=0)
21 | timm.models.load_checkpoint(vit, filename)
22 | vit.eval()
23 |
24 | inp = torch.rand(1,3,224,224)
25 |
26 |
27 | for i in range(12):
28 |
29 | vit_chunck = torch.nn.Sequential(vit.patch_embed, vit.pos_drop, vit.blocks[:i])
30 | ovit_chunck = torch.nn.Sequential(ovit.patch_embed, ovit.pos_drop, ovit.blocks[:i])
31 |
32 | out_vit = vit_chunck(inp)
33 | out_ovit = ovit_chunck(inp)
34 |
35 | if torch.abs(out_vit - out_ovit).sum() > 1e-8:
36 | print(f"diff@{i}")
37 | else:
38 | print(f"NOT diff@{i}")
39 |
40 |
41 | #print(vit(inp))
42 | #print(ovit(inp))
43 |
44 | import ipdb; ipdb.set_trace()
45 |
46 | # Sanity check
47 | ovit_state_dict = ovit.state_dict()
48 |
49 | for name, param in vit.named_parameters():
50 | if 'attn' in name or 'mlp' in name:
51 | ovit_param = ovit_state_dict[name]
52 | if torch.abs(param.data - ovit_param.data).sum() > 1e-8:
53 | print(f'--->[ DIFF ] {name} is NOT same')
54 | else:
55 | print(f'[ OK ] {name} is same')
56 |
57 |
--------------------------------------------------------------------------------
/src/tests/README.md:
--------------------------------------------------------------------------------
1 | # Tests
2 | The tests in this folder are our tool to check if everything is working as intended and to pick any errors that might
3 | appear when introducing new features. They can also be used to make sure all dependencies are available before running
4 | any long experiments.
5 |
6 | ## Running tests
7 |
8 | ### From console
9 | Type the following code in your console:
10 | ```bash
11 | cd src/
12 | py.test -s tests/
13 | ```
14 |
15 | ### Running tests in parallel
16 | As the amount of tests grow, it can be faster to run them in parallel since they can take few minutes each.
17 | Pytest can support it by using the `pytest-xdist` plugin, e.g. running all tests with 5 workers:
18 | ```bash
19 | cd src/
20 | py.test -n 5 tests/
21 | ```
22 | And for a more verbose output:
23 | ```bash
24 | cd src/
25 | py.test -sv -n 5 tests/
26 | ```
27 | **_Warning:_** It is recommended to run a single test without parallelization the first time, since the first thing that
28 | our test will do is download the dataset (MNIST). If ran in parallel they can start downloading it in multiple workers
29 | at the same time.
30 |
31 | ### From your IDE (PyCharm, VSCode, ...)
32 | `py.tests` are well supported. It's usually enough to select `py.test` as a framework, right click on a test file or
33 | directory and select the option to run pytest tests (not as a python script!).
34 |
--------------------------------------------------------------------------------
/src/tests/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 |
5 | from main_incremental import main
6 | import datasets.dataset_config as c
7 |
8 |
9 | def run_main(args_line, result_dir='results_test', clean_run=False):
10 | assert "--results-path" not in args_line
11 |
12 | print('Staring dir:', os.getcwd())
13 | if os.getcwd().endswith('tests'):
14 | os.chdir('..')
15 | elif os.getcwd().endswith('IL_Survey'):
16 | os.chdir('src')
17 | elif os.getcwd().endswith('src'):
18 | print('CWD is OK.')
19 | print('Test CWD:', os.getcwd())
20 | test_results_path = os.getcwd() + f"/../{result_dir}"
21 |
22 | # for testing - use relative path to CWD
23 | c.dataset_config['mnist']['path'] = '../data'
24 | if os.path.exists(test_results_path) and clean_run:
25 | shutil.rmtree(test_results_path)
26 | os.makedirs(test_results_path, exist_ok=True)
27 | args_line += " --results-path {}".format(test_results_path)
28 |
29 | # if distributed test -- use all GPU
30 | worker_id = int(os.environ.get("PYTEST_XDIST_WORKER", "gw-1")[2:])
31 | if worker_id >= 0 and torch.cuda.is_available():
32 | gpu_idx = worker_id % torch.cuda.device_count()
33 | args_line += " --gpu {}".format(gpu_idx)
34 |
35 | print('ARGS:', args_line)
36 | return main(args_line.split(' '))
37 |
38 |
39 | def run_main_and_assert(args_line,
40 | taw_current_task_min=0.01,
41 | tag_current_task_min=0.0,
42 | result_dir='results_test'):
43 | acc_taw, acc_tag, forg_taw, forg_tag, exp_dir = run_main(args_line, result_dir)
44 |
45 | # acc matrices sanity check
46 | assert acc_tag.shape == acc_taw.shape
47 | assert acc_tag.shape == forg_tag.shape
48 | assert acc_tag.shape == forg_taw.shape
49 |
50 | # check current task performance
51 | assert all(acc_tag.diagonal() >= tag_current_task_min)
52 | assert all(acc_taw.diagonal() >= taw_current_task_min)
53 |
--------------------------------------------------------------------------------
/src/tests/test_bic.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach bic"
8 |
9 |
10 | def test_bic_exemplars():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --num-exemplars 200"
13 | run_main_and_assert(args_line)
14 |
15 |
16 | def test_bic_exemplars_lambda():
17 | args_line = FAST_LOCAL_TEST_ARGS
18 | args_line += " --num-exemplars 200"
19 | args_line += " --lamb 1"
20 | run_main_and_assert(args_line)
21 |
22 |
23 | def test_bic_exemplars_per_class():
24 | args_line = FAST_LOCAL_TEST_ARGS
25 | args_line += " --num-exemplars-per-class 20"
26 | run_main_and_assert(args_line)
27 |
28 |
29 | def test_bic_with_warmup():
30 | args_line = FAST_LOCAL_TEST_ARGS
31 | args_line += " --warmup-nepochs 5"
32 | args_line += " --warmup-lr-factor 0.5"
33 | args_line += " --num-exemplars 200"
34 | run_main_and_assert(args_line)
35 |
--------------------------------------------------------------------------------
/src/tests/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataset import TensorDataset
4 |
5 | from main_incremental import main
6 |
7 |
8 | def test_dataloader_dataset_swap():
9 | # given
10 | data1 = TensorDataset(torch.arange(10))
11 | data2 = TensorDataset(torch.arange(10, 20))
12 | dl = DataLoader(data1, batch_size=2, shuffle=True, num_workers=1)
13 | # when
14 | batches1 = list(dl)
15 |
16 | try:
17 | dl.dataset += data2
18 | except ValueError:
19 | # In new pytorch this raise an error
20 | # which is expected and OK behaviour
21 | return
22 |
23 | batches2 = list(dl)
24 | all_data = list(dl.dataset)
25 |
26 | # then
27 | assert len(all_data) == 20
28 | assert len(batches1) == 5
29 | assert len(batches2) == 5
30 | # ^ is troublesome!
31 | # Sampler is initialized in DataLoader __init__
32 | # and it holding reference to old DS.
33 | assert dl.sampler.data_source == data1
34 | # Thus, we will not see the new data.
35 |
36 |
37 | def test_dataloader_multiple_datasets():
38 | args_line = "--exp-name local_test --approach finetuning --datasets mnist mnist mnist" \
39 | " --network LeNet --num-tasks 2 --batch-size 32" \
40 | " --results-path ../results/ --num-workers 0 --nepochs 2"
41 | print('ARGS:', args_line)
42 | main(args_line.split(' '))
43 |
--------------------------------------------------------------------------------
/src/tests/test_datasets_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision.transforms import Lambda
3 | from torch.utils.data.dataset import ConcatDataset
4 |
5 | from datasets.memory_dataset import MemoryDataset
6 | from datasets.exemplars_selection import override_dataset_transform
7 |
8 |
9 | def pic(i):
10 | return np.array([[i]], dtype=np.int8)
11 |
12 |
13 | def test_dataset_transform_override():
14 | # given
15 | data1 = MemoryDataset({
16 | 'x': [pic(1), pic(2), pic(3)], 'y': ['a', 'b', 'c']
17 | }, transform=Lambda(lambda x: np.array(x)[0, 0] * 2))
18 | data2 = MemoryDataset({
19 | 'x': [pic(4), pic(5), pic(6)], 'y': ['d', 'e', 'f']
20 | }, transform=Lambda(lambda x: np.array(x)[0, 0] * 3))
21 | data3 = MemoryDataset({
22 | 'x': [pic(7), pic(8), pic(9)], 'y': ['g', 'h', 'i']
23 | }, transform=Lambda(lambda x: np.array(x)[0, 0] + 10))
24 | ds = ConcatDataset([data1, ConcatDataset([data2, data3])])
25 |
26 | # when
27 | x1, y1 = zip(*[ds[i] for i in range(len(ds))])
28 | with override_dataset_transform(ds, Lambda(lambda x: np.array(x)[0, 0])) as ds_overriden:
29 | x2, y2 = zip(*[ds_overriden[i] for i in range(len(ds_overriden))])
30 | x3, y3 = zip(*[ds[i] for i in range(len(ds))])
31 |
32 | # then
33 | assert np.array_equal(x1, [2, 4, 6, 12, 15, 18, 17, 18, 19])
34 | assert np.array_equal(x2, [1, 2, 3, 4, 5, 6, 7, 8, 9])
35 | assert np.array_equal(x3, x1) # after everything is back to normal
36 |
--------------------------------------------------------------------------------
/src/tests/test_dmc.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets cifar100" \
4 | " --network resnet32 --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach dmc" \
8 | " --aux-dataset cifar100"
9 |
10 |
11 | def test_dmc():
12 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
13 |
14 |
--------------------------------------------------------------------------------
/src/tests/test_eeil.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach eeil"
8 |
9 |
10 | def test_eeil_exemplars_with_noise_grad():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --num-exemplars 200"
13 | args_line += " --nepochs-finetuning 1"
14 | args_line += " --noise-grad"
15 | run_main_and_assert(args_line)
16 |
17 |
18 | def test_eeil_exemplars():
19 | args_line = FAST_LOCAL_TEST_ARGS
20 | args_line += " --num-exemplars 200"
21 | args_line += " --nepochs-finetuning 1"
22 | run_main_and_assert(args_line)
23 |
24 |
25 | def test_eeil_with_warmup():
26 | args_line = FAST_LOCAL_TEST_ARGS
27 | args_line += " --warmup-nepochs 5"
28 | args_line += " --warmup-lr-factor 0.5"
29 | args_line += " --num-exemplars 200"
30 | args_line += " --nepochs-finetuning 1"
31 | run_main_and_assert(args_line)
32 |
--------------------------------------------------------------------------------
/src/tests/test_ewc.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach ewc"
8 |
9 |
10 | def test_ewc_without_exemplars():
11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
12 |
13 |
14 | def test_ewc_with_exemplars():
15 | args_line = FAST_LOCAL_TEST_ARGS
16 | args_line += " --num-exemplars 200"
17 | run_main_and_assert(args_line)
18 |
19 |
20 | def test_ewc_with_warmup():
21 | args_line = FAST_LOCAL_TEST_ARGS
22 | args_line += " --warmup-nepochs 5"
23 | args_line += " --warmup-lr-factor 0.5"
24 | args_line += " --num-exemplars 200"
25 | run_main_and_assert(args_line)
26 |
--------------------------------------------------------------------------------
/src/tests/test_finetuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | from tests import run_main_and_assert
5 |
6 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
7 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
8 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
9 | " --num-workers 0"
10 |
11 |
12 | def test_finetuning_without_exemplars():
13 | args_line = FAST_LOCAL_TEST_ARGS
14 | args_line += " --approach finetuning"
15 | run_main_and_assert(args_line)
16 |
17 |
18 | def test_finetuning_with_exemplars():
19 | args_line = FAST_LOCAL_TEST_ARGS
20 | args_line += " --approach finetuning"
21 | args_line += " --num-exemplars 200"
22 | run_main_and_assert(args_line)
23 |
24 |
25 | @pytest.mark.xfail
26 | def test_finetuning_with_exemplars_per_class_and_herding():
27 | args_line = FAST_LOCAL_TEST_ARGS
28 | args_line += " --approach finetuning"
29 | args_line += " --num-exemplars-per-class 10"
30 | args_line += " --exemplar-selection herding"
31 | run_main_and_assert(args_line)
32 |
33 |
34 | def test_finetuning_with_exemplars_per_class_and_entropy():
35 | args_line = FAST_LOCAL_TEST_ARGS
36 | args_line += " --approach finetuning"
37 | args_line += " --num-exemplars-per-class 10"
38 | args_line += " --exemplar-selection entropy"
39 | run_main_and_assert(args_line)
40 |
41 |
42 | def test_finetuning_with_exemplars_per_class_and_distance():
43 | args_line = FAST_LOCAL_TEST_ARGS
44 | args_line += " --approach finetuning"
45 | args_line += " --num-exemplars-per-class 10"
46 | args_line += " --exemplar-selection distance"
47 | run_main_and_assert(args_line)
48 |
49 |
50 | def test_wrong_args():
51 | with pytest.raises(SystemExit): # error of providing both args
52 | args_line = FAST_LOCAL_TEST_ARGS
53 | args_line += " --approach finetuning"
54 | args_line += " --num-exemplars-per-class 10"
55 | args_line += " --num-exemplars 200"
56 | run_main_and_assert(args_line)
57 |
58 |
59 | def test_finetuning_with_eval_on_train():
60 | args_line = FAST_LOCAL_TEST_ARGS
61 | args_line += " --approach finetuning"
62 | args_line += " --num-exemplars-per-class 10"
63 | args_line += " --exemplar-selection distance"
64 | args_line += " --eval-on-train"
65 | run_main_and_assert(args_line)
66 |
67 | def test_finetuning_with_no_cudnn_deterministic():
68 | args_line = FAST_LOCAL_TEST_ARGS
69 | args_line += " --approach finetuning"
70 | args_line += " --num-exemplars-per-class 10"
71 | args_line += " --exemplar-selection distance"
72 |
73 | run_main_and_assert(args_line)
74 | assert torch.backends.cudnn.deterministic == True
75 |
76 | args_line += " --no-cudnn-deterministic"
77 | run_main_and_assert(args_line)
78 | assert torch.backends.cudnn.deterministic == False
79 |
--------------------------------------------------------------------------------
/src/tests/test_fix_bn.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
6 | " --num-workers 0 --fix-bn"
7 |
8 |
9 | def test_finetuning_fix_bn():
10 | args_line = FAST_LOCAL_TEST_ARGS
11 | args_line += " --approach finetuning"
12 | run_main_and_assert(args_line)
13 |
14 |
15 | def test_joint_fix_bn():
16 | args_line = FAST_LOCAL_TEST_ARGS
17 | args_line += " --approach joint"
18 | run_main_and_assert(args_line)
19 |
20 |
21 | def test_freezingt_fix_bn():
22 | args_line = FAST_LOCAL_TEST_ARGS
23 | args_line += " --approach freezing"
24 | run_main_and_assert(args_line)
25 |
26 |
27 | def test_icarl_fix_bn():
28 | args_line = FAST_LOCAL_TEST_ARGS
29 | args_line += " --num-exemplars 200"
30 | args_line += " --approach icarl"
31 | run_main_and_assert(args_line)
32 |
33 |
34 | def test_eeil_fix_bn():
35 | args_line = FAST_LOCAL_TEST_ARGS
36 | args_line += " --num-exemplars 200"
37 | args_line += " --approach eeil"
38 | run_main_and_assert(args_line)
39 |
40 |
41 | def test_mas_fix_bn():
42 | args_line = FAST_LOCAL_TEST_ARGS
43 | args_line += " --approach mas"
44 | run_main_and_assert(args_line)
45 |
46 |
47 | def test_lwf_fix_bn():
48 | args_line = FAST_LOCAL_TEST_ARGS
49 | args_line += " --approach lwf"
50 | run_main_and_assert(args_line)
51 |
52 |
53 | def test_lwm_fix_bn():
54 | args_line = FAST_LOCAL_TEST_ARGS
55 | args_line += " --approach lwm --gradcam-layer conv2 --log-gradcam-samples 16"
56 | run_main_and_assert(args_line)
57 |
58 |
59 | def test_r_walk_fix_bn():
60 | args_line = FAST_LOCAL_TEST_ARGS
61 | args_line += " --approach r_walk"
62 | run_main_and_assert(args_line)
63 |
64 |
65 | def test_path_integral_fix_bn():
66 | args_line = FAST_LOCAL_TEST_ARGS
67 | args_line += " --approach path_integral"
68 | run_main_and_assert(args_line)
69 |
70 |
71 | def test_luci_fix_bn():
72 | args_line = FAST_LOCAL_TEST_ARGS
73 | args_line += " --approach lucir"
74 | run_main_and_assert(args_line)
75 |
76 |
77 | def test_ewc_fix_bn():
78 | args_line = FAST_LOCAL_TEST_ARGS
79 | args_line += " --approach ewc"
80 | run_main_and_assert(args_line)
81 |
--------------------------------------------------------------------------------
/src/tests/test_freezing.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach freezing"
8 |
9 |
10 | def test_freezing_without_exemplars():
11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
12 |
13 |
14 | def test_freezing_with_exemplars():
15 | args_line = FAST_LOCAL_TEST_ARGS
16 | args_line += " --num-exemplars 200"
17 | run_main_and_assert(args_line)
18 |
19 |
20 | def test_freezing_with_warmup():
21 | args_line = FAST_LOCAL_TEST_ARGS
22 | args_line += " --warmup-nepochs 5"
23 | args_line += " --warmup-lr-factor 0.5"
24 | run_main_and_assert(args_line)
25 |
--------------------------------------------------------------------------------
/src/tests/test_gridsearch.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3 --num-workers 0" \
6 | " --gridsearch-tasks 3 --gridsearch-config gridsearch_config" \
7 | " --gridsearch-acc-drop-thr 0.2 --gridsearch-hparam-decay 0.5"
8 |
9 |
10 | def test_gridsearch_finetuning():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --approach finetuning --num-exemplars 200"
13 | run_main_and_assert(args_line)
14 |
15 |
16 | def test_gridsearch_freezing():
17 | args_line = FAST_LOCAL_TEST_ARGS
18 | args_line += " --approach freezing --num-exemplars 200"
19 | run_main_and_assert(args_line)
20 |
21 |
22 | def test_gridsearch_joint():
23 | args_line = FAST_LOCAL_TEST_ARGS
24 | args_line += " --approach joint"
25 | run_main_and_assert(args_line)
26 |
27 |
28 | def test_gridsearch_lwf():
29 | args_line = FAST_LOCAL_TEST_ARGS
30 | args_line += " --approach lwf --num-exemplars 200"
31 | run_main_and_assert(args_line)
32 |
33 |
34 | def test_gridsearch_icarl():
35 | args_line = FAST_LOCAL_TEST_ARGS
36 | args_line += " --approach icarl --num-exemplars 200"
37 | run_main_and_assert(args_line)
38 |
39 |
40 | def test_gridsearch_eeil():
41 | args_line = FAST_LOCAL_TEST_ARGS
42 | args_line += " --approach eeil --nepochs-finetuning 3 --num-exemplars 200"
43 | run_main_and_assert(args_line)
44 |
45 |
46 | def test_gridsearch_bic():
47 | args_line = FAST_LOCAL_TEST_ARGS
48 | args_line += " --approach bic --num-bias-epochs 3 --num-exemplars 200"
49 | run_main_and_assert(args_line)
50 |
51 |
52 | def test_gridsearch_lucir():
53 | args_line = FAST_LOCAL_TEST_ARGS
54 | args_line += " --approach lucir --num-exemplars 200"
55 | run_main_and_assert(args_line)
56 |
57 |
58 | def test_gridsearch_lwm():
59 | args_line = FAST_LOCAL_TEST_ARGS
60 | args_line += " --approach lwm --gradcam-layer conv2 --log-gradcam-samples 16 --num-exemplars 200"
61 | run_main_and_assert(args_line)
62 |
63 |
64 | def test_gridsearch_ewc():
65 | args_line = FAST_LOCAL_TEST_ARGS
66 | args_line += " --approach ewc --num-exemplars 200"
67 | run_main_and_assert(args_line)
68 |
69 |
70 | def test_gridsearch_mas():
71 | args_line = FAST_LOCAL_TEST_ARGS
72 | args_line += " --approach mas --num-exemplars 200"
73 | run_main_and_assert(args_line)
74 |
75 |
76 | def test_gridsearch_pathint():
77 | args_line = FAST_LOCAL_TEST_ARGS
78 | args_line += " --approach path_integral --num-exemplars 200"
79 | run_main_and_assert(args_line)
80 |
81 |
82 | def test_gridsearch_rwalk():
83 | args_line = FAST_LOCAL_TEST_ARGS
84 | args_line += " --approach r_walk --num-exemplars 200"
85 | run_main_and_assert(args_line)
86 |
87 |
88 | def test_gridsearch_dmc():
89 | args_line = FAST_LOCAL_TEST_ARGS
90 | args_line += " --approach dmc"
91 | args_line += " --aux-dataset mnist" # just to test the grid search fast
92 | run_main_and_assert(args_line)
93 |
--------------------------------------------------------------------------------
/src/tests/test_icarl.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach icarl"
8 |
9 |
10 | def test_icarl_exemplars():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --num-exemplars 200"
13 | args_line += " --lamb 1"
14 | run_main_and_assert(args_line)
15 |
16 |
17 | def test_icarl_exemplars_without_lamb():
18 | args_line = FAST_LOCAL_TEST_ARGS
19 | args_line += " --num-exemplars 200"
20 | args_line += " --lamb 0"
21 | run_main_and_assert(args_line)
22 |
23 |
24 | def test_icarl_with_warmup():
25 | args_line = FAST_LOCAL_TEST_ARGS
26 | args_line += " --warmup-nepochs 5"
27 | args_line += " --warmup-lr-factor 0.5"
28 | args_line += " --num-exemplars 200"
29 | args_line += " --lamb 1"
30 | run_main_and_assert(args_line)
31 |
--------------------------------------------------------------------------------
/src/tests/test_il2m.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach il2m"
8 |
9 |
10 | def test_il2m():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --num-exemplars 200"
13 | run_main_and_assert(args_line)
14 |
15 |
16 | def test_il2m_with_warmup():
17 | args_line = FAST_LOCAL_TEST_ARGS
18 | args_line += " --warmup-nepochs 5"
19 | args_line += " --warmup-lr-factor 0.5"
20 | args_line += " --num-exemplars 200"
21 | run_main_and_assert(args_line)
22 |
--------------------------------------------------------------------------------
/src/tests/test_joint.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
6 | " --num-workers 0"
7 |
8 |
9 | def test_joint():
10 | args_line = FAST_LOCAL_TEST_ARGS
11 | args_line += " --approach joint"
12 | run_main_and_assert(args_line)
13 |
--------------------------------------------------------------------------------
/src/tests/test_last_layer_analysis.py:
--------------------------------------------------------------------------------
1 | from tests import run_main
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
6 | " --num-workers 0" \
7 | " --approach finetuning"
8 |
9 |
10 | def test_last_layer_analysis():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --last-layer-analysis"
13 | run_main(args_line)
14 |
15 |
--------------------------------------------------------------------------------
/src/tests/test_loggers.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pathlib import Path
3 |
4 | from tests import run_main
5 |
6 | FAST_LOCAL_TEST_ARGS = "--exp-name loggers_test --datasets mnist" \
7 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
8 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
9 | " --num-workers 0 --approach finetuning"
10 |
11 |
12 | def test_disk_and_tensorflow_logger():
13 |
14 | args_line = FAST_LOCAL_TEST_ARGS
15 | args_line += " --log disk tensorboard"
16 | result = run_main(args_line, 'results_test_loggers', clean_run=True)
17 | experiment_dir = Path(result[-1])
18 |
19 | # check disk logger
20 | assert experiment_dir.is_dir()
21 | raw_logs = list(experiment_dir.glob('raw_log-*.txt'))
22 | assert len(raw_logs) == 1
23 | df = pd.read_json(raw_logs[0], lines=True)
24 | assert sorted(df.iter.unique()) == [0, 1, 2]
25 | assert sorted(df.group.unique()) == ['test', 'train', 'valid']
26 | assert len(df.group.unique()) == 3
27 |
28 | # check tb logger
29 | tb_events_logs = list(experiment_dir.glob('events.out.tfevents*'))
30 | assert len(tb_events_logs) == 1
31 | assert experiment_dir.joinpath(tb_events_logs[0]).is_file()
32 |
--------------------------------------------------------------------------------
/src/tests/test_lucir.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --gridsearch-tasks -1" \
8 | " --approach lucir"
9 |
10 |
11 | def test_lucir_exemplars():
12 | args_line = FAST_LOCAL_TEST_ARGS
13 | args_line += " --num-exemplars-per-class 20"
14 | run_main_and_assert(args_line)
15 |
16 |
17 | def test_lucir_exemplars_with_gridsearch():
18 | args_line = FAST_LOCAL_TEST_ARGS
19 | args_line += " --num-exemplars-per-class 20"
20 | args_line = args_line.replace('--gridsearch-tasks -1', '--gridsearch-tasks 3')
21 | run_main_and_assert(args_line)
22 |
23 |
24 | def test_lucir_exemplars():
25 | args_line = FAST_LOCAL_TEST_ARGS
26 | args_line += " --num-exemplars-per-class 20"
27 | run_main_and_assert(args_line)
28 |
29 |
30 | def test_lucir_exemplars_remove_margin_ranking():
31 | args_line = FAST_LOCAL_TEST_ARGS
32 | args_line += " --num-exemplars-per-class 20"
33 | args_line += " --remove-margin-ranking"
34 | run_main_and_assert(args_line)
35 |
36 |
37 | def test_lucir_exemplars_remove_adapt_lamda():
38 | args_line = FAST_LOCAL_TEST_ARGS
39 | args_line += " --num-exemplars-per-class 20"
40 | args_line += " --remove-adapt-lamda"
41 | run_main_and_assert(args_line)
42 |
43 |
44 | def test_lucir_exemplars_warmup():
45 | args_line = FAST_LOCAL_TEST_ARGS
46 | args_line += " --num-exemplars-per-class 20"
47 | args_line += " --warmup-nepochs 5"
48 | args_line += " --warmup-lr-factor 0.5"
49 | run_main_and_assert(args_line)
50 |
--------------------------------------------------------------------------------
/src/tests/test_lwf.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach lwf"
8 |
9 |
10 | def test_lwf_without_exemplars():
11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
12 |
13 |
14 | def test_lwf_with_exemplars():
15 | args_line = FAST_LOCAL_TEST_ARGS
16 | args_line += " --num-exemplars 200"
17 | run_main_and_assert(args_line)
18 |
19 |
20 | def test_lwf_with_warmup():
21 | args_line = FAST_LOCAL_TEST_ARGS
22 | args_line += " --warmup-nepochs 5"
23 | args_line += " --warmup-lr-factor 0.5"
24 | args_line += " --num-exemplars 200"
25 | run_main_and_assert(args_line)
26 |
--------------------------------------------------------------------------------
/src/tests/test_lwm.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach lwm --gradcam-layer conv2"
8 |
9 |
10 | def test_lwm_without_exemplars():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --log-gradcam-samples 16"
13 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
14 |
15 |
16 | def test_lwm_with_exemplars():
17 | args_line = FAST_LOCAL_TEST_ARGS
18 | args_line += " --num-exemplars 200"
19 | run_main_and_assert(args_line)
20 |
21 |
22 | def test_lwm_with_warmup():
23 | args_line = FAST_LOCAL_TEST_ARGS
24 | args_line += " --warmup-nepochs 5"
25 | args_line += " --warmup-lr-factor 0.5"
26 | args_line += " --num-exemplars 200"
27 | run_main_and_assert(args_line)
28 |
--------------------------------------------------------------------------------
/src/tests/test_mas.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach mas"
8 |
9 |
10 | def test_mas_without_exemplars():
11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
12 |
13 |
14 | def test_mas_with_exemplars():
15 | args_line = FAST_LOCAL_TEST_ARGS
16 | args_line += " --num-exemplars 200"
17 | run_main_and_assert(args_line)
18 |
19 |
20 | def test_mas_with_warmup():
21 | args_line = FAST_LOCAL_TEST_ARGS
22 | args_line += " --warmup-nepochs 5"
23 | args_line += " --warmup-lr-factor 0.5"
24 | args_line += " --num-exemplars 200"
25 | run_main_and_assert(args_line)
26 |
--------------------------------------------------------------------------------
/src/tests/test_multisoftmax.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
6 | " --num-workers 0"
7 |
8 |
9 | def test_finetuning_without_multisoftmax():
10 | args_line = FAST_LOCAL_TEST_ARGS
11 | args_line += " --approach finetuning"
12 | run_main_and_assert(args_line)
13 |
14 |
15 | def test_finetuning_with_multisoftmax():
16 | args_line = FAST_LOCAL_TEST_ARGS
17 | args_line += " --approach finetuning"
18 | args_line += " --multi-softmax"
19 | run_main_and_assert(args_line)
20 |
--------------------------------------------------------------------------------
/src/tests/test_path_integral.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach path_integral"
8 |
9 |
10 | def test_pi_without_exemplars():
11 | run_main_and_assert(FAST_LOCAL_TEST_ARGS)
12 |
13 |
14 | def test_pi_with_exemplars():
15 | args_line = FAST_LOCAL_TEST_ARGS
16 | args_line += " --num-exemplars 200"
17 | run_main_and_assert(args_line)
18 |
19 |
20 | def test_pi_with_warmup():
21 | args_line = FAST_LOCAL_TEST_ARGS
22 | args_line += " --warmup-nepochs 5"
23 | args_line += " --warmup-lr-factor 0.5"
24 | args_line += " --num-exemplars 200"
25 | run_main_and_assert(args_line)
26 |
--------------------------------------------------------------------------------
/src/tests/test_rwalk.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 3" \
6 | " --num-workers 0" \
7 | " --approach r_walk"
8 |
9 |
10 | def test_rwalk_without_exemplars():
11 | args_line = FAST_LOCAL_TEST_ARGS
12 | args_line += " --num-exemplars 0"
13 | run_main_and_assert(args_line)
14 |
15 |
16 | def test_rwalk_with_exemplars():
17 | args_line = FAST_LOCAL_TEST_ARGS
18 | args_line += " --num-exemplars 200"
19 | run_main_and_assert(args_line)
20 |
21 |
22 | def test_rwalk_with_warmup():
23 | args_line = FAST_LOCAL_TEST_ARGS
24 | args_line += " --warmup-nepochs 5"
25 | args_line += " --warmup-lr-factor 0.5"
26 | args_line += " --num-exemplars 200"
27 | run_main_and_assert(args_line)
28 |
--------------------------------------------------------------------------------
/src/tests/test_stop_at_task.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 5 --seed 1 --batch-size 32" \
5 | " --nepochs 2 --num-workers 0 --stop-at-task 3"
6 |
7 |
8 | def test_finetuning_stop_at_task():
9 | args_line = FAST_LOCAL_TEST_ARGS
10 | args_line += " --approach finetuning"
11 | run_main_and_assert(args_line)
12 |
--------------------------------------------------------------------------------
/src/tests/test_warmup.py:
--------------------------------------------------------------------------------
1 | from tests import run_main_and_assert
2 |
3 | FAST_LOCAL_TEST_ARGS = "--exp-name local_test --datasets mnist" \
4 | " --network LeNet --num-tasks 3 --seed 1 --batch-size 32" \
5 | " --nepochs 2 --lr-factor 10 --momentum 0.9 --lr-min 1e-7" \
6 | " --num-workers 0"
7 |
8 |
9 | def test_finetuning_without_warmup():
10 | args_line = FAST_LOCAL_TEST_ARGS
11 | args_line += " --approach finetuning"
12 | run_main_and_assert(args_line)
13 |
14 |
15 | def test_finetuning_with_warmup():
16 | args_line = FAST_LOCAL_TEST_ARGS
17 | args_line += " --approach finetuning"
18 | args_line += " --warmup-nepochs 5"
19 | args_line += " --warmup-lr-factor 0.5"
20 | run_main_and_assert(args_line)
21 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 |
6 | cudnn_deterministic = True
7 |
8 |
9 | def seed_everything(seed=0):
10 | """Fix all random seeds"""
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed_all(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.backends.cudnn.deterministic = cudnn_deterministic
17 |
18 |
19 | def print_summary(acc_taw, acc_tag, forg_taw, forg_tag):
20 | """Print summary of results"""
21 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]):
22 | print('*' * 108)
23 | print(name)
24 | for i in range(metric.shape[0]):
25 | print('\t', end='')
26 | for j in range(metric.shape[1]):
27 | print('{:5.1f}% '.format(100 * metric[i, j]), end='')
28 | if np.trace(metric) == 0.0:
29 | if i > 0:
30 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='')
31 | else:
32 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i + 1].mean()), end='')
33 | print()
34 | print('*' * 108)
35 |
--------------------------------------------------------------------------------