├── agents ├── __init__.py ├── __pycache__ │ ├── lwf.cpython-38.pyc │ ├── lwf.cpython-39.pyc │ ├── pcr.cpython-39.pyc │ ├── scr.cpython-38.pyc │ ├── scr.cpython-39.pyc │ ├── FOAL.cpython-39.pyc │ ├── agem.cpython-38.pyc │ ├── agem.cpython-39.pyc │ ├── base.cpython-38.pyc │ ├── base.cpython-39.pyc │ ├── cndpm.cpython-38.pyc │ ├── cndpm.cpython-39.pyc │ ├── gdumb.cpython-38.pyc │ ├── gdumb.cpython-39.pyc │ ├── icarl.cpython-38.pyc │ ├── icarl.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── ewc_pp.cpython-38.pyc │ ├── ewc_pp.cpython-39.pyc │ ├── exp_replay.cpython-38.pyc │ ├── exp_replay.cpython-39.pyc │ └── exp_replay_dvc.cpython-39.pyc ├── cndpm.py ├── FOAL.py ├── scr.py ├── icarl.py ├── gdumb.py ├── lwf.py ├── agem.py ├── pcr.py └── ewc_pp.py ├── utils ├── __init__.py ├── buffer │ ├── __init__.py │ ├── __pycache__ │ │ ├── buffer.cpython-38.pyc │ │ ├── buffer.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── mem_match.cpython-39.pyc │ │ ├── aser_update.cpython-39.pyc │ │ ├── aser_utils.cpython-39.pyc │ │ ├── buffer_utils.cpython-38.pyc │ │ ├── buffer_utils.cpython-39.pyc │ │ ├── mgi_retrieve.cpython-39.pyc │ │ ├── mir_retrieve.cpython-39.pyc │ │ ├── sc_retrieve.cpython-39.pyc │ │ ├── aser_retrieve.cpython-39.pyc │ │ ├── gss_greedy_update.cpython-39.pyc │ │ ├── random_retrieve.cpython-39.pyc │ │ └── reservoir_update.cpython-39.pyc │ ├── random_retrieve.py │ ├── sc_retrieve.py │ ├── mem_match.py │ ├── buffer.py │ ├── reservoir_update.py │ ├── mir_retrieve.py │ ├── mgi_retrieve.py │ ├── aser_retrieve.py │ └── aser_update.py ├── __pycache__ │ ├── io.cpython-39.pyc │ ├── loss.cpython-38.pyc │ ├── loss.cpython-39.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── global_vars.cpython-38.pyc │ ├── global_vars.cpython-39.pyc │ ├── kd_manager.cpython-38.pyc │ ├── kd_manager.cpython-39.pyc │ ├── name_match.cpython-38.pyc │ ├── name_match.cpython-39.pyc │ ├── setup_elements.cpython-38.pyc │ └── setup_elements.cpython-39.pyc ├── kd_manager.py ├── global_vars.py ├── io.py ├── name_match.py ├── loss.py ├── utils.py └── setup_elements.py ├── continuum ├── __init__.py ├── dataset_scripts │ ├── __init__.py │ ├── __pycache__ │ │ ├── DTD.cpython-39.pyc │ │ ├── CelebA.cpython-39.pyc │ │ ├── core50.cpython-38.pyc │ │ ├── core50.cpython-39.pyc │ │ ├── Aircraft.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── cifar10.cpython-38.pyc │ │ ├── cifar10.cpython-39.pyc │ │ ├── cifar100.cpython-38.pyc │ │ ├── cifar100.cpython-39.pyc │ │ ├── flower102.cpython-39.pyc │ │ ├── food101.cpython-39.pyc │ │ ├── openloris.cpython-38.pyc │ │ ├── openloris.cpython-39.pyc │ │ ├── Country211.cpython-39.pyc │ │ ├── inaturalist.cpython-39.pyc │ │ ├── dataset_base.cpython-38.pyc │ │ ├── dataset_base.cpython-39.pyc │ │ ├── mini_imagenet.cpython-38.pyc │ │ └── mini_imagenet.cpython-39.pyc │ ├── dataset_base.py │ ├── cifar100.py │ ├── cifar10.py │ ├── flower102.py │ ├── food101.py │ ├── mini_imagenet.py │ ├── openloris.py │ ├── DTD.py │ ├── tinyimagenet.py │ ├── CelebA.py │ ├── place365.py │ ├── Country211.py │ ├── Aircraft.py │ ├── inaturalist.py │ ├── SUN397.py │ └── core50.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── continuum.cpython-38.pyc │ ├── continuum.cpython-39.pyc │ ├── data_utils.cpython-38.pyc │ ├── data_utils.cpython-39.pyc │ ├── non_stationary.cpython-38.pyc │ └── non_stationary.cpython-39.pyc ├── continuum.py └── data_utils.py ├── experiment ├── __init__.py ├── __pycache__ │ ├── run.cpython-38.pyc │ ├── run.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── metrics.cpython-39.pyc │ └── tune_hyperparam.cpython-39.pyc ├── tune_hyperparam.py └── metrics.py ├── models ├── ndpm │ ├── __init__.py │ ├── __pycache__ │ │ ├── vae.cpython-38.pyc │ │ ├── vae.cpython-39.pyc │ │ ├── loss.cpython-38.pyc │ │ ├── loss.cpython-39.pyc │ │ ├── ndpm.cpython-38.pyc │ │ ├── ndpm.cpython-39.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── expert.cpython-38.pyc │ │ ├── expert.cpython-39.pyc │ │ ├── priors.cpython-38.pyc │ │ ├── priors.cpython-39.pyc │ │ ├── classifier.cpython-38.pyc │ │ ├── classifier.cpython-39.pyc │ │ ├── component.cpython-38.pyc │ │ └── component.cpython-39.pyc │ ├── utils.py │ ├── loss.py │ ├── priors.py │ ├── expert.py │ └── component.py ├── __init__.py ├── __pycache__ │ ├── vit.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── resnet.cpython-38.pyc │ ├── resnet.cpython-39.pyc │ ├── pretrained.cpython-38.pyc │ └── pretrained.cpython-39.pyc ├── vit.py └── pretrained.py ├── requirements.txt ├── main_config.py ├── main_tune.py └── README.md /agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/ndpm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet -------------------------------------------------------------------------------- /agents/__pycache__/lwf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/lwf.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/lwf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/lwf.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/pcr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/pcr.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/scr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/scr.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/scr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/scr.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/vit.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/io.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/FOAL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/FOAL.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/agem.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/agem.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/agem.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/agem.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/cndpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/cndpm.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/cndpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/cndpm.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/gdumb.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/gdumb.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/gdumb.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/gdumb.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/icarl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/icarl.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/icarl.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/icarl.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/ewc_pp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/ewc_pp.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/ewc_pp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/ewc_pp.cpython-39.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/run.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/experiment/__pycache__/run.cpython-38.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/run.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/experiment/__pycache__/run.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/vae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/vae.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/vae.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/vae.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/exp_replay.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/exp_replay.cpython-38.pyc -------------------------------------------------------------------------------- /agents/__pycache__/exp_replay.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/exp_replay.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/pretrained.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/__pycache__/pretrained.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/ndpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/ndpm.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/ndpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/ndpm.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/global_vars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/global_vars.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/global_vars.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/global_vars.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/kd_manager.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/kd_manager.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/kd_manager.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/kd_manager.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/name_match.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/name_match.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/name_match.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/name_match.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/continuum.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/continuum.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/continuum.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/continuum.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/data_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/data_utils.cpython-39.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/experiment/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/experiment/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/experiment/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/expert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/expert.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/expert.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/expert.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/priors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/priors.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/priors.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/priors.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/setup_elements.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/setup_elements.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/setup_elements.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/__pycache__/setup_elements.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/buffer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/buffer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/buffer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/buffer.cpython-39.pyc -------------------------------------------------------------------------------- /agents/__pycache__/exp_replay_dvc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/agents/__pycache__/exp_replay_dvc.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/classifier.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/classifier.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/classifier.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/classifier.cpython-39.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/component.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/component.cpython-38.pyc -------------------------------------------------------------------------------- /models/ndpm/__pycache__/component.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/models/ndpm/__pycache__/component.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/mem_match.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/mem_match.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/non_stationary.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/non_stationary.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/__pycache__/non_stationary.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/__pycache__/non_stationary.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/aser_update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/aser_update.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/aser_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/aser_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/buffer_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/buffer_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/buffer_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/buffer_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/mgi_retrieve.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/mgi_retrieve.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/mir_retrieve.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/mir_retrieve.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/sc_retrieve.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/sc_retrieve.cpython-39.pyc -------------------------------------------------------------------------------- /experiment/__pycache__/tune_hyperparam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/experiment/__pycache__/tune_hyperparam.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/aser_retrieve.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/aser_retrieve.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/DTD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/DTD.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/gss_greedy_update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/gss_greedy_update.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/random_retrieve.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/random_retrieve.cpython-39.pyc -------------------------------------------------------------------------------- /utils/buffer/__pycache__/reservoir_update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/utils/buffer/__pycache__/reservoir_update.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/CelebA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/CelebA.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/core50.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/core50.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/core50.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/core50.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/Aircraft.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/Aircraft.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/cifar10.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/cifar10.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/cifar10.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/cifar10.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/cifar100.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/cifar100.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/cifar100.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/cifar100.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/flower102.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/flower102.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/food101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/food101.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/openloris.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/openloris.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/openloris.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/openloris.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/Country211.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/Country211.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/inaturalist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/inaturalist.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/dataset_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/dataset_base.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/dataset_base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/dataset_base.cpython-39.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuyuchen-cz/F-OAL/HEAD/continuum/dataset_scripts/__pycache__/mini_imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kornia==0.6.12 2 | matplotlib==3.6.2 3 | numpy==1.23.5 4 | pandas==2.2.3 5 | Pillow==9.3.0 6 | psutil==5.9.5 7 | PyYAML==6.0.1 8 | scikit_learn==1.3.1 9 | scipy==1.9.1 10 | skimage==0.0 11 | timm==0.9.7 12 | torch==1.13.1 13 | torchvision==0.13.0 14 | tqdm==4.64.1 15 | -------------------------------------------------------------------------------- /models/ndpm/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Lambda(nn.Module): 5 | def __init__(self, f=None): 6 | super().__init__() 7 | self.f = f if f is not None else (lambda x: x) 8 | 9 | def forward(self, *args, **kwargs): 10 | return self.f(*args, **kwargs) 11 | -------------------------------------------------------------------------------- /utils/buffer/random_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import random_retrieve 2 | 3 | class Random_retrieve(object): 4 | def __init__(self, params): 5 | super().__init__() 6 | self.num_retrieve = params.eps_mem_batch 7 | 8 | def retrieve(self, buffer, **kwargs): 9 | return random_retrieve(buffer, self.num_retrieve) -------------------------------------------------------------------------------- /utils/buffer/sc_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | import torch 3 | 4 | class Match_retrieve(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | self.num_retrieve = params.eps_mem_batch 8 | self.warmup = params.warmup 9 | 10 | def retrieve(self, buffer, **kwargs): 11 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 12 | cur_x, cur_y = kwargs['x'], kwargs['y'] 13 | return match_retrieve(buffer, cur_y) 14 | else: 15 | return torch.tensor([]), torch.tensor([]) -------------------------------------------------------------------------------- /utils/kd_manager.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def loss_fn_kd(scores, target_scores, T=2.): 7 | log_scores_norm = F.log_softmax(scores / T, dim=1) 8 | targets_norm = F.softmax(target_scores / T, dim=1) 9 | # Calculate distillation loss (see e.g., Li and Hoiem, 2017) 10 | kd_loss = (-1 * targets_norm * log_scores_norm).sum(dim=1).mean() * T ** 2 11 | return kd_loss 12 | 13 | 14 | class KdManager: 15 | def __init__(self): 16 | self.teacher_model = None 17 | 18 | def update_teacher(self, model): 19 | self.teacher_model = copy.deepcopy(model) 20 | 21 | def get_kd_loss(self, cur_model_logits, x): 22 | if self.teacher_model is not None: 23 | with torch.no_grad(): 24 | prev_model_logits = self.teacher_model.forward(x) 25 | dist_loss = loss_fn_kd(cur_model_logits, prev_model_logits) 26 | else: 27 | dist_loss = 0 28 | return dist_loss 29 | -------------------------------------------------------------------------------- /utils/buffer/mem_match.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | from utils.buffer.buffer_utils import random_retrieve 3 | import torch 4 | 5 | class MemMatch_retrieve(object): 6 | def __init__(self, params): 7 | super().__init__() 8 | self.num_retrieve = params.eps_mem_batch 9 | self.warmup = params.warmup 10 | 11 | 12 | def retrieve(self, buffer, **kwargs): 13 | match_x, match_y = torch.tensor([]), torch.tensor([]) 14 | candidate_x, candidate_y = torch.tensor([]), torch.tensor([]) 15 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 16 | while match_x.size(0) == 0: 17 | candidate_x, candidate_y, indices = random_retrieve(buffer, self.num_retrieve,return_indices=True) 18 | if candidate_x.size(0) == 0: 19 | return candidate_x, candidate_y, match_x, match_y 20 | match_x, match_y = match_retrieve(buffer, candidate_y, indices) 21 | return candidate_x, candidate_y, match_x, match_y -------------------------------------------------------------------------------- /continuum/continuum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # other imports 4 | from utils.name_match import data_objects 5 | 6 | class continuum(object): 7 | def __init__(self, dataset, scenario, params): 8 | """" Initialize Object """ 9 | self.data_object = data_objects[dataset](scenario, params) 10 | self.run = params.num_runs 11 | self.task_nums = self.data_object.task_nums 12 | self.cur_task = 0 13 | self.cur_run = -1 14 | 15 | def __iter__(self): 16 | return self 17 | 18 | def __next__(self): 19 | if self.cur_task == self.data_object.task_nums: 20 | raise StopIteration 21 | x_train, y_train, labels = self.data_object.new_task(self.cur_task, cur_run=self.cur_run) 22 | self.cur_task += 1 23 | return x_train, y_train, labels 24 | 25 | def test_data(self): 26 | return self.data_object.get_test_set() 27 | 28 | def clean_mem_test_set(self): 29 | self.data_object.clean_mem_test_set() 30 | 31 | def reset_run(self): 32 | self.cur_task = 0 33 | 34 | def new_run(self): 35 | self.cur_task = 0 36 | self.cur_run += 1 37 | self.data_object.new_run(cur_run=self.cur_run) 38 | 39 | 40 | -------------------------------------------------------------------------------- /utils/global_vars.py: -------------------------------------------------------------------------------- 1 | MODLES_NDPM_VAE_NF_BASE = 32 2 | MODELS_NDPM_VAE_NF_EXT = 4 3 | MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER = False 4 | MODELS_NDPM_VAE_Z_DIM = 64 5 | MODELS_NDPM_VAE_RECON_LOSS = 'gaussian' 6 | MODELS_NDPM_VAE_LEARN_X_LOG_VAR = False 7 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM = 0 8 | MODELS_NDPM_VAE_Z_SAMPLES = 16 9 | MODELS_NDPM_CLASSIFIER_NUM_BLOCKS = [1, 1, 1, 1] 10 | MODELS_NDPM_CLASSIFIER_NORM_LAYER = 'InstanceNorm2d' 11 | MODELS_NDPM_CLASSIFIER_CLS_NF_BASE = 20 12 | MODELS_NDPM_CLASSIFIER_CLS_NF_EXT = 4 13 | MODELS_NDPM_NDPM_DISABLE_D = False 14 | MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS = False 15 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE = 50 16 | MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS = 0 17 | MODELS_NDPM_NDPM_SLEEP_STEP_G = 4000 18 | MODELS_NDPM_NDPM_SLEEP_STEP_D = 1000 19 | MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE = 0 20 | MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP = 500 21 | MODELS_NDPM_NDPM_WEIGHT_DECAY = 0.00001 22 | MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY = False 23 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 24 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 25 | MODELS_NDPM_COMPONENT_CLIP_GRAD = {'type': 'value', 'options': {'clip_value': 0.5}} 26 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/dataset_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | 4 | class DatasetBase(ABC): 5 | def __init__(self, dataset, scenario, task_nums, run, params): 6 | super(DatasetBase, self).__init__() 7 | self.params = params 8 | self.scenario = scenario 9 | self.dataset = dataset 10 | self.task_nums = task_nums 11 | self.run = run 12 | self.root = os.path.join('./datasets', self.dataset) 13 | self.test_set = [] 14 | self.val_set = [] 15 | self._is_properly_setup() 16 | self.download_load() 17 | 18 | 19 | @abstractmethod 20 | def download_load(self): 21 | pass 22 | 23 | @abstractmethod 24 | def setup(self, **kwargs): 25 | pass 26 | 27 | @abstractmethod 28 | def new_task(self, cur_task, **kwargs): 29 | pass 30 | 31 | def _is_properly_setup(self): 32 | pass 33 | 34 | @abstractmethod 35 | def new_run(self, **kwargs): 36 | pass 37 | 38 | @property 39 | def dataset_info(self): 40 | return self.dataset 41 | 42 | def get_test_set(self): 43 | return self.test_set 44 | 45 | def clean_mem_test_set(self): 46 | self.test_set = None 47 | self.test_data = None 48 | self.test_label = None -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pandas as pd 3 | import os 4 | import psutil 5 | import torch 6 | 7 | def load_yaml(path, key='parameters'): 8 | with open(path, 'r') as stream: 9 | try: 10 | return yaml.load(stream, Loader=yaml.FullLoader)[key] 11 | except yaml.YAMLError as exc: 12 | print(exc) 13 | 14 | def save_dataframe_csv(df, path, name): 15 | df.to_csv(path + '/' + name, index=False) 16 | 17 | 18 | def load_dataframe_csv(path, name=None, delimiter=None, names=None): 19 | if not name: 20 | return pd.read_csv(path, delimiter=delimiter, names=names) 21 | else: 22 | return pd.read_csv(path+name, delimiter=delimiter, names=names) 23 | 24 | def check_ram_usage(): 25 | """ 26 | Compute the RAM usage of the current process. 27 | Returns: 28 | mem (float): Memory occupation in Megabytes 29 | """ 30 | 31 | process = psutil.Process(os.getpid()) 32 | mem = process.memory_info().rss / (1024 * 1024) 33 | 34 | return mem 35 | 36 | def save_model(model, optimizer, opt, epoch, save_file): 37 | print('==> Saving...') 38 | state = { 39 | 'opt': opt, 40 | 'model': model.state_dict(), 41 | 'optimizer': optimizer.state_dict(), 42 | 'epoch': epoch, 43 | } 44 | torch.save(state, save_file) 45 | del state 46 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class CustomViT(nn.Module): 7 | def __init__(self, pretrained=True, num_classes=1000): 8 | super(CustomViT, self).__init__() 9 | self.vit = timm.create_model('vit_base_patch16_224', pretrained=pretrained) 10 | # 需要保证头部能够处理累加后的特征,所以使用原头部是可行的 11 | self.head = self.vit.head 12 | 13 | def forward(self, x): 14 | B = x.shape[0] # Batch size 15 | x = self.vit.patch_embed(x) # 应用 patch embedding 16 | 17 | cls_tokens = self.vit.cls_token.expand(B, -1, -1) # 扩展CLS token 18 | x = torch.cat((cls_tokens, x), dim=1) # 将CLS token添加到序列的前面 19 | x = x + self.vit.pos_embed # 添加位置嵌入 20 | x = self.vit.pos_drop(x) # 应用位置dropout 21 | 22 | # 初始化一个变量来累加每个块的输出 23 | cumulative_output = 0 24 | 25 | for blk in self.vit.blocks: 26 | x = blk(x) 27 | cumulative_output += x[:, 0] # 累加每个块的CLS token输出 28 | 29 | # 计算平均值 30 | average_output = cumulative_output / len(self.vit.blocks) 31 | 32 | # 应用最后的层次归一化 33 | x = self.vit.norm(average_output) 34 | 35 | # 通过分类头部 36 | 37 | 38 | return x 39 | 40 | model = CustomViT(pretrained=True) 41 | input_tensor = torch.randn(1, 3, 224, 224) # 创建一个假的输入张量 42 | output = model(input_tensor) 43 | print(output.shape) # 检查输出的形状 44 | -------------------------------------------------------------------------------- /models/ndpm/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.functional import binary_cross_entropy 5 | 6 | 7 | def gaussian_nll(x, mean, log_var, min_noise=0.001): 8 | return ( 9 | ((x - mean) ** 2 + min_noise) / (2 * log_var.exp() + 1e-8) 10 | + 0.5 * log_var + 0.5 * np.log(2 * np.pi) 11 | ) 12 | 13 | 14 | def laplace_nll(x, median, log_scale, min_noise=0.01): 15 | return ( 16 | ((x - median).abs() + min_noise) / (log_scale.exp() + 1e-8) 17 | + log_scale + np.log(2) 18 | ) 19 | 20 | 21 | def bernoulli_nll(x, p): 22 | # Broadcast 23 | x_exp, p_exp = [], [] 24 | for x_size, p_size in zip(x.size(), p.size()): 25 | if x_size > p_size: 26 | x_exp.append(-1) 27 | p_exp.append(x_size) 28 | elif x_size < p_size: 29 | x_exp.append(p_size) 30 | p_exp.append(-1) 31 | else: 32 | x_exp.append(-1) 33 | p_exp.append(-1) 34 | x = x.expand(*x_exp) 35 | p = p.expand(*p_exp) 36 | 37 | return binary_cross_entropy(p, x, reduction='none') 38 | 39 | 40 | def logistic_nll(x, mean, log_scale): 41 | bin_size = 1 / 256 42 | scale = log_scale.exp() 43 | x_centered = x - mean 44 | cdf1 = x_centered / scale 45 | cdf2 = (x_centered + bin_size) / scale 46 | p = torch.sigmoid(cdf2) - torch.sigmoid(cdf1) + 1e-12 47 | return -p.log() 48 | -------------------------------------------------------------------------------- /main_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.io import load_yaml 3 | from types import SimpleNamespace 4 | from utils.utils import boolean_string 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | from experiment.run import multiple_run 10 | 11 | def main(args): 12 | genereal_params = load_yaml(args.general) 13 | data_params = load_yaml(args.data) 14 | agent_params = load_yaml(args.agent) 15 | genereal_params['verbose'] = args.verbose 16 | genereal_params['cuda'] = torch.cuda.is_available() 17 | final_params = SimpleNamespace(**genereal_params, **data_params, **agent_params) 18 | time_start = time.time() 19 | print(final_params) 20 | 21 | #reproduce 22 | np.random.seed(final_params.seed) 23 | random.seed(final_params.seed) 24 | torch.manual_seed(final_params.seed) 25 | if final_params.cuda: 26 | torch.cuda.manual_seed(final_params.seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | #run 31 | multiple_run(final_params) 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | # Commandline arguments 37 | parser = argparse.ArgumentParser('CVPR Continual Learning Challenge') 38 | parser.add_argument('--general', dest='general', default='config/general.yml') 39 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml') 40 | parser.add_argument('--agent', dest='agent', default='config/agent/er.yml') 41 | 42 | parser.add_argument('--verbose', type=boolean_string, default=True, 43 | help='print information or not') 44 | args = parser.parse_args() 45 | main(args) -------------------------------------------------------------------------------- /agents/cndpm.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from models.ndpm.ndpm import Ndpm 4 | from utils.setup_elements import transforms_match 5 | from torch.utils import data 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class Cndpm(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Cndpm, self).__init__(model, opt, params) 13 | self.model = model 14 | 15 | 16 | def train_learner(self, x_train, y_train): 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | # setup tracker 22 | losses_batch = AverageMeter() 23 | acc_batch = AverageMeter() 24 | 25 | self.model.train() 26 | 27 | for ep in range(self.epoch): 28 | for i, batch_data in enumerate(train_loader): 29 | # batch update 30 | batch_x, batch_y = batch_data 31 | batch_x = maybe_cuda(batch_x, self.cuda) 32 | batch_y = maybe_cuda(batch_y, self.cuda) 33 | self.model.learn(batch_x, batch_y) 34 | if self.params.verbose: 35 | print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format( 36 | i, 37 | len(self.model.stm_x), self.params.stm_capacity, 38 | len(self.model.experts) - 1 39 | ), end='') 40 | print() 41 | -------------------------------------------------------------------------------- /utils/buffer/buffer.py: -------------------------------------------------------------------------------- 1 | from utils.setup_elements import input_size_match 2 | from utils import name_match #import update_methods, retrieve_methods 3 | from utils.utils import maybe_cuda 4 | import torch 5 | from utils.buffer.buffer_utils import BufferClassTracker 6 | from utils.setup_elements import n_classes 7 | 8 | class Buffer(torch.nn.Module): 9 | def __init__(self, model, params): 10 | super().__init__() 11 | self.params = params 12 | self.model = model 13 | self.cuda = self.params.cuda 14 | self.current_index = 0 15 | self.n_seen_so_far = 0 16 | self.device = "cuda" if self.params.cuda else "cpu" 17 | 18 | # define buffer 19 | buffer_size = params.mem_size 20 | print('buffer has %d slots' % buffer_size) 21 | input_size = input_size_match[params.data] 22 | buffer_img = maybe_cuda(torch.FloatTensor(buffer_size, *input_size).fill_(0)) 23 | buffer_label = maybe_cuda(torch.LongTensor(buffer_size).fill_(0)) 24 | 25 | # registering as buffer allows us to save the object using `torch.save` 26 | self.register_buffer('buffer_img', buffer_img) 27 | self.register_buffer('buffer_label', buffer_label) 28 | 29 | # define update and retrieve method 30 | self.update_method = name_match.update_methods[params.update](params) 31 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params) 32 | 33 | if self.params.buffer_tracker: 34 | self.buffer_tracker = BufferClassTracker(n_classes[params.data], self.device) 35 | 36 | def update(self, x, y,**kwargs): 37 | return self.update_method.update(buffer=self, x=x, y=y, **kwargs) 38 | 39 | 40 | def retrieve(self, **kwargs): 41 | return self.retrieve_method.retrieve(buffer=self, **kwargs) -------------------------------------------------------------------------------- /models/ndpm/priors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | from utils.utils import maybe_cuda 5 | 6 | 7 | class Prior(ABC): 8 | def __init__(self, params): 9 | self.params = params 10 | 11 | @abstractmethod 12 | def add_expert(self): 13 | pass 14 | 15 | @abstractmethod 16 | def record_usage(self, usage, index=None): 17 | pass 18 | 19 | @abstractmethod 20 | def nl_prior(self, normalize=False): 21 | pass 22 | 23 | 24 | class CumulativePrior(Prior): 25 | def __init__(self, params): 26 | super().__init__(params) 27 | self.log_counts = maybe_cuda(torch.tensor( 28 | params.log_alpha 29 | )).float().unsqueeze(0) 30 | 31 | def add_expert(self): 32 | self.log_counts = torch.cat( 33 | [self.log_counts, maybe_cuda(torch.zeros(1))], 34 | dim=0 35 | ) 36 | 37 | def record_usage(self, usage, index=None): 38 | """Record expert usage 39 | 40 | Args: 41 | usage: Tensor of shape [K+1] if index is None else scalar 42 | index: expert index 43 | """ 44 | if index is None: 45 | self.log_counts = torch.logsumexp(torch.stack([ 46 | self.log_counts, 47 | usage.log() 48 | ], dim=1), dim=1) 49 | else: 50 | self.log_counts[index] = torch.logsumexp(torch.stack([ 51 | self.log_counts[index], 52 | maybe_cuda(torch.tensor(usage)).float().log() 53 | ], dim=0), dim=0) 54 | 55 | def nl_prior(self, normalize=False): 56 | nl_prior = -self.log_counts 57 | if normalize: 58 | nl_prior += torch.logsumexp(self.log_counts, dim=0) 59 | return nl_prior 60 | 61 | @property 62 | def counts(self): 63 | return self.log_counts.exp() 64 | -------------------------------------------------------------------------------- /experiment/tune_hyperparam.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from sklearn.model_selection import ParameterGrid 3 | from utils.setup_elements import setup_opt, setup_architecture 4 | from utils.utils import maybe_cuda 5 | from utils.name_match import agents 6 | import numpy as np 7 | from experiment.metrics import compute_performance 8 | 9 | 10 | def tune_hyper(tune_data, tune_test_loaders, default_params, tune_params): 11 | param_grid_list = list(ParameterGrid(tune_params)) 12 | print(len(param_grid_list)) 13 | tune_accs = [] 14 | tune_fgt = [] 15 | for param_set in param_grid_list: 16 | final_params = vars(default_params) 17 | print(param_set) 18 | final_params.update(param_set) 19 | final_params = SimpleNamespace(**final_params) 20 | accuracy_list = [] 21 | for run in range(final_params.num_runs_val): 22 | tmp_acc = [] 23 | model = setup_architecture(final_params) 24 | model = maybe_cuda(model, final_params.cuda) 25 | opt = setup_opt(final_params.optimizer, model, final_params.learning_rate, final_params.weight_decay) 26 | agent = agents[final_params.agent](model, opt, final_params) 27 | for i, (x_train, y_train, labels) in enumerate(tune_data): 28 | print("-----------tune run {} task {}-------------".format(run, i)) 29 | print('size: {}, {}'.format(x_train.shape, y_train.shape)) 30 | agent.train_learner(x_train, y_train) 31 | acc_array = agent.evaluate(tune_test_loaders) 32 | tmp_acc.append(acc_array) 33 | print( 34 | "-----------tune run {}-----------avg_end_acc {}-----------".format(run, np.mean(tmp_acc[-1]))) 35 | accuracy_list.append(np.array(tmp_acc)) 36 | accuracy_list = np.array(accuracy_list) 37 | avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt = compute_performance(accuracy_list) 38 | tune_accs.append(avg_end_acc[0]) 39 | tune_fgt.append(avg_end_fgt[0]) 40 | best_tune = param_grid_list[tune_accs.index(max(tune_accs))] 41 | return best_tune -------------------------------------------------------------------------------- /models/ndpm/expert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.ndpm.classifier import ResNetSharingClassifier 4 | from models.ndpm.vae import CnnSharingVae 5 | from utils.utils import maybe_cuda 6 | 7 | from utils.global_vars import * 8 | 9 | 10 | class Expert(nn.Module): 11 | def __init__(self, params, experts=()): 12 | super().__init__() 13 | self.id = len(experts) 14 | self.experts = experts 15 | 16 | self.g = maybe_cuda(CnnSharingVae(params, experts)) 17 | self.d = maybe_cuda(ResNetSharingClassifier(params, experts)) if not MODELS_NDPM_NDPM_DISABLE_D else None 18 | 19 | 20 | # use random initialized g if it's a placeholder 21 | if self.id == 0: 22 | self.eval() 23 | for p in self.g.parameters(): 24 | p.requires_grad = False 25 | 26 | # use random initialized d if it's a placeholder 27 | if self.id == 0 and self.d is not None: 28 | for p in self.d.parameters(): 29 | p.requires_grad = False 30 | 31 | def forward(self, x): 32 | return self.d(x) 33 | 34 | def nll(self, x, y, step=None): 35 | """Negative log likelihood""" 36 | nll = self.g.nll(x, step) 37 | if self.d is not None: 38 | d_nll = self.d.nll(x, y, step) 39 | nll = nll + d_nll 40 | return nll 41 | 42 | def collect_nll(self, x, y, step=None): 43 | if self.id == 0: 44 | nll = self.nll(x, y, step) 45 | return nll.unsqueeze(1) 46 | 47 | nll = self.g.collect_nll(x, step) 48 | if self.d is not None: 49 | d_nll = self.d.collect_nll(x, y, step) 50 | nll = nll + d_nll 51 | 52 | return nll 53 | 54 | def lr_scheduler_step(self): 55 | if self.g.lr_scheduler is not NotImplemented: 56 | self.g.lr_scheduler.step() 57 | if self.d is not None and self.d.lr_scheduler is not NotImplemented: 58 | self.d.lr_scheduler.step() 59 | 60 | def clip_grad(self): 61 | self.g.clip_grad() 62 | if self.d is not None: 63 | self.d.clip_grad() 64 | 65 | def optimizer_step(self): 66 | self.g.optimizer.step() 67 | if self.d is not None: 68 | self.d.optimizer.step() 69 | -------------------------------------------------------------------------------- /main_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.io import load_yaml 3 | from types import SimpleNamespace 4 | from utils.utils import boolean_string 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | from experiment.run import multiple_run_tune_separate 10 | from utils.setup_elements import default_trick 11 | 12 | def main(args): 13 | genereal_params = load_yaml(args.general) 14 | data_params = load_yaml(args.data) 15 | default_params = load_yaml(args.default) 16 | tune_params = load_yaml(args.tune) 17 | genereal_params['verbose'] = args.verbose 18 | genereal_params['cuda'] = torch.cuda.is_available() 19 | genereal_params['train_val'] = args.train_val 20 | if args.trick: 21 | default_trick[args.trick] = True 22 | genereal_params['trick'] = default_trick 23 | final_default_params = SimpleNamespace(**genereal_params, **data_params, **default_params) 24 | 25 | time_start = time.time() 26 | print(final_default_params) 27 | print() 28 | 29 | #reproduce 30 | np.random.seed(final_default_params.seed) 31 | random.seed(final_default_params.seed) 32 | torch.manual_seed(final_default_params.seed) 33 | if final_default_params.cuda: 34 | torch.cuda.manual_seed(final_default_params.seed) 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | #run 39 | multiple_run_tune_separate(final_default_params, tune_params, args.save_path) 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | # Commandline arguments 45 | parser = argparse.ArgumentParser('Continual Learning') 46 | parser.add_argument('--general', dest='general', default='config/general_1.yml') 47 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml') 48 | parser.add_argument('--default', dest='default', default='config/agent/er/er_1k.yml') 49 | parser.add_argument('--tune', dest='tune', default='config/agent/er/er_tune.yml') 50 | parser.add_argument('--save-path', dest='save_path', default=None) 51 | parser.add_argument('--verbose', type=boolean_string, default=False, 52 | help='print information or not') 53 | parser.add_argument('--train_val', type=boolean_string, default=False, 54 | help='use tha val batches to train') 55 | parser.add_argument('--trick', type=str, default=None) 56 | args = parser.parse_args() 57 | main(args) -------------------------------------------------------------------------------- /experiment/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import sem 3 | import scipy.stats as stats 4 | 5 | def compute_performance(end_task_acc_arr): 6 | """ 7 | Given test accuracy results from multiple runs saved in end_task_acc_arr, 8 | compute the average accuracy, forgetting, and task accuracies as well as their confidence intervals. 9 | 10 | :param end_task_acc_arr: (list) List of lists 11 | :param task_ids: (list or tuple) Task ids to keep track of 12 | :return: (avg_end_acc, forgetting, avg_acc_task) 13 | """ 14 | n_run, n_tasks = end_task_acc_arr.shape[:2] 15 | t_coef = stats.t.ppf((1+0.95) / 2, n_run-1) # t coefficient used to compute 95% CIs: mean +- t * 16 | 17 | # compute average test accuracy and CI 18 | end_acc = end_task_acc_arr[:, -1, :] # shape: (num_run, num_task) 19 | avg_acc_per_run = np.mean(end_acc, axis=1) # mean of end task accuracies per run 20 | avg_end_acc = (np.mean(avg_acc_per_run), t_coef * sem(avg_acc_per_run)) 21 | 22 | # compute forgetting 23 | best_acc = np.max(end_task_acc_arr, axis=1) 24 | final_forgets = best_acc - end_acc 25 | avg_fgt = np.mean(final_forgets, axis=1) 26 | avg_end_fgt = (np.mean(avg_fgt), t_coef * sem(avg_fgt)) 27 | 28 | # compute ACC 29 | acc_per_run = np.mean((np.sum(np.tril(end_task_acc_arr), axis=2) / 30 | (np.arange(n_tasks) + 1)), axis=1) 31 | avg_acc = (np.mean(acc_per_run), t_coef * sem(acc_per_run)) 32 | 33 | 34 | # compute BWT+ 35 | bwt_per_run = (np.sum(np.tril(end_task_acc_arr, -1), axis=(1,2)) - 36 | np.sum(np.diagonal(end_task_acc_arr, axis1=1, axis2=2) * 37 | (np.arange(n_tasks, 0, -1) - 1), axis=1)) / (n_tasks * (n_tasks - 1) / 2) 38 | bwtp_per_run = np.maximum(bwt_per_run, 0) 39 | avg_bwtp = (np.mean(bwtp_per_run), t_coef * sem(bwtp_per_run)) 40 | 41 | # compute FWT 42 | fwt_per_run = np.sum(np.triu(end_task_acc_arr, 1), axis=(1,2)) / (n_tasks * (n_tasks - 1) / 2) 43 | avg_fwt = (np.mean(fwt_per_run), t_coef * sem(fwt_per_run)) 44 | return avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt 45 | 46 | 47 | 48 | 49 | def single_run_avg_end_fgt(acc_array): 50 | best_acc = np.max(acc_array, axis=1) 51 | end_acc = acc_array[-1] 52 | final_forgets = best_acc - end_acc 53 | avg_fgt = np.mean(final_forgets) 54 | return avg_fgt 55 | -------------------------------------------------------------------------------- /utils/name_match.py: -------------------------------------------------------------------------------- 1 | from agents.gdumb import Gdumb 2 | from continuum.dataset_scripts.cifar100 import CIFAR100 3 | from continuum.dataset_scripts.cifar10 import CIFAR10 4 | from continuum.dataset_scripts.core50 import CORE50 5 | from continuum.dataset_scripts.mini_imagenet import Mini_ImageNet 6 | from continuum.dataset_scripts.openloris import OpenLORIS 7 | from continuum.dataset_scripts.CelebA import CelebA 8 | from continuum.dataset_scripts.food101 import FOOD101 9 | from continuum.dataset_scripts.DTD import DTD 10 | from continuum.dataset_scripts.Aircraft import FGVCAIRCRAFT 11 | from continuum.dataset_scripts.Country211 import Country211 12 | from agents.exp_replay import ExperienceReplay 13 | from agents.agem import AGEM 14 | from agents.ewc_pp import EWC_pp 15 | from agents.cndpm import Cndpm 16 | from agents.lwf import Lwf 17 | from agents.icarl import Icarl 18 | from agents.FOAL import FOAL 19 | from agents.scr import SupContrastReplay 20 | from agents.pcr import ProxyContrastiveReplay 21 | from agents.exp_replay_dvc import ExperienceReplay_DVC 22 | from utils.buffer.random_retrieve import Random_retrieve 23 | from utils.buffer.reservoir_update import Reservoir_update 24 | from utils.buffer.mir_retrieve import MIR_retrieve 25 | from utils.buffer.gss_greedy_update import GSSGreedyUpdate 26 | from utils.buffer.aser_retrieve import ASER_retrieve 27 | from utils.buffer.aser_update import ASER_update 28 | from utils.buffer.sc_retrieve import Match_retrieve 29 | from utils.buffer.mem_match import MemMatch_retrieve 30 | from utils.buffer.mgi_retrieve import MGI_retrieve 31 | 32 | data_objects = { 33 | 'cifar100': CIFAR100, 34 | 'cifar10': CIFAR10, 35 | 'core50': CORE50, 36 | 'mini_imagenet': Mini_ImageNet, 37 | 'openloris': OpenLORIS, 38 | 'CelebA':CelebA, 39 | 'food101':FOOD101, 40 | 'DTD':DTD, 41 | 'aircraft':FGVCAIRCRAFT, 42 | 'Country211':Country211, 43 | } 44 | 45 | agents = { 46 | 'ER': ExperienceReplay, 47 | 'EWC': EWC_pp, 48 | 'AGEM': AGEM, 49 | 'CNDPM': Cndpm, 50 | 'LWF': Lwf, 51 | 'ICARL': Icarl, 52 | 'GDUMB': Gdumb, 53 | 'SCR': SupContrastReplay, 54 | 'FOAL': FOAL, 55 | 'PCR': ProxyContrastiveReplay, 56 | 'ER_DVC':ExperienceReplay_DVC 57 | } 58 | 59 | retrieve_methods = { 60 | 'MIR': MIR_retrieve, 61 | 'random': Random_retrieve, 62 | 'ASER': ASER_retrieve, 63 | 'match': Match_retrieve, 64 | 'mem_match': MemMatch_retrieve, 65 | 'MGI':MGI_retrieve 66 | 67 | } 68 | 69 | update_methods = { 70 | 'random': Reservoir_update, 71 | 'GSS': GSSGreedyUpdate, 72 | 'ASER': ASER_update 73 | } 74 | 75 | -------------------------------------------------------------------------------- /utils/buffer/reservoir_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Reservoir_update(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | 8 | def update(self, buffer, x, y, **kwargs): 9 | batch_size = x.size(0) 10 | 11 | # add whatever still fits in the buffer 12 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index) 13 | if place_left: 14 | offset = min(place_left, batch_size) 15 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset]) 16 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset]) 17 | 18 | 19 | buffer.current_index += offset 20 | buffer.n_seen_so_far += offset 21 | 22 | # everything was added 23 | if offset == x.size(0): 24 | filled_idx = list(range(buffer.current_index - offset, buffer.current_index, )) 25 | if buffer.params.buffer_tracker: 26 | buffer.buffer_tracker.update_cache(buffer.buffer_label, y[:offset], filled_idx) 27 | return filled_idx 28 | 29 | 30 | #TODO: the buffer tracker will have bug when the mem size can't be divided by batch size 31 | 32 | # remove what is already in the buffer 33 | x, y = x[place_left:], y[place_left:] 34 | 35 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long() 36 | valid_indices = (indices < buffer.buffer_img.size(0)).long() 37 | 38 | idx_new_data = valid_indices.nonzero().squeeze(-1) 39 | idx_buffer = indices[idx_new_data] 40 | 41 | buffer.n_seen_so_far += x.size(0) 42 | 43 | if idx_buffer.numel() == 0: 44 | return [] 45 | 46 | assert idx_buffer.max() < buffer.buffer_img.size(0) 47 | assert idx_buffer.max() < buffer.buffer_label.size(0) 48 | # assert idx_buffer.max() < self.buffer_task.size(0) 49 | 50 | assert idx_new_data.max() < x.size(0) 51 | assert idx_new_data.max() < y.size(0) 52 | 53 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))} 54 | 55 | replace_y = y[list(idx_map.values())] 56 | if buffer.params.buffer_tracker: 57 | buffer.buffer_tracker.update_cache(buffer.buffer_label, replace_y, list(idx_map.keys())) 58 | # perform overwrite op 59 | buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 60 | buffer.buffer_label[list(idx_map.keys())] = replace_y 61 | return list(idx_map.keys()) -------------------------------------------------------------------------------- /utils/buffer/mir_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import maybe_cuda 3 | import torch.nn.functional as F 4 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector 5 | import copy 6 | 7 | 8 | class MIR_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.params = params 12 | self.subsample = params.subsample 13 | self.num_retrieve = params.eps_mem_batch 14 | 15 | def retrieve(self, buffer, **kwargs): 16 | sub_x, sub_y = random_retrieve(buffer, self.subsample) 17 | grad_dims = [] 18 | for param in buffer.model.parameters(): 19 | grad_dims.append(param.data.numel()) 20 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 21 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 22 | if sub_x.size(0) > 0: 23 | with torch.no_grad(): 24 | logits_pre = buffer.model.forward(sub_x) 25 | logits_post = model_temp.forward(sub_x) 26 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none') 27 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none') 28 | scores = post_loss - pre_loss 29 | big_ind = scores.sort(descending=True)[1][:self.num_retrieve] 30 | return sub_x[big_ind], sub_y[big_ind] 31 | else: 32 | return sub_x, sub_y 33 | 34 | def get_future_step_parameters(self, model, grad_vector, grad_dims): 35 | """ 36 | computes \theta-\delta\theta 37 | :param this_net: 38 | :param grad_vector: 39 | :return: 40 | """ 41 | new_model = copy.deepcopy(model) 42 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims) 43 | with torch.no_grad(): 44 | for param in new_model.parameters(): 45 | if param.grad is not None: 46 | param.data = param.data - self.params.learning_rate * param.grad.data 47 | return new_model 48 | 49 | def overwrite_grad(self, pp, new_grad, grad_dims): 50 | """ 51 | This is used to overwrite the gradients with a new gradient 52 | vector, whenever violations occur. 53 | pp: parameters 54 | newgrad: corrected gradient 55 | grad_dims: list storing number of parameters at each layer 56 | """ 57 | cnt = 0 58 | for param in pp(): 59 | param.grad = torch.zeros_like(param.data) 60 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 61 | en = sum(grad_dims[:cnt + 1]) 62 | this_grad = new_grad[beg: en].contiguous().view( 63 | param.data.size()) 64 | param.grad.data.copy_(this_grad) 65 | cnt += 1 -------------------------------------------------------------------------------- /agents/FOAL.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch.nn.functional as F 7 | import torch 8 | import torch.nn.init as init 9 | 10 | 11 | class FOAL(ContinualLearner): 12 | def __init__(self, model, opt, params): 13 | super(FOAL, self).__init__(model, opt, params) 14 | self.R = ((torch.eye(params.projection)).float()).cuda() 15 | self.W = (init.zeros_(self.model.fc.weight.t())).double().cuda() 16 | self.R = self.R.double() 17 | 18 | 19 | def train_learner(self, x_train, y_train): 20 | self.before_train(x_train, y_train) 21 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 22 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, drop_last=True) 23 | losses_batch = AverageMeter() 24 | acc_batch = AverageMeter() 25 | self.model.eval() 26 | 27 | with torch.no_grad(): 28 | for ep in range(self.epoch): 29 | for i, batch_data in enumerate(train_loader): 30 | # batch update 31 | batch_x, batch_y = batch_data 32 | 33 | batch_x = maybe_cuda(batch_x, self.cuda) 34 | batch_y = maybe_cuda(batch_y, self.cuda) 35 | new_activation = self.model.expansion(batch_x) 36 | new_activation = new_activation.double() 37 | 38 | label_onehot = F.one_hot(batch_y,self.model.numclass) 39 | self.R = self.R - self.R @ new_activation.t() @ torch.pinverse( 40 | torch.eye(new_activation.size(0)).cuda(non_blocking=True) + 41 | new_activation @ self.R @ new_activation.t()) @ new_activation @ self.R 42 | self.W = self.W + self.R @ new_activation.t() @ (label_onehot - new_activation @ self.W) 43 | self.model.fc.weight = torch.nn.parameter.Parameter(torch.t(self.W.float())) 44 | logits = self.model(batch_x) 45 | _, pred_label = torch.max(logits, 1) 46 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 47 | 48 | # update tracker 49 | acc_batch.update(correct_cnt, batch_y.size(0)) 50 | if i % 100 == 1 and self.verbose: 51 | print( 52 | '==>>> it: {}, avg. loss: {:.6f}, ' 53 | 'running train acc: {:.3f}' 54 | .format(i, losses_batch.avg(), acc_batch.avg()) 55 | ) 56 | self.after_train() 57 | 58 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR100(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar100' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR100, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR100(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR100(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc': 36 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.test_set = [] 38 | for labels in self.task_labels: 39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 40 | self.test_set.append((x_test, y_test)) 41 | else: 42 | raise Exception('wrong scenario') 43 | 44 | def new_task(self, cur_task, **kwargs): 45 | if self.scenario == 'ni': 46 | x_train, y_train = self.train_set[cur_task] 47 | labels = set(y_train) 48 | elif self.scenario == 'nc': 49 | labels = self.task_labels[cur_task] 50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 51 | return x_train, y_train, labels 52 | 53 | def new_run(self, **kwargs): 54 | self.setup() 55 | return self.test_set 56 | 57 | def test_plot(self): 58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 59 | self.params.ns_factor) 60 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR10(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar10' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR10, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR10(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR10(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc': 36 | self.task_labels = create_task_composition(class_nums=10, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.test_set = [] 38 | for labels in self.task_labels: 39 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 40 | self.test_set.append((x_test, y_test)) 41 | else: 42 | raise Exception('wrong scenario') 43 | 44 | def new_task(self, cur_task, **kwargs): 45 | if self.scenario == 'ni': 46 | x_train, y_train = self.train_set[cur_task] 47 | labels = set(y_train) 48 | elif self.scenario == 'nc': 49 | labels = self.task_labels[cur_task] 50 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 51 | return x_train, y_train, labels 52 | 53 | def new_run(self, **kwargs): 54 | self.setup() 55 | return self.test_set 56 | 57 | def test_plot(self): 58 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 59 | self.params.ns_factor) 60 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/flower102.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class FLOWER102(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'flower102' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | 16 | super(FLOWER102, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 17 | 18 | 19 | def download_load(self): 20 | dataset_train = datasets.Flowers102(root=self.root, split='train', download=True) 21 | # self.train_data = dataset_train.data 22 | # self.train_label = np.array(dataset_train.target_transform) 23 | self.train_data, self.train_label = dataset_train 24 | dataset_test = datasets.Flowers102(root=self.root, split='test', download=True) 25 | # self.test_data = dataset_test 26 | # self.test_label = np.array(dataset_test.target_transform) 27 | self.test_data, self.test_label =dataset_test 28 | 29 | def setup(self): 30 | if self.scenario == 'ni': 31 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 32 | self.train_label, 33 | self.test_data, self.test_label, 34 | self.task_nums, 32, 35 | self.params.val_size, 36 | self.params.ns_type, self.params.ns_factor, 37 | plot=self.params.plot_sample) 38 | elif self.scenario == 'nc': 39 | self.task_labels = create_task_composition(class_nums=102, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 40 | self.test_set = [] 41 | for labels in self.task_labels: 42 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 43 | self.test_set.append((x_test, y_test)) 44 | else: 45 | raise Exception('wrong scenario') 46 | 47 | def new_task(self, cur_task, **kwargs): 48 | if self.scenario == 'ni': 49 | x_train, y_train = self.train_set[cur_task] 50 | labels = set(y_train) 51 | elif self.scenario == 'nc': 52 | labels = self.task_labels[cur_task] 53 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 54 | return x_train, y_train, labels 55 | 56 | def new_run(self, **kwargs): 57 | self.setup() 58 | return self.test_set 59 | 60 | def test_plot(self): 61 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 62 | self.params.ns_factor) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # FOAL 3 | 4 | This repository contains the implementation of the F-OAL: Forward-only Online Analytic Learning with 5 | Fast Training and Low Memory Footprint in Class Incremental Learning(https://arxiv.org/abs/2403.15751). 6 | 7 | 8 | ## Introduction 9 | 10 | Online Class Incremental Learning (OCIL) aims to train models incrementally, where data arrive in mini-batches, and previous data are not accessible. A major challenge in OCIL is Catastrophic Forgetting, i.e., the loss of previously learned knowledge. Among existing baselines, replay-based methods show competitive results but requires extra memory for storing exemplars, while exemplar-free (i.e., data need not be stored for replay in production) methods are resource-friendly but often lack accuracy. In this paper, we propose an exemplar-free approach—Forward-only Online Analytic Learning (F-OAL). Unlike traditional methods, F-OAL does not rely on back-propagation and is forward-only, significantly reducing memory usage and computational time. Cooperating with a pre-trained frozen encoder with Feature Fusion, F-OAL only needs to update a linear classifier by recursive least square. This approach simultaneously achieves high accuracy and low resource consumption. Extensive experiments on benchmark datasets demonstrate F-OAL's robust performance in OCIL scenarios. 11 | 12 | 13 | 14 | ## About Analytic Continual Learning 15 | Analytic Continual Learning (ACL) is a new branch for Continual Learning (CL). ACL redefines CL problem into a recursive least square form, which is solved in one epoch with high accuracy. F-OAL expand ACL to OCIL setting. Other ACL papers and codes can be found at (https://github.com/ZHUANGHP/Analytic-continual-learning) 16 | 17 | ## Reproduce the results in the paper 18 | Our code is built on [Survey](https://github.com/RaptorMai/online-continual-learning). Using following command can reproduce some of results. The reproduction of `SLCA`, `LAE` and `EASE` can be found at [PILOT](https://github.com/sun-hailong/LAMDA-PILOT). To reproduce results on `CORe`, please check [Survey](https://github.com/RaptorMai/online-continual-learning) about placing the related files. 19 | 20 | ``` 21 | python general_main.py --data cifar100 --cl_type nc --agent FOAL 22 | python general_main.py --data cifar100 --cl_type nc --agent LWF --learning_rate 0.001 23 | python general_main.py --data cifar100 --cl_type nc --agent EWC --fisher_update_after 50 --alpha 0.9 --lambda 100 --learning_rate 0.001 24 | python general_main.py --data cifar100 --cl_type nc --agent ER --retrieve random --update random --mem_size 5000 --learning_rate 0.001 25 | python general_main.py --data cifar100 --cl_type nc --agent ICARL --retrieve random --update random --mem_size 5000 --learning_rate 0.001 26 | python general_main.py --data cifar100 --cl_type nc --agent ER --update ASER --retrieve ASER --mem_size 5000 --aser_type asvm --n_smp_cls 1.5 --k 3 --learning_rate 0.001 27 | python general_main.py --data cifar100 --cl_type nc --agent SCR --retrieve random --update random --mem_size 5000 --head mlp --temp 0.07 --eps_mem_batch 10 --learning_rate 0.001 28 | python general_main.py --data cifar100 --cl_type nc --agent PCR --retrieve random --update random --mem_size 5000 --learning_rate 0.001 29 | python general_main.py --data cifar100 --cl_type nc --agent ER_DVC --retrieve MGI --update random --mem_size 5000 --dl_weight 4.0 --learning_rate 0.001 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /agents/scr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match, input_size_match 7 | from utils.utils import maybe_cuda, AverageMeter 8 | from kornia.augmentation import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale 9 | import torch.nn as nn 10 | 11 | class SupContrastReplay(ContinualLearner): 12 | def __init__(self, model, opt, params): 13 | super(SupContrastReplay, self).__init__(model, opt, params) 14 | self.buffer = Buffer(model, params) 15 | self.mem_size = params.mem_size 16 | self.eps_mem_batch = params.eps_mem_batch 17 | self.mem_iters = params.mem_iters 18 | self.transform = nn.Sequential( 19 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), scale=(0.2, 1.)), 20 | RandomHorizontalFlip(), 21 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 22 | RandomGrayscale(p=0.2) 23 | 24 | ) 25 | 26 | def train_learner(self, x_train, y_train): 27 | self.before_train(x_train, y_train) 28 | # set up loader 29 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 30 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 31 | drop_last=True) 32 | # set up model 33 | self.model = self.model.train() 34 | 35 | # setup tracker 36 | losses = AverageMeter() 37 | acc_batch = AverageMeter() 38 | 39 | for ep in range(self.epoch): 40 | for i, batch_data in enumerate(train_loader): 41 | # batch update 42 | batch_x, batch_y = batch_data 43 | batch_x = maybe_cuda(batch_x, self.cuda) 44 | batch_y = maybe_cuda(batch_y, self.cuda) 45 | 46 | for j in range(self.mem_iters): 47 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 48 | 49 | if mem_x.size(0) > 0: 50 | mem_x = maybe_cuda(mem_x, self.cuda) 51 | mem_y = maybe_cuda(mem_y, self.cuda) 52 | combined_batch = torch.cat((mem_x, batch_x)) 53 | combined_labels = torch.cat((mem_y, batch_y)) 54 | combined_batch_aug = self.transform(combined_batch) 55 | features = torch.cat([self.model.forward(combined_batch).unsqueeze(1), self.model.forward(combined_batch_aug).unsqueeze(1)], dim=1) 56 | loss = self.criterion(features, combined_labels) 57 | losses.update(loss, batch_y.size(0)) 58 | self.opt.zero_grad() 59 | loss.backward() 60 | self.opt.step() 61 | 62 | # update mem 63 | self.buffer.update(batch_x, batch_y) 64 | if i % 100 == 1 and self.verbose: 65 | print( 66 | '==>>> it: {}, avg. loss: {:.6f}, ' 67 | .format(i, losses.avg(), acc_batch.avg()) 68 | ) 69 | self.after_train() 70 | -------------------------------------------------------------------------------- /agents/icarl.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils import utils 4 | from utils.buffer.buffer_utils import random_retrieve 5 | from utils.setup_elements import transforms_match 6 | from torch.utils import data 7 | import numpy as np 8 | from torch.nn import functional as F 9 | from utils.utils import maybe_cuda, AverageMeter 10 | from utils.buffer.buffer import Buffer 11 | import torch 12 | import copy 13 | import tqdm 14 | 15 | 16 | class Icarl(ContinualLearner): 17 | def __init__(self, model, opt, params): 18 | super(Icarl, self).__init__(model, opt, params) 19 | self.model = model 20 | self.mem_size = params.mem_size 21 | self.buffer = Buffer(model, params) 22 | self.prev_model = None 23 | 24 | def train_learner(self, x_train, y_train): 25 | self.before_train(x_train, y_train) 26 | # set up loader 27 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 28 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 29 | drop_last=True) 30 | self.model.train() 31 | self.update_representation(train_loader) 32 | self.prev_model = copy.deepcopy(self.model) 33 | self.after_train() 34 | 35 | def update_representation(self, train_loader): 36 | updated_idx = [] 37 | for ep in range(self.epoch): 38 | pbar = tqdm.tqdm(enumerate(train_loader), desc='Re-Alignment', total=len(train_loader), unit='batch') 39 | # for i, train_data in enumerate(train_loader): 40 | for i, (train_x,train_y) in pbar: 41 | # batch update 42 | # train_x, train_y = train_data 43 | train_x = maybe_cuda(train_x, self.cuda) 44 | train_y = maybe_cuda(train_y, self.cuda) 45 | train_y_copy = train_y.clone() 46 | for k, y in enumerate(train_y_copy): 47 | train_y_copy[k] = len(self.old_labels) + self.new_labels.index(y) 48 | all_cls_num = len(self.new_labels) + len(self.old_labels) 49 | target_labels = utils.ohe_label(train_y_copy, all_cls_num, device=train_y_copy.device).float() 50 | if self.prev_model is not None: 51 | mem_x, mem_y = random_retrieve(self.buffer, self.batch, 52 | excl_indices=updated_idx) 53 | mem_x = maybe_cuda(mem_x, self.cuda) 54 | batch_x = torch.cat([train_x, mem_x]) 55 | target_labels = torch.cat([target_labels, torch.zeros_like(target_labels)]) 56 | else: 57 | batch_x = train_x 58 | logits = self.forward(batch_x) 59 | self.opt.zero_grad() 60 | if self.prev_model is not None: 61 | with torch.no_grad(): 62 | q = torch.sigmoid(self.prev_model.forward(batch_x)) 63 | for k, y in enumerate(self.old_labels): 64 | target_labels[:, k] = q[:, k] 65 | loss = F.binary_cross_entropy_with_logits(logits[:, :all_cls_num], target_labels, reduction='none').sum(dim=1).mean() 66 | loss.backward() 67 | self.opt.step() 68 | updated_idx += self.buffer.update(train_x, train_y) 69 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/food101.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | from PIL import Image 7 | 8 | class FOOD101(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'food101' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | 16 | super(FOOD101, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 17 | 18 | def download_load(self): 19 | train = datasets.Food101(root=self.root, split='train', download=True) 20 | test = datasets.Food101(root=self.root, split='test', download=True) 21 | self.train_data = [] 22 | self.train_label = [] 23 | self.test_data = [] 24 | self.test_label = [] 25 | for image, label in train: 26 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 27 | self.train_data.append(image) 28 | self.train_label.append(label) 29 | self.train_data = np.array(self.train_data) 30 | self.train_label = np.array(self.train_label) 31 | 32 | for image, label in test: 33 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 34 | self.test_data.append(image) 35 | self.test_label.append(label) 36 | self.test_data = np.array(self.test_data) 37 | self.test_label = np.array(self.test_label) 38 | 39 | # print(dataset_train.shape) 40 | 41 | def setup(self): 42 | if self.scenario == 'ni': 43 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 44 | self.train_label, 45 | self.test_data, self.test_label, 46 | self.task_nums, 32, 47 | self.params.val_size, 48 | self.params.ns_type, self.params.ns_factor, 49 | plot=self.params.plot_sample) 50 | elif self.scenario == 'nc': 51 | self.task_labels = create_task_composition(class_nums=101, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 52 | self.test_set = [] 53 | for labels in self.task_labels: 54 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 55 | self.test_set.append((x_test, y_test)) 56 | else: 57 | raise Exception('wrong scenario') 58 | 59 | def new_task(self, cur_task, **kwargs): 60 | if self.scenario == 'ni': 61 | x_train, y_train = self.train_set[cur_task] 62 | labels = set(y_train) 63 | elif self.scenario == 'nc': 64 | labels = self.task_labels[cur_task] 65 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 66 | return x_train, y_train, labels 67 | 68 | def new_run(self, **kwargs): 69 | self.setup() 70 | return self.test_set 71 | 72 | def test_plot(self): 73 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 74 | self.params.ns_factor) 75 | -------------------------------------------------------------------------------- /agents/gdumb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import math 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match, setup_architecture, setup_opt 7 | from utils.utils import maybe_cuda, EarlyStopping 8 | import numpy as np 9 | import random 10 | 11 | 12 | class Gdumb(ContinualLearner): 13 | def __init__(self, model, opt, params): 14 | super(Gdumb, self).__init__(model, opt, params) 15 | self.mem_img = {} 16 | self.mem_c = {} 17 | #self.early_stopping = EarlyStopping(self.params.min_delta, self.params.patience, self.params.cumulative_delta) 18 | 19 | def greedy_balancing_update(self, x, y): 20 | k_c = self.params.mem_size // max(1, len(self.mem_img)) 21 | if y not in self.mem_img or self.mem_c[y] < k_c: 22 | if sum(self.mem_c.values()) >= self.params.mem_size: 23 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0] 24 | idx = random.randrange(self.mem_c[cls_max]) 25 | self.mem_img[cls_max].pop(idx) 26 | self.mem_c[cls_max] -= 1 27 | if y not in self.mem_img: 28 | self.mem_img[y] = [] 29 | self.mem_c[y] = 0 30 | self.mem_img[y].append(x) 31 | self.mem_c[y] += 1 32 | 33 | def train_learner(self, x_train, y_train): 34 | self.before_train(x_train, y_train) 35 | # set up loader 36 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 37 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 38 | drop_last=True) 39 | 40 | for i, batch_data in enumerate(train_loader): 41 | # batch update 42 | batch_x, batch_y = batch_data 43 | batch_x = maybe_cuda(batch_x, self.cuda) 44 | batch_y = maybe_cuda(batch_y, self.cuda) 45 | # update mem 46 | for j in range(len(batch_x)): 47 | self.greedy_balancing_update(batch_x[j], batch_y[j].item()) 48 | #self.early_stopping.reset() 49 | self.train_mem() 50 | self.after_train() 51 | 52 | def train_mem(self): 53 | mem_x = [] 54 | mem_y = [] 55 | for i in self.mem_img.keys(): 56 | mem_x += self.mem_img[i] 57 | mem_y += [i] * self.mem_c[i] 58 | 59 | mem_x = torch.stack(mem_x) 60 | mem_y = torch.LongTensor(mem_y) 61 | self.model = setup_architecture(self.params) 62 | self.model = maybe_cuda(self.model, self.cuda) 63 | opt = setup_opt(self.params.optimizer, self.model, self.params.learning_rate, self.params.weight_decay) 64 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1, T_mult=2, eta_min=self.params.minlr) 65 | 66 | #loss = math.inf 67 | for i in range(self.params.mem_epoch): 68 | idx = np.random.permutation(len(mem_x)).tolist() 69 | mem_x = maybe_cuda(mem_x[idx], self.cuda) 70 | mem_y = maybe_cuda(mem_y[idx], self.cuda) 71 | self.model = self.model.train() 72 | batch_size = self.params.batch 73 | #scheduler.step() 74 | #if opt.param_groups[0]['lr'] == self.params.learning_rate: 75 | # if self.early_stopping.step(-loss): 76 | # return 77 | for j in range(len(mem_y) // batch_size): 78 | opt.zero_grad() 79 | logits = self.model.forward(mem_x[batch_size * j:batch_size * (j + 1)]) 80 | loss = self.criterion(logits, mem_y[batch_size * j:batch_size * (j + 1)]) 81 | loss.backward() 82 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip) 83 | opt.step() 84 | 85 | -------------------------------------------------------------------------------- /utils/buffer/mgi_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector 4 | import copy 5 | from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip,ColorJitter,RandomGrayscale 6 | from utils.setup_elements import transforms_match, input_size_match 7 | import torch.nn as nn 8 | from utils.setup_elements import n_classes 9 | 10 | 11 | class MGI_retrieve(object): 12 | def __init__(self, params, **kwargs): 13 | super().__init__() 14 | self.params = params 15 | self.subsample = params.subsample 16 | self.num_retrieve = params.eps_mem_batch 17 | self.transform = nn.Sequential( 18 | RandomResizedCrop(size=(input_size_match[self.params.data][1], input_size_match[self.params.data][2]), 19 | scale=(0.2, 1.)).cuda(), 20 | RandomHorizontalFlip().cuda(), 21 | ColorJitter(0.4, 0.4, 0.4, 0.1), 22 | RandomGrayscale(p=0.2) 23 | 24 | ) 25 | self.out_dim = n_classes[params.data] 26 | 27 | def retrieve(self, buffer, **kwargs): 28 | sub_x, sub_y = random_retrieve(buffer, self.subsample) 29 | grad_dims = [] 30 | for param in buffer.model.parameters(): 31 | grad_dims.append(param.data.numel()) 32 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 33 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 34 | if sub_x.size(0) > 0: 35 | with torch.no_grad(): 36 | sub_x_aug = self.transform(sub_x) 37 | logits_pre = buffer.model(sub_x,sub_x_aug) 38 | logits_post = model_temp(sub_x,sub_x_aug) 39 | 40 | z_pre, zt_pre, zzt_pre,fea_z_pre = logits_pre 41 | z_post, zt_post, zzt_post,fea_z_post = logits_post 42 | 43 | 44 | grads_pre_z= torch.sum(torch.abs(F.softmax(z_pre, dim=1) - F.one_hot(sub_y, self.out_dim)), 1) 45 | mgi_pre_z = grads_pre_z * fea_z_pre[0].reshape(-1) 46 | grads_post_z = torch.sum(torch.abs(F.softmax(z_post, dim=1) - F.one_hot(sub_y, self.out_dim)), 1) # N * 1 47 | mgi_post_z = grads_post_z * fea_z_post[0].reshape(-1) 48 | 49 | 50 | scores = mgi_post_z - mgi_pre_z 51 | 52 | big_ind = scores.sort(descending=True)[1][:int(self.num_retrieve)] 53 | 54 | 55 | 56 | return sub_x[big_ind], sub_x_aug[big_ind],sub_y[big_ind] 57 | else: 58 | return sub_x, sub_x,sub_y 59 | 60 | def get_future_step_parameters(self, model, grad_vector, grad_dims): 61 | """ 62 | computes \theta-\delta\theta 63 | :param this_net: 64 | :param grad_vector: 65 | :return: 66 | """ 67 | new_model = copy.deepcopy(model) 68 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims) 69 | with torch.no_grad(): 70 | for param in new_model.parameters(): 71 | if param.grad is not None: 72 | param.data = param.data - self.params.learning_rate * param.grad.data 73 | return new_model 74 | 75 | def overwrite_grad(self, pp, new_grad, grad_dims): 76 | """ 77 | This is used to overwrite the gradients with a new gradient 78 | vector, whenever violations occur. 79 | pp: parameters 80 | newgrad: corrected gradient 81 | grad_dims: list storing number of parameters at each layer 82 | """ 83 | cnt = 0 84 | for param in pp(): 85 | param.grad = torch.zeros_like(param.data) 86 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 87 | en = sum(grad_dims[:cnt + 1]) 88 | this_grad = new_grad[beg: en].contiguous().view( 89 | param.data.size()) 90 | param.grad.data.copy_(this_grad) 91 | cnt += 1 -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all'): 15 | super(SupConLoss, self).__init__() 16 | self.temperature = temperature 17 | self.contrast_mode = contrast_mode 18 | 19 | def forward(self, features, labels=None, mask=None): 20 | """Compute loss for model. If both `labels` and `mask` are None, 21 | it degenerates to SimCLR unsupervised loss: 22 | https://arxiv.org/pdf/2002.05709.pdf 23 | 24 | Args: 25 | features: hidden vector of shape [bsz, n_views, ...]. 26 | labels: ground truth of shape [bsz]. 27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 28 | has the same class as sample i. Can be asymmetric. 29 | Returns: 30 | A loss scalar. 31 | """ 32 | device = (torch.device('cuda') 33 | if features.is_cuda 34 | else torch.device('cpu')) 35 | 36 | if len(features.shape) < 3: 37 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 38 | 'at least 3 dimensions are required') 39 | if len(features.shape) > 3: 40 | features = features.view(features.shape[0], features.shape[1], -1) 41 | 42 | batch_size = features.shape[0] 43 | if labels is not None and mask is not None: 44 | raise ValueError('Cannot define both `labels` and `mask`') 45 | elif labels is None and mask is None: 46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 47 | elif labels is not None: 48 | labels = labels.contiguous().view(-1, 1) 49 | if labels.shape[0] != batch_size: 50 | raise ValueError('Num of labels does not match num of features') 51 | mask = torch.eq(labels, labels.T).float().to(device) 52 | else: 53 | mask = mask.float().to(device) 54 | 55 | contrast_count = features.shape[1] 56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 57 | if self.contrast_mode == 'one': 58 | anchor_feature = features[:, 0] 59 | anchor_count = 1 60 | elif self.contrast_mode == 'all': 61 | anchor_feature = contrast_feature 62 | anchor_count = contrast_count 63 | elif self.contrast_mode == 'proxy': 64 | anchor_feature = features[:, 0] 65 | contrast_feature = features[:, 1] 66 | anchor_count = 1 67 | contrast_count = 1 68 | else: 69 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 70 | 71 | # compute logits 72 | anchor_dot_contrast = torch.div( 73 | torch.matmul(anchor_feature, contrast_feature.T), 74 | self.temperature) 75 | # for numerical stability 76 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 77 | logits = anchor_dot_contrast - logits_max.detach() 78 | 79 | # tile mask 80 | mask = mask.repeat(anchor_count, contrast_count) 81 | # mask-out self-contrast cases 82 | logits_mask = torch.scatter( 83 | torch.ones_like(mask), 84 | 1, 85 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 86 | 0 87 | ) 88 | 89 | # compute log_prob 90 | if self.contrast_mode == 'proxy': 91 | exp_logits = torch.exp(logits) 92 | else: 93 | mask = mask * logits_mask 94 | exp_logits = torch.exp(logits) * logits_mask 95 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 96 | 97 | # compute mean of log-likelihood over positive 98 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 99 | 100 | # loss 101 | loss = -1 * mean_log_prob_pos 102 | loss = loss.view(anchor_count, batch_size).mean() 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | TEST_SPLIT = 1 / 6 8 | 9 | 10 | class Mini_ImageNet(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'mini_imagenet' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | super(Mini_ImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | train_in = open("datasets/mini_imagenet/mini-imagenet-cache-train.pkl", "rb") 22 | train = pickle.load(train_in) 23 | train_x = train["image_data"].reshape([64, 600, 84, 84, 3]) 24 | val_in = open("datasets/mini_imagenet/mini-imagenet-cache-val.pkl", "rb") 25 | val = pickle.load(val_in) 26 | val_x = val['image_data'].reshape([16, 600, 84, 84, 3]) 27 | test_in = open("datasets/mini_imagenet/mini-imagenet-cache-test.pkl", "rb") 28 | test = pickle.load(test_in) 29 | test_x = test['image_data'].reshape([20, 600, 84, 84, 3]) 30 | all_data = np.vstack((train_x, val_x, test_x)) 31 | train_data = [] 32 | train_label = [] 33 | test_data = [] 34 | test_label = [] 35 | for i in range(len(all_data)): 36 | cur_x = all_data[i] 37 | cur_y = np.ones((600,)) * i 38 | rdm_x, rdm_y = shuffle_data(cur_x, cur_y) 39 | x_test = rdm_x[: int(600 * TEST_SPLIT)] 40 | y_test = rdm_y[: int(600 * TEST_SPLIT)] 41 | x_train = rdm_x[int(600 * TEST_SPLIT):] 42 | y_train = rdm_y[int(600 * TEST_SPLIT):] 43 | train_data.append(x_train) 44 | train_label.append(y_train) 45 | test_data.append(x_test) 46 | test_label.append(y_test) 47 | self.train_data = np.concatenate(train_data) 48 | self.train_label = np.concatenate(train_label) 49 | self.test_data = np.concatenate(test_data) 50 | self.test_label = np.concatenate(test_label) 51 | 52 | def new_run(self, **kwargs): 53 | self.setup() 54 | return self.test_set 55 | 56 | def new_task(self, cur_task, **kwargs): 57 | if self.scenario == 'ni': 58 | x_train, y_train = self.train_set[cur_task] 59 | labels = set(y_train) 60 | elif self.scenario == 'nc': 61 | labels = self.task_labels[cur_task] 62 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 63 | else: 64 | raise Exception('unrecognized scenario') 65 | return x_train, y_train, labels 66 | 67 | def setup(self): 68 | if self.scenario == 'ni': 69 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 70 | self.train_label, 71 | self.test_data, self.test_label, 72 | self.task_nums, 84, 73 | self.params.val_size, 74 | self.params.ns_type, self.params.ns_factor, 75 | plot=self.params.plot_sample) 76 | 77 | elif self.scenario == 'nc': 78 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, 79 | fixed_order=self.params.fix_order) 80 | self.test_set = [] 81 | for labels in self.task_labels: 82 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 83 | self.test_set.append((x_test, y_test)) 84 | 85 | def test_plot(self): 86 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 87 | self.params.ns_factor) 88 | -------------------------------------------------------------------------------- /models/ndpm/component.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | from typing import Tuple 5 | 6 | from utils.utils import maybe_cuda 7 | from utils.global_vars import * 8 | 9 | 10 | class Component(nn.Module, ABC): 11 | def __init__(self, params, experts: Tuple): 12 | super().__init__() 13 | self.params = params 14 | self.experts = experts 15 | 16 | self.optimizer = NotImplemented 17 | self.lr_scheduler = NotImplemented 18 | 19 | @abstractmethod 20 | def nll(self, x, y, step=None): 21 | """Return NLL""" 22 | pass 23 | 24 | @abstractmethod 25 | def collect_nll(self, x, y, step=None): 26 | """Return NLLs including previous experts""" 27 | pass 28 | 29 | def _clip_grad_value(self, clip_value): 30 | for group in self.optimizer.param_groups: 31 | nn.utils.clip_grad_value_(group['params'], clip_value) 32 | 33 | def _clip_grad_norm(self, max_norm, norm_type=2): 34 | for group in self.optimizer.param_groups: 35 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type) 36 | 37 | def clip_grad(self): 38 | clip_grad_config = MODELS_NDPM_COMPONENT_CLIP_GRAD 39 | if clip_grad_config['type'] == 'value': 40 | self._clip_grad_value(**clip_grad_config['options']) 41 | elif clip_grad_config['type'] == 'norm': 42 | self._clip_grad_norm(**clip_grad_config['options']) 43 | else: 44 | raise ValueError('Invalid clip_grad type: {}' 45 | .format(clip_grad_config['type'])) 46 | 47 | @staticmethod 48 | def build_optimizer(optim_config, params): 49 | return getattr(torch.optim, optim_config['type'])( 50 | params, **optim_config['options']) 51 | 52 | @staticmethod 53 | def build_lr_scheduler(lr_config, optimizer): 54 | return getattr(torch.optim.lr_scheduler, lr_config['type'])( 55 | optimizer, **lr_config['options']) 56 | 57 | def weight_decay_loss(self): 58 | loss = maybe_cuda(torch.zeros([])) 59 | for param in self.parameters(): 60 | loss += torch.norm(param) ** 2 61 | return loss 62 | 63 | 64 | class ComponentG(Component, ABC): 65 | def setup_optimizer(self): 66 | self.optimizer = self.build_optimizer( 67 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 68 | self.lr_scheduler = self.build_lr_scheduler( 69 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G, self.optimizer) 70 | 71 | def collect_nll(self, x, y=None, step=None): 72 | """Default `collect_nll` 73 | 74 | Warning: Parameter-sharing components should implement their own 75 | `collect_nll` 76 | 77 | Returns: 78 | nll: Tensor of shape [B, 1+K] 79 | """ 80 | outputs = [expert.g.nll(x, y, step) for expert in self.experts] 81 | nll = outputs 82 | output = self.nll(x, y, step) 83 | nll.append(output) 84 | return torch.stack(nll, dim=1) 85 | 86 | 87 | 88 | class ComponentD(Component, ABC): 89 | def setup_optimizer(self): 90 | self.optimizer = self.build_optimizer( 91 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 92 | self.lr_scheduler = self.build_lr_scheduler( 93 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D, self.optimizer) 94 | 95 | def collect_forward(self, x): 96 | """Default `collect_forward` 97 | 98 | Warning: Parameter-sharing components should implement their own 99 | `collect_forward` 100 | 101 | Returns: 102 | output: Tensor of shape [B, 1+K, C] 103 | """ 104 | outputs = [expert.d(x) for expert in self.experts] 105 | outputs.append(self.forward(x)) 106 | return torch.stack(outputs, 1) 107 | 108 | def collect_nll(self, x, y, step=None): 109 | """Default `collect_nll` 110 | 111 | Warning: Parameter-sharing components should implement their own 112 | `collect_nll` 113 | 114 | Returns: 115 | nll: Tensor of shape [B, 1+K] 116 | """ 117 | outputs = [expert.d.nll(x, y, step) for expert in self.experts] 118 | nll = outputs 119 | output = self.nll(x, y, step) 120 | nll.append(output) 121 | return torch.stack(nll, dim=1) 122 | -------------------------------------------------------------------------------- /agents/lwf.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches as mpatches 9 | import copy 10 | 11 | 12 | class Lwf(ContinualLearner): 13 | def __init__(self, model, opt, params): 14 | super(Lwf, self).__init__(model, opt, params) 15 | self.k = 0 16 | self.done = [] 17 | 18 | def train_learner(self, x_train, y_train): 19 | self.before_train(x_train, y_train) 20 | 21 | # set up loader 22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 24 | drop_last=True) 25 | 26 | # set up model 27 | self.model = self.model.train() 28 | tasks = [ 29 | [22, 20, 25, 4], 30 | [10, 15, 28, 11], 31 | [18, 29, 27, 35], 32 | [37, 2, 39, 30], 33 | [34, 16, 36, 8], 34 | [13, 5, 17, 14], 35 | [33, 7, 32, 1], 36 | [26, 12, 31, 24], 37 | [6, 23, 21, 19], 38 | [9, 38, 3, 0] 39 | ] 40 | 41 | # setup tracker 42 | losses_batch = AverageMeter() 43 | acc_batch = AverageMeter() 44 | 45 | for ep in range(self.epoch): 46 | for i, batch_data in enumerate(train_loader): 47 | # batch update 48 | batch_x, batch_y = batch_data 49 | batch_x = maybe_cuda(batch_x, self.cuda) 50 | batch_y = maybe_cuda(batch_y, self.cuda) 51 | 52 | logits = self.forward(batch_x) 53 | loss_old = self.kd_manager.get_kd_loss(logits, batch_x) 54 | loss_new = self.criterion(logits, batch_y) 55 | loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old 56 | _, pred_label = torch.max(logits, 1) 57 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 58 | # update tracker 59 | acc_batch.update(correct_cnt, batch_y.size(0)) 60 | losses_batch.update(loss, batch_y.size(0)) 61 | # backward 62 | self.opt.zero_grad() 63 | loss.backward() 64 | self.opt.step() 65 | 66 | # 根据当前任务着色并绘制柱状图 67 | 68 | 69 | if i % 100 == 1 and self.verbose: 70 | print( 71 | '==>>> it: {}, avg. loss: {:.6f}, ' 72 | 'running train acc: {:.3f}' 73 | .format(i, losses_batch.avg(), acc_batch.avg()) 74 | ) 75 | weights = self.model.fc.weight.data 76 | weights = weights.cpu() 77 | row_norms = torch.norm(weights, p=2, dim=1) 78 | 79 | # 根据当前任务着色并绘制柱状图 80 | colors = [] 81 | 82 | for i in range(weights.size(0)): 83 | if i in tasks[self.k]: 84 | colors.append('red') # 当前任务的类别' 85 | self.done.append(i) 86 | elif i in self.done: 87 | colors.append('blue') # 已遍历过的任务的类别 88 | else: 89 | colors.append('yellow') # 未遍历的任务的类别 90 | 91 | plt.rc('axes', labelsize=20) # Axes labels size 92 | plt.bar(range(weights.size(0)), row_norms.numpy(), color=colors) 93 | plt.xlabel(f'Classes of DTD at Task {self.k + 1}', labelpad=20) 94 | plt.ylabel('L2 Norm of Weights') 95 | red_patch = mpatches.Patch(color='red', label='Current Task') 96 | blue_patch = mpatches.Patch(color='blue', label='Completed Tasks') 97 | yellow_patch = mpatches.Patch(color='yellow', label='Future Tasks') 98 | plt.subplots_adjust(bottom=0.2) 99 | plt.legend(handles=[red_patch, blue_patch, yellow_patch], loc='upper center', 100 | bbox_to_anchor=(0.5, -0.05), ncol=3, frameon=False) 101 | plt.savefig(f'recencybiasLWF{self.k + 1}.pdf') 102 | plt.close() 103 | 104 | # 任务计数器递增,并确保不超出范围 105 | self.k += 1 106 | self.after_train() 107 | -------------------------------------------------------------------------------- /utils/buffer/aser_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.buffer_utils import random_retrieve, ClassBalancedRandomSampling 3 | from utils.buffer.aser_utils import compute_knn_sv 4 | from utils.utils import maybe_cuda 5 | from utils.setup_elements import n_classes 6 | 7 | 8 | class ASER_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.num_retrieve = params.eps_mem_batch 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.aser_type = params.aser_type 16 | self.n_smp_cls = int(params.n_smp_cls) 17 | self.out_dim = n_classes[params.data] 18 | self.is_aser_upt = params.update == "ASER" 19 | ClassBalancedRandomSampling.class_index_cache = None 20 | 21 | def retrieve(self, buffer, **kwargs): 22 | model = buffer.model 23 | 24 | if buffer.n_seen_so_far <= self.mem_size: 25 | # Use random retrieval until buffer is filled 26 | ret_x, ret_y = random_retrieve(buffer, self.num_retrieve) 27 | else: 28 | # Use ASER retrieval if buffer is filled 29 | cur_x, cur_y = kwargs['x'], kwargs['y'] 30 | buffer_x, buffer_y = buffer.buffer_img, buffer.buffer_label 31 | ret_x, ret_y = self._retrieve_by_knn_sv(model, buffer_x, buffer_y, cur_x, cur_y, self.num_retrieve) 32 | return ret_x, ret_y 33 | 34 | def _retrieve_by_knn_sv(self, model, buffer_x, buffer_y, cur_x, cur_y, num_retrieve): 35 | """ 36 | Retrieves data instances with top-N Shapley Values from candidate set. 37 | Args: 38 | model (object): neural network. 39 | buffer_x (tensor): data buffer. 40 | buffer_y (tensor): label buffer. 41 | cur_x (tensor): current input data tensor. 42 | cur_y (tensor): current input label tensor. 43 | num_retrieve (int): number of data instances to be retrieved. 44 | Returns 45 | ret_x (tensor): retrieved data tensor. 46 | ret_y (tensor): retrieved label tensor. 47 | """ 48 | cur_x = maybe_cuda(cur_x) 49 | cur_y = maybe_cuda(cur_y) 50 | 51 | # Reset and update ClassBalancedRandomSampling cache if ASER update is not enabled 52 | if not self.is_aser_upt: 53 | ClassBalancedRandomSampling.update_cache(buffer_y, self.out_dim) 54 | 55 | # Get candidate data for retrieval (i.e., cand <- class balanced subsamples from memory) 56 | cand_x, cand_y, cand_ind = \ 57 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, device=self.device) 58 | 59 | # Type 1 - Adversarial SV 60 | # Get evaluation data for type 1 (i.e., eval <- current input) 61 | eval_adv_x, eval_adv_y = cur_x, cur_y 62 | # Compute adversarial Shapley value of candidate data 63 | # (i.e., sv wrt current input) 64 | sv_matrix_adv = compute_knn_sv(model, eval_adv_x, eval_adv_y, cand_x, cand_y, self.k, device=self.device) 65 | 66 | if self.aser_type != "neg_sv": 67 | # Type 2 - Cooperative SV 68 | # Get evaluation data for type 2 69 | # (i.e., eval <- class balanced subsamples from memory excluding those already in candidate set) 70 | excl_indices = set(cand_ind.tolist()) 71 | eval_coop_x, eval_coop_y, _ = \ 72 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, 73 | excl_indices=excl_indices, device=self.device) 74 | # Compute Shapley value 75 | sv_matrix_coop = \ 76 | compute_knn_sv(model, eval_coop_x, eval_coop_y, cand_x, cand_y, self.k, device=self.device) 77 | if self.aser_type == "asv": 78 | # Use extremal SVs for computation 79 | sv = sv_matrix_coop.max(0).values - sv_matrix_adv.min(0).values 80 | else: 81 | # Use mean variation for aser_type == "asvm" or anything else 82 | sv = sv_matrix_coop.mean(0) - sv_matrix_adv.mean(0) 83 | else: 84 | # aser_type == "neg_sv" 85 | # No Type 1 - Cooperative SV; Use sum of Adversarial SV only 86 | sv = sv_matrix_adv.sum(0) * -1 87 | 88 | ret_ind = sv.argsort(descending=True) 89 | 90 | ret_x = cand_x[ret_ind][:num_retrieve] 91 | ret_y = cand_y[ret_ind][:num_retrieve] 92 | return ret_x, ret_y 93 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def maybe_cuda(what, use_cuda=True, **kw): 5 | """ 6 | Moves `what` to CUDA and returns it, if `use_cuda` and it's available. 7 | Args: 8 | what (object): any object to move to eventually gpu 9 | use_cuda (bool): if we want to use gpu or cpu. 10 | Returns 11 | object: the same object but eventually moved to gpu. 12 | """ 13 | 14 | if use_cuda is not False and torch.cuda.is_available(): 15 | what = what.cuda() 16 | return what 17 | 18 | 19 | def boolean_string(s): 20 | if s not in {'False', 'True'}: 21 | raise ValueError('Not a valid boolean string') 22 | return s == 'True' 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n): 36 | self.sum += val * n 37 | self.count += n 38 | 39 | def avg(self): 40 | if self.count == 0: 41 | return 0 42 | return float(self.sum) / self.count 43 | 44 | 45 | def mini_batch_deep_features(model, total_x, num): 46 | """ 47 | Compute deep features with mini-batches. 48 | Args: 49 | model (object): neural network. 50 | total_x (tensor): data tensor. 51 | num (int): number of data. 52 | Returns 53 | deep_features (tensor): deep feature representation of data tensor. 54 | """ 55 | is_train = False 56 | if model.training: 57 | is_train = True 58 | model.eval() 59 | if hasattr(model, "features"): 60 | model_has_feature_extractor = True 61 | else: 62 | model_has_feature_extractor = False 63 | # delete the last fully connected layer 64 | modules = list(model.children())[:-1] 65 | # make feature extractor 66 | model_features = torch.nn.Sequential(*modules) 67 | 68 | with torch.no_grad(): 69 | bs = 64 70 | num_itr = num // bs + int(num % bs > 0) 71 | sid = 0 72 | deep_features_list = [] 73 | for i in range(num_itr): 74 | eid = sid + bs if i != num_itr - 1 else num 75 | batch_x = total_x[sid: eid] 76 | 77 | if model_has_feature_extractor: 78 | batch_deep_features_ = model.features(batch_x) 79 | else: 80 | batch_deep_features_ = torch.squeeze(model_features(batch_x)) 81 | 82 | deep_features_list.append(batch_deep_features_.reshape((batch_x.size(0), -1))) 83 | sid = eid 84 | if num_itr == 1: 85 | deep_features_ = deep_features_list[0] 86 | else: 87 | deep_features_ = torch.cat(deep_features_list, 0) 88 | if is_train: 89 | model.train() 90 | return deep_features_ 91 | 92 | 93 | def euclidean_distance(u, v): 94 | euclidean_distance_ = (u - v).pow(2).sum(1) 95 | return euclidean_distance_ 96 | 97 | 98 | def ohe_label(label_tensor, dim, device="cpu"): 99 | # Returns one-hot-encoding of input label tensor 100 | n_labels = label_tensor.size(0) 101 | zero_tensor = torch.zeros((n_labels, dim), device=device, dtype=torch.long) 102 | return zero_tensor.scatter_(1, label_tensor.reshape((n_labels, 1)), 1) 103 | 104 | 105 | def nonzero_indices(bool_mask_tensor): 106 | # Returns tensor which contains indices of nonzero elements in bool_mask_tensor 107 | return bool_mask_tensor.nonzero(as_tuple=True)[0] 108 | 109 | 110 | class EarlyStopping(): 111 | def __init__(self, min_delta, patience, cumulative_delta): 112 | self.min_delta = min_delta 113 | self.patience = patience 114 | self.cumulative_delta = cumulative_delta 115 | self.counter = 0 116 | self.best_score = None 117 | 118 | def step(self, score): 119 | if self.best_score is None: 120 | self.best_score = score 121 | elif score <= self.best_score + self.min_delta: 122 | if not self.cumulative_delta and score > self.best_score: 123 | self.best_score = score 124 | self.counter += 1 125 | if self.counter >= self.patience: 126 | return True 127 | else: 128 | self.best_score = score 129 | self.counter = 0 130 | return False 131 | 132 | def reset(self): 133 | self.counter = 0 134 | self.best_score = None 135 | -------------------------------------------------------------------------------- /agents/agem.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.buffer.buffer import Buffer 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class AGEM(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(AGEM, self).__init__(model, opt, params) 13 | self.buffer = Buffer(model, params) 14 | self.mem_size = params.mem_size 15 | self.eps_mem_batch = params.eps_mem_batch 16 | self.mem_iters = params.mem_iters 17 | 18 | def train_learner(self, x_train, y_train): 19 | self.before_train(x_train, y_train) 20 | 21 | # set up loader 22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 24 | drop_last=True) 25 | # set up model 26 | self.model = self.model.train() 27 | 28 | # setup tracker 29 | losses_batch = AverageMeter() 30 | acc_batch = AverageMeter() 31 | 32 | for ep in range(self.epoch): 33 | for i, batch_data in enumerate(train_loader): 34 | # batch update 35 | batch_x, batch_y = batch_data 36 | batch_x = maybe_cuda(batch_x, self.cuda) 37 | batch_y = maybe_cuda(batch_y, self.cuda) 38 | for j in range(self.mem_iters): 39 | logits = self.forward(batch_x) 40 | loss = self.criterion(logits, batch_y) 41 | if self.params.trick['kd_trick']: 42 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 43 | self.kd_manager.get_kd_loss(logits, batch_x) 44 | if self.params.trick['kd_trick_star']: 45 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 46 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x) 47 | _, pred_label = torch.max(logits, 1) 48 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 49 | # update tracker 50 | acc_batch.update(correct_cnt, batch_y.size(0)) 51 | losses_batch.update(loss, batch_y.size(0)) 52 | # backward 53 | self.opt.zero_grad() 54 | loss.backward() 55 | 56 | if self.task_seen > 0: 57 | # sample from memory of previous tasks 58 | mem_x, mem_y = self.buffer.retrieve() 59 | if mem_x.size(0) > 0: 60 | params = [p for p in self.model.parameters() if p.requires_grad] 61 | # gradient computed using current batch 62 | grad = [p.grad.clone() for p in params] 63 | mem_x = maybe_cuda(mem_x, self.cuda) 64 | mem_y = maybe_cuda(mem_y, self.cuda) 65 | mem_logits = self.forward(mem_x) 66 | loss_mem = self.criterion(mem_logits, mem_y) 67 | self.opt.zero_grad() 68 | loss_mem.backward() 69 | # gradient computed using memory samples 70 | grad_ref = [p.grad.clone() for p in params] 71 | 72 | # inner product of grad and grad_ref 73 | prod = sum([torch.sum(g * g_r) for g, g_r in zip(grad, grad_ref)]) 74 | if prod < 0: 75 | prod_ref = sum([torch.sum(g_r ** 2) for g_r in grad_ref]) 76 | # do projection 77 | grad = [g - prod / prod_ref * g_r for g, g_r in zip(grad, grad_ref)] 78 | # replace params' grad 79 | for g, p in zip(grad, params): 80 | p.grad.data.copy_(g) 81 | self.opt.step() 82 | # update mem 83 | self.buffer.update(batch_x, batch_y) 84 | 85 | if i % 100 == 1 and self.verbose: 86 | print( 87 | '==>>> it: {}, avg. loss: {:.6f}, ' 88 | 'running train acc: {:.3f}' 89 | .format(i, losses_batch.avg(), acc_batch.avg()) 90 | ) 91 | self.after_train() -------------------------------------------------------------------------------- /agents/pcr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from utils.buffer.buffer import Buffer 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match,transforms_aug 7 | from utils.utils import maybe_cuda 8 | from utils.loss import SupConLoss 9 | 10 | 11 | class ProxyContrastiveReplay(ContinualLearner): 12 | """ 13 | Proxy-based Contrastive Replay, 14 | Implements the strategy defined in 15 | "PCR: Proxy-based Contrastive Replay for Online Class-Incremental Continual Learning" 16 | https://arxiv.org/abs/2304.04408 17 | 18 | This strategy has been designed and tested in the 19 | Online Setting (OnlineCLScenario). However, it 20 | can also be used in non-online scenarios 21 | """ 22 | def __init__(self, model, opt, params): 23 | super(ProxyContrastiveReplay, self).__init__(model, opt, params) 24 | self.buffer = Buffer(model, params) 25 | self.mem_size = params.mem_size 26 | self.eps_mem_batch = params.eps_mem_batch 27 | self.mem_iters = params.mem_iters 28 | 29 | def train_learner(self, x_train, y_train): 30 | self.before_train(x_train, y_train) 31 | # set up loader 32 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 33 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 34 | drop_last=True) 35 | # set up model 36 | self.model = self.model.train() 37 | for ep in range(self.epoch): 38 | for i, batch_data in enumerate(train_loader): 39 | # batch update 40 | batch_x, batch_y = batch_data 41 | batch_x_aug = torch.stack([transforms_aug[self.data](batch_x[idx].cpu()) 42 | for idx in range(batch_x.size(0))]) 43 | batch_x = maybe_cuda(batch_x, self.cuda) 44 | batch_x_aug = maybe_cuda(batch_x_aug, self.cuda) 45 | batch_y = maybe_cuda(batch_y, self.cuda) 46 | batch_x_combine = torch.cat((batch_x, batch_x_aug)) 47 | batch_y_combine = torch.cat((batch_y, batch_y)) 48 | for j in range(self.mem_iters): 49 | logits, feas= self.model.pcrForward(batch_x_combine) 50 | novel_loss = 0*self.criterion(logits, batch_y_combine) 51 | self.opt.zero_grad() 52 | 53 | 54 | mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y) 55 | if mem_x.size(0) > 0: 56 | # mem_x, mem_y = Rotation(mem_x, mem_y) 57 | mem_x_aug = torch.stack([transforms_aug[self.data](mem_x[idx].cpu()) 58 | for idx in range(mem_x.size(0))]) 59 | mem_x = maybe_cuda(mem_x, self.cuda) 60 | mem_x_aug = maybe_cuda(mem_x_aug, self.cuda) 61 | mem_y = maybe_cuda(mem_y, self.cuda) 62 | mem_x_combine = torch.cat([mem_x, mem_x_aug]) 63 | mem_y_combine = torch.cat([mem_y, mem_y]) 64 | 65 | 66 | mem_logits, mem_fea= self.model.pcrForward(mem_x_combine) 67 | 68 | combined_feas = torch.cat([mem_fea, feas]) 69 | combined_labels = torch.cat((mem_y_combine, batch_y_combine)) 70 | combined_feas_aug = self.model.pcrLinear.L.weight[combined_labels] 71 | 72 | combined_feas_norm = torch.norm(combined_feas, p=2, dim=1).unsqueeze(1).expand_as(combined_feas) 73 | combined_feas_normalized = combined_feas.div(combined_feas_norm + 0.000001) 74 | 75 | combined_feas_aug_norm = torch.norm(combined_feas_aug, p=2, dim=1).unsqueeze(1).expand_as( 76 | combined_feas_aug) 77 | combined_feas_aug_normalized = combined_feas_aug.div(combined_feas_aug_norm + 0.000001) 78 | cos_features = torch.cat([combined_feas_normalized.unsqueeze(1), 79 | combined_feas_aug_normalized.unsqueeze(1)], 80 | dim=1) 81 | PSC = SupConLoss(temperature=0.09, contrast_mode='proxy') 82 | novel_loss += PSC(features=cos_features, labels=combined_labels) 83 | 84 | 85 | novel_loss.backward() 86 | self.opt.step() 87 | # update mem 88 | self.buffer.update(batch_x, batch_y) 89 | 90 | self.after_train() -------------------------------------------------------------------------------- /continuum/dataset_scripts/openloris.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from PIL import Image 3 | import numpy as np 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | import time 6 | from continuum.data_utils import shuffle_data 7 | 8 | 9 | class OpenLORIS(DatasetBase): 10 | """ 11 | tasks_nums is predefined and it depends on the ns_type. 12 | """ 13 | def __init__(self, scenario, params): # scenario refers to "ni" or "nc" 14 | dataset = 'openloris' 15 | self.ns_type = params.ns_type 16 | task_nums = openloris_ntask[self.ns_type] # ns_type can be (illumination, occlusion, pixel, clutter, sequence) 17 | super(OpenLORIS, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | s = time.time() 22 | self.train_set = [] 23 | for batch_num in range(1, self.task_nums+1): 24 | train_x = [] 25 | train_y = [] 26 | test_x = [] 27 | test_y = [] 28 | for i in range(len(datapath)): 29 | train_temp = glob.glob('datasets/openloris/' + self.ns_type + '/train/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 30 | 31 | train_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in train_temp]) 32 | train_y.extend([i] * len(train_temp)) 33 | 34 | test_temp = glob.glob( 35 | 'datasets/openloris/' + self.ns_type + '/test/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 36 | 37 | test_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in test_temp]) 38 | test_y.extend([i] * len(test_temp)) 39 | 40 | print(" --> batch{}'-dataset consisting of {} samples".format(batch_num, len(train_x))) 41 | print(" --> test'-dataset consisting of {} samples".format(len(test_x))) 42 | self.train_set.append((np.array(train_x), np.array(train_y))) 43 | self.test_set.append((np.array(test_x), np.array(test_y))) 44 | e = time.time() 45 | print('loading time: {}'.format(str(e - s))) 46 | 47 | def new_run(self, **kwargs): 48 | pass 49 | 50 | def new_task(self, cur_task, **kwargs): 51 | train_x, train_y = self.train_set[cur_task] 52 | # get val set 53 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 54 | val_size = int(len(train_x_rdm) * self.params.val_size) 55 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 56 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 57 | self.val_set.append((val_data_rdm, val_label_rdm)) 58 | labels = set(train_label_rdm) 59 | return train_data_rdm, train_label_rdm, labels 60 | 61 | def setup(self, **kwargs): 62 | pass 63 | 64 | 65 | 66 | openloris_ntask = { 67 | 'illumination': 9, 68 | 'occlusion': 9, 69 | 'pixel': 9, 70 | 'clutter': 9, 71 | 'sequence': 12 72 | } 73 | 74 | datapath = ['bottle_01', 'bottle_02', 'bottle_03', 'bottle_04', 'bowl_01', 'bowl_02', 'bowl_03', 'bowl_04', 'bowl_05', 75 | 'corkscrew_01', 'cottonswab_01', 'cottonswab_02', 'cup_01', 'cup_02', 'cup_03', 'cup_04', 'cup_05', 76 | 'cup_06', 'cup_07', 'cup_08', 'cup_10', 'cushion_01', 'cushion_02', 'cushion_03', 'glasses_01', 77 | 'glasses_02', 'glasses_03', 'glasses_04', 'knife_01', 'ladle_01', 'ladle_02', 'ladle_03', 'ladle_04', 78 | 'mask_01', 'mask_02', 'mask_03', 'mask_04', 'mask_05', 'paper_cutter_01', 'paper_cutter_02', 79 | 'paper_cutter_03', 'paper_cutter_04', 'pencil_01', 'pencil_02', 'pencil_03', 'pencil_04', 'pencil_05', 80 | 'plasticbag_01', 'plasticbag_02', 'plasticbag_03', 'plug_01', 'plug_02', 'plug_03', 'plug_04', 'pot_01', 81 | 'scissors_01', 'scissors_02', 'scissors_03', 'stapler_01', 'stapler_02', 'stapler_03', 'thermometer_01', 82 | 'thermometer_02', 'thermometer_03', 'toy_01', 'toy_02', 'toy_03', 'toy_04', 'toy_05','nail_clippers_01','nail_clippers_02', 83 | 'nail_clippers_03', 'bracelet_01', 'bracelet_02','bracelet_03', 'comb_01','comb_02', 84 | 'comb_03', 'umbrella_01','umbrella_02','umbrella_03','socks_01','socks_02','socks_03', 85 | 'toothpaste_01','toothpaste_02','toothpaste_03','wallet_01','wallet_02','wallet_03', 86 | 'headphone_01','headphone_02','headphone_03', 'key_01','key_02','key_03', 87 | 'battery_01', 'battery_02', 'mouse_01', 'pencilcase_01', 'pencilcase_02', 'tape_01', 88 | 'chopsticks_01', 'chopsticks_02', 'chopsticks_03', 89 | 'notebook_01', 'notebook_02', 'notebook_03', 90 | 'spoon_01', 'spoon_02', 'spoon_03', 91 | 'tissue_01', 'tissue_02', 'tissue_03', 92 | 'clamp_01', 'clamp_02', 'hat_01', 'hat_02', 'u_disk_01', 'u_disk_02', 'swimming_glasses_01' 93 | ] 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/DTD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | import os 7 | import pickle 8 | from PIL import Image 9 | 10 | class DTD(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'DTD' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | 18 | super(DTD, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 19 | 20 | 21 | def download_load(self): 22 | if os.path.exists('./datasets/DTD/DTD_train_data.pkl') and os.path.exists('./datasets/DTD/DTD_train_label.pkl') and \ 23 | os.path.exists('./datasets/DTD/DTD_test_data.pkl') and os.path.exists('./datasets/DTD/DTD_test_label.pkl'): 24 | with open('./datasets/DTD/DTD_train_data.pkl', 'rb') as f: 25 | self.train_data = pickle.load(f) 26 | with open('./datasets/DTD/DTD_train_label.pkl', 'rb') as f: 27 | self.train_label = pickle.load(f) 28 | with open('./datasets/DTD/DTD_test_data.pkl', 'rb') as f: 29 | self.test_data = pickle.load(f) 30 | with open('./datasets/DTD/DTD_test_label.pkl', 'rb') as f: 31 | self.test_label = pickle.load(f) 32 | else: 33 | train = datasets.DTD(root=self.root, split='train', download=True) 34 | test = datasets.DTD(root=self.root, split='test', download=True) 35 | self.train_data = [] 36 | self.train_label = [] 37 | self.test_data = [] 38 | self.test_label = [] 39 | for image, label in train: 40 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 41 | image = np.array(image) 42 | self.train_data.append(image) 43 | self.train_label.append(label) 44 | self.train_data = np.array(self.train_data) 45 | self.train_label = np.array(self.train_label) 46 | with open('./datasets/DTD/DTD_train_data.pkl', 'wb') as f: 47 | pickle.dump(self.train_data, f) 48 | with open('./datasets/DTD/DTD_train_label.pkl', 'wb') as f: 49 | pickle.dump(self.train_label, f) 50 | for image, label in test: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | self.test_data.append(image) 53 | self.test_label.append(label) 54 | self.test_data = np.array(self.test_data) 55 | self.test_label = np.array(self.test_label) 56 | with open('./datasets/DTD/DTD_test_data.pkl', 'wb') as f: 57 | pickle.dump(self.test_data, f) 58 | with open('./datasets/DTD/DTD_test_label.pkl', 'wb') as f: 59 | pickle.dump(self.test_label, f) 60 | 61 | 62 | def setup(self): 63 | if self.scenario == 'ni': 64 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 65 | self.train_label, 66 | self.test_data, self.test_label, 67 | self.task_nums, 32, 68 | self.params.val_size, 69 | self.params.ns_type, self.params.ns_factor, 70 | plot=self.params.plot_sample) 71 | elif self.scenario == 'nc': 72 | self.task_labels = create_task_composition(class_nums=47, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 73 | self.test_set = [] 74 | for labels in self.task_labels: 75 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 76 | self.test_set.append((x_test, y_test)) 77 | else: 78 | raise Exception('wrong scenario') 79 | 80 | def new_task(self, cur_task, **kwargs): 81 | if self.scenario == 'ni': 82 | x_train, y_train = self.train_set[cur_task] 83 | labels = set(y_train) 84 | elif self.scenario == 'nc': 85 | labels = self.task_labels[cur_task] 86 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 87 | return x_train, y_train, labels 88 | 89 | def new_run(self, **kwargs): 90 | self.setup() 91 | return self.test_set 92 | 93 | def test_plot(self): 94 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 95 | self.params.ns_factor) 96 | -------------------------------------------------------------------------------- /agents/ewc_pp.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | 8 | class EWC_pp(ContinualLearner): 9 | def __init__(self, model, opt, params): 10 | super(EWC_pp, self).__init__(model, opt, params) 11 | self.weights = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 12 | self.lambda_ = params.lambda_ 13 | self.alpha = params.alpha 14 | self.fisher_update_after = params.fisher_update_after 15 | self.prev_params = {} 16 | self.running_fisher = self.init_fisher() 17 | self.tmp_fisher = self.init_fisher() 18 | self.normalized_fisher = self.init_fisher() 19 | 20 | def train_learner(self, x_train, y_train): 21 | self.before_train(x_train, y_train) 22 | # set up loader 23 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 24 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 25 | drop_last=True) 26 | # setup tracker 27 | losses_batch = AverageMeter() 28 | acc_batch = AverageMeter() 29 | 30 | # set up model 31 | self.model.train() 32 | 33 | for ep in range(self.epoch): 34 | for i, batch_data in enumerate(train_loader): 35 | # batch update 36 | batch_x, batch_y = batch_data 37 | batch_x = maybe_cuda(batch_x, self.cuda) 38 | batch_y = maybe_cuda(batch_y, self.cuda) 39 | 40 | # update the running fisher 41 | if (ep * len(train_loader) + i + 1) % self.fisher_update_after == 0: 42 | self.update_running_fisher() 43 | 44 | out = self.forward(batch_x) 45 | loss = self.total_loss(out, batch_y) 46 | if self.params.trick['kd_trick']: 47 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 48 | self.kd_manager.get_kd_loss(out, batch_x) 49 | if self.params.trick['kd_trick_star']: 50 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 51 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x) 52 | # update tracker 53 | losses_batch.update(loss.item(), batch_y.size(0)) 54 | _, pred_label = torch.max(out, 1) 55 | acc = (pred_label == batch_y).sum().item() / batch_y.size(0) 56 | acc_batch.update(acc, batch_y.size(0)) 57 | # backward 58 | self.opt.zero_grad() 59 | loss.backward() 60 | 61 | # accumulate the fisher of current batch 62 | self.accum_fisher() 63 | self.opt.step() 64 | 65 | if i % 100 == 1 and self.verbose: 66 | print( 67 | '==>>> it: {}, avg. loss: {:.6f}, ' 68 | 'running train acc: {:.3f}' 69 | .format(i, losses_batch.avg(), acc_batch.avg()) 70 | ) 71 | 72 | # save params for current task 73 | for n, p in self.weights.items(): 74 | self.prev_params[n] = p.clone().detach() 75 | 76 | # update normalized fisher of current task 77 | max_fisher = max([torch.max(m) for m in self.running_fisher.values()]) 78 | min_fisher = min([torch.min(m) for m in self.running_fisher.values()]) 79 | for n, p in self.running_fisher.items(): 80 | self.normalized_fisher[n] = (p - min_fisher) / (max_fisher - min_fisher + 1e-32) 81 | self.after_train() 82 | 83 | def total_loss(self, inputs, targets): 84 | # cross entropy loss 85 | loss = self.criterion(inputs, targets) 86 | if len(self.prev_params) > 0: 87 | # add regularization loss 88 | reg_loss = 0 89 | for n, p in self.weights.items(): 90 | reg_loss += (self.normalized_fisher[n] * (p - self.prev_params[n]) ** 2).sum() 91 | loss += self.lambda_ * reg_loss 92 | return loss 93 | 94 | def init_fisher(self): 95 | return {n: p.clone().detach().fill_(0) for n, p in self.model.named_parameters() if p.requires_grad} 96 | 97 | def update_running_fisher(self): 98 | for n, p in self.running_fisher.items(): 99 | self.running_fisher[n] = (1. - self.alpha) * p \ 100 | + 1. / self.fisher_update_after * self.alpha * self.tmp_fisher[n] 101 | # reset the accumulated fisher 102 | self.tmp_fisher = self.init_fisher() 103 | 104 | def accum_fisher(self): 105 | for n, p in self.tmp_fisher.items(): 106 | if self.weights[n].grad is not None: 107 | p += self.weights[n].grad ** 2 -------------------------------------------------------------------------------- /continuum/dataset_scripts/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | import os 7 | import pickle 8 | from PIL import Image 9 | 10 | class tiny(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'tiny' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | 18 | super(tiny, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 19 | 20 | 21 | def download_load(self): 22 | if os.path.exists('./datasets/tiny/tiny_train_data.pkl') and os.path.exists('./datasets/tiny/tiny_train_label.pkl') and \ 23 | os.path.exists('./datasets/tiny/tiny_test_data.pkl') and os.path.exists('./datasets/tiny/tiny_test_label.pkl'): 24 | with open('./datasets/tiny/tiny_train_data.pkl', 'rb') as f: 25 | self.train_data = pickle.load(f) 26 | with open('./datasets/tiny/tiny_train_label.pkl', 'rb') as f: 27 | self.train_label = pickle.load(f) 28 | with open('./datasets/tiny/tiny_test_data.pkl', 'rb') as f: 29 | self.test_data = pickle.load(f) 30 | with open('./datasets/tiny/tiny_test_label.pkl', 'rb') as f: 31 | self.test_label = pickle.load(f) 32 | else: 33 | train = datasets.ImageFolder( './datasets/tinyimagenet/train') 34 | test = datasets.ImageFolder( './datasets/tinyimagenet/test') 35 | self.train_data = [] 36 | self.train_label = [] 37 | self.test_data = [] 38 | self.test_label = [] 39 | for image, label in train: 40 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 41 | image = np.array(image) 42 | self.train_data.append(image) 43 | self.train_label.append(label) 44 | self.train_data = np.array(self.train_data) 45 | self.train_label = np.array(self.train_label) 46 | with open('./datasets/tiny/tiny_train_data.pkl', 'wb') as f: 47 | pickle.dump(self.train_data, f) 48 | with open('./datasets/tiny/tiny_train_label.pkl', 'wb') as f: 49 | pickle.dump(self.train_label, f) 50 | for image, label in test: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | self.test_data.append(image) 53 | self.test_label.append(label) 54 | self.test_data = np.array(self.test_data) 55 | self.test_label = np.array(self.test_label) 56 | with open('./datasets/tiny/tiny_test_data.pkl', 'wb') as f: 57 | pickle.dump(self.test_data, f) 58 | with open('./datasets/tiny/tiny_test_label.pkl', 'wb') as f: 59 | pickle.dump(self.test_label, f) 60 | 61 | 62 | def setup(self): 63 | if self.scenario == 'ni': 64 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 65 | self.train_label, 66 | self.test_data, self.test_label, 67 | self.task_nums, 32, 68 | self.params.val_size, 69 | self.params.ns_type, self.params.ns_factor, 70 | plot=self.params.plot_sample) 71 | elif self.scenario == 'nc': 72 | self.task_labels = create_task_composition(class_nums=200, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 73 | self.test_set = [] 74 | for labels in self.task_labels: 75 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 76 | self.test_set.append((x_test, y_test)) 77 | else: 78 | raise Exception('wrong scenario') 79 | 80 | def new_task(self, cur_task, **kwargs): 81 | if self.scenario == 'ni': 82 | x_train, y_train = self.train_set[cur_task] 83 | labels = set(y_train) 84 | elif self.scenario == 'nc': 85 | labels = self.task_labels[cur_task] 86 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 87 | return x_train, y_train, labels 88 | 89 | def new_run(self, **kwargs): 90 | self.setup() 91 | return self.test_set 92 | 93 | def test_plot(self): 94 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 95 | self.params.ns_factor) 96 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/CelebA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | import os 7 | import pickle 8 | from PIL import Image 9 | 10 | class CelebA(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'CelebA' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | 18 | super(CelebA, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 19 | 20 | 21 | def download_load(self): 22 | if os.path.exists('./datasets/CelebA/CelebA_train_data.pkl') and os.path.exists('./datasets/CelebA/CelebA_train_label.pkl') and \ 23 | os.path.exists('./datasets/CelebA/CelebA_test_data.pkl') and os.path.exists('./datasets/CelebA/CelebA_test_label.pkl'): 24 | with open('./datasets/CelebA/CelebA_train_data.pkl', 'rb') as f: 25 | self.train_data = pickle.load(f) 26 | with open('./datasets/CelebA/CelebA_train_label.pkl', 'rb') as f: 27 | self.train_label = pickle.load(f) 28 | with open('./datasets/CelebA/CelebA_test_data.pkl', 'rb') as f: 29 | self.test_data = pickle.load(f) 30 | with open('./datasets/CelebA/CelebA_test_label.pkl', 'rb') as f: 31 | self.test_label = pickle.load(f) 32 | else: 33 | train = datasets.CelebA(root=self.root, split='train', download=True) 34 | test = datasets.CelebA(root=self.root, split='test', download=True) 35 | self.train_data = [] 36 | self.train_label = [] 37 | self.test_data = [] 38 | self.test_label = [] 39 | for image, label in train: 40 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 41 | image = np.array(image) 42 | self.train_data.append(image) 43 | self.train_label.append(label) 44 | self.train_data = np.array(self.train_data) 45 | self.train_label = np.array(self.train_label) 46 | with open('./datasets/CelebA/CelebA_train_data.pkl', 'wb') as f: 47 | pickle.dump(self.train_data, f) 48 | with open('./datasets/CelebA/CelebA_train_label.pkl', 'wb') as f: 49 | pickle.dump(self.train_label, f) 50 | for image, label in test: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | self.test_data.append(image) 53 | self.test_label.append(label) 54 | self.test_data = np.array(self.test_data) 55 | self.test_label = np.array(self.test_label) 56 | with open('./datasets/CelebA/CelebA_test_data.pkl', 'wb') as f: 57 | pickle.dump(self.test_data, f) 58 | with open('./datasets/CelebA/CelebA_test_label.pkl', 'wb') as f: 59 | pickle.dump(self.test_label, f) 60 | 61 | 62 | def setup(self): 63 | if self.scenario == 'ni': 64 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 65 | self.train_label, 66 | self.test_data, self.test_label, 67 | self.task_nums, 32, 68 | self.params.val_size, 69 | self.params.ns_type, self.params.ns_factor, 70 | plot=self.params.plot_sample) 71 | elif self.scenario == 'nc': 72 | self.task_labels = create_task_composition(class_nums=10177, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 73 | self.test_set = [] 74 | for labels in self.task_labels: 75 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 76 | self.test_set.append((x_test, y_test)) 77 | else: 78 | raise Exception('wrong scenario') 79 | 80 | def new_task(self, cur_task, **kwargs): 81 | if self.scenario == 'ni': 82 | x_train, y_train = self.train_set[cur_task] 83 | labels = set(y_train) 84 | elif self.scenario == 'nc': 85 | labels = self.task_labels[cur_task] 86 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 87 | return x_train, y_train, labels 88 | 89 | def new_run(self, **kwargs): 90 | self.setup() 91 | return self.test_set 92 | 93 | def test_plot(self): 94 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 95 | self.params.ns_factor) 96 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/place365.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | import os 7 | import pickle 8 | from PIL import Image 9 | 10 | class Place365(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'Place365' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | 18 | super(Place365, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 19 | 20 | 21 | def download_load(self): 22 | if os.path.exists('./datasets/Place365/Place365_train_data.pkl') and os.path.exists('./datasets/Place365/Place365_train_label.pkl') and \ 23 | os.path.exists('./datasets/Place365/Place365_test_data.pkl') and os.path.exists('./datasets/Place365/Place365_test_label.pkl'): 24 | with open('./datasets/Place365/Place365_train_data.pkl', 'rb') as f: 25 | self.train_data = pickle.load(f) 26 | with open('./datasets/Place365/Place365_train_label.pkl', 'rb') as f: 27 | self.train_label = pickle.load(f) 28 | with open('./datasets/Place365/Place365_test_data.pkl', 'rb') as f: 29 | self.test_data = pickle.load(f) 30 | with open('./datasets/Place365/Place365_test_label.pkl', 'rb') as f: 31 | self.test_label = pickle.load(f) 32 | else: 33 | train = datasets.Places365(root=self.root, split='train', download=True) 34 | test = datasets.Places365(root=self.root, split='test', download=True) 35 | self.train_data = [] 36 | self.train_label = [] 37 | self.test_data = [] 38 | self.test_label = [] 39 | for image, label in train: 40 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 41 | image = np.array(image) 42 | self.train_data.append(image) 43 | self.train_label.append(label) 44 | self.train_data = np.array(self.train_data) 45 | self.train_label = np.array(self.train_label) 46 | with open('./datasets/Place365/Place365_train_data.pkl', 'wb') as f: 47 | pickle.dump(self.train_data, f) 48 | with open('./datasets/Place365/Place365_train_label.pkl', 'wb') as f: 49 | pickle.dump(self.train_label, f) 50 | for image, label in test: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | self.test_data.append(image) 53 | self.test_label.append(label) 54 | self.test_data = np.array(self.test_data) 55 | self.test_label = np.array(self.test_label) 56 | with open('./datasets/Place365/Place365_test_data.pkl', 'wb') as f: 57 | pickle.dump(self.test_data, f) 58 | with open('./datasets/Place365/Place365_test_label.pkl', 'wb') as f: 59 | pickle.dump(self.test_label, f) 60 | 61 | 62 | def setup(self): 63 | if self.scenario == 'ni': 64 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 65 | self.train_label, 66 | self.test_data, self.test_label, 67 | self.task_nums, 32, 68 | self.params.val_size, 69 | self.params.ns_type, self.params.ns_factor, 70 | plot=self.params.plot_sample) 71 | elif self.scenario == 'nc': 72 | self.task_labels = create_task_composition(class_nums=365, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 73 | self.test_set = [] 74 | for labels in self.task_labels: 75 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 76 | self.test_set.append((x_test, y_test)) 77 | else: 78 | raise Exception('wrong scenario') 79 | 80 | def new_task(self, cur_task, **kwargs): 81 | if self.scenario == 'ni': 82 | x_train, y_train = self.train_set[cur_task] 83 | labels = set(y_train) 84 | elif self.scenario == 'nc': 85 | labels = self.task_labels[cur_task] 86 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 87 | return x_train, y_train, labels 88 | 89 | def new_run(self, **kwargs): 90 | self.setup() 91 | return self.test_set 92 | 93 | def test_plot(self): 94 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 95 | self.params.ns_factor) 96 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/Country211.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | import os 7 | import pickle 8 | from PIL import Image 9 | 10 | class Country211(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'Country211' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | 18 | super(Country211, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 19 | 20 | 21 | def download_load(self): 22 | if os.path.exists('./datasets/Country211/Country211_train_data.pkl') and os.path.exists('./datasets/Country211/Country211_train_label.pkl') and \ 23 | os.path.exists('./datasets/Country211/Country211_test_data.pkl') and os.path.exists('./datasets/Country211/Country211_test_label.pkl'): 24 | with open('./datasets/Country211/Country211_train_data.pkl', 'rb') as f: 25 | self.train_data = pickle.load(f) 26 | with open('./datasets/Country211/Country211_train_label.pkl', 'rb') as f: 27 | self.train_label = pickle.load(f) 28 | with open('./datasets/Country211/Country211_test_data.pkl', 'rb') as f: 29 | self.test_data = pickle.load(f) 30 | with open('./datasets/Country211/Country211_test_label.pkl', 'rb') as f: 31 | self.test_label = pickle.load(f) 32 | else: 33 | train = datasets.Country211(root=self.root, split='train', download=True) 34 | test = datasets.Country211(root=self.root, split='test', download=True) 35 | self.train_data = [] 36 | self.train_label = [] 37 | self.test_data = [] 38 | self.test_label = [] 39 | for image, label in train: 40 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 41 | image = np.array(image) 42 | self.train_data.append(image) 43 | self.train_label.append(label) 44 | self.train_data = np.array(self.train_data) 45 | self.train_label = np.array(self.train_label) 46 | with open('./datasets/Country211/Country211_train_data.pkl', 'wb') as f: 47 | pickle.dump(self.train_data, f) 48 | with open('./datasets/Country211/Country211_train_label.pkl', 'wb') as f: 49 | pickle.dump(self.train_label, f) 50 | for image, label in test: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | self.test_data.append(image) 53 | self.test_label.append(label) 54 | self.test_data = np.array(self.test_data) 55 | self.test_label = np.array(self.test_label) 56 | with open('./datasets/Country211/Country211_test_data.pkl', 'wb') as f: 57 | pickle.dump(self.test_data, f) 58 | with open('./datasets/Country211/Country211_test_label.pkl', 'wb') as f: 59 | pickle.dump(self.test_label, f) 60 | 61 | 62 | def setup(self): 63 | if self.scenario == 'ni': 64 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 65 | self.train_label, 66 | self.test_data, self.test_label, 67 | self.task_nums, 32, 68 | self.params.val_size, 69 | self.params.ns_type, self.params.ns_factor, 70 | plot=self.params.plot_sample) 71 | elif self.scenario == 'nc': 72 | self.task_labels = create_task_composition(class_nums=211, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 73 | self.test_set = [] 74 | for labels in self.task_labels: 75 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 76 | self.test_set.append((x_test, y_test)) 77 | else: 78 | raise Exception('wrong scenario') 79 | 80 | def new_task(self, cur_task, **kwargs): 81 | if self.scenario == 'ni': 82 | x_train, y_train = self.train_set[cur_task] 83 | labels = set(y_train) 84 | elif self.scenario == 'nc': 85 | labels = self.task_labels[cur_task] 86 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 87 | return x_train, y_train, labels 88 | 89 | def new_run(self, **kwargs): 90 | self.setup() 91 | return self.test_set 92 | 93 | def test_plot(self): 94 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 95 | self.params.ns_factor) 96 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/Aircraft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | from PIL import Image 7 | import pickle 8 | import os 9 | 10 | 11 | class FGVCAIRCRAFT(DatasetBase): 12 | def __init__(self, scenario, params): 13 | dataset = 'aircraft' 14 | if scenario == 'ni': 15 | num_tasks = len(params.ns_factor) 16 | else: 17 | num_tasks = params.num_tasks 18 | 19 | super(FGVCAIRCRAFT, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 20 | 21 | 22 | def download_load(self): 23 | if os.path.exists('./datasets/aircraft/aircraft_train_data.pkl') and os.path.exists('./datasets/aircraft/aircraft_train_label.pkl') and \ 24 | os.path.exists('./datasets/aircraft/aircraft_test_data.pkl') and os.path.exists('./datasets/aircraft/aircraft_test_label.pkl'): 25 | with open('./datasets/aircraft/aircraft_train_data.pkl', 'rb') as f: 26 | self.train_data = pickle.load(f) 27 | with open('./datasets/aircraft/aircraft_train_label.pkl', 'rb') as f: 28 | self.train_label = pickle.load(f) 29 | with open('./datasets/aircraft/aircraft_test_data.pkl', 'rb') as f: 30 | self.test_data = pickle.load(f) 31 | with open('./datasets/aircraft/aircraft_test_label.pkl', 'rb') as f: 32 | self.test_label = pickle.load(f) 33 | else: 34 | train = datasets.FGVCAircraft(root=self.root, split='train', download=True) 35 | test = datasets.FGVCAircraft(root=self.root, split='test', download=True) 36 | self.train_data = [] 37 | self.train_label = [] 38 | self.test_data = [] 39 | self.test_label = [] 40 | for image, label in train: 41 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 42 | # image = np.array(image) 43 | self.train_data.append(image) 44 | self.train_label.append(label) 45 | self.train_data = np.array(self.train_data) 46 | self.train_label = np.array(self.train_label) 47 | 48 | with open('./datasets/aircraft/aircraft_train_data.pkl', 'wb') as f: 49 | pickle.dump(self.train_data, f) 50 | with open('./datasets/aircraft/aircraft_train_label.pkl','wb') as f: 51 | pickle.dump(self.train_label,f) 52 | 53 | for image, label in test: 54 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 55 | self.test_data.append(image) 56 | self.test_label.append(label) 57 | self.test_data = np.array(self.test_data) 58 | self.test_label = np.array(self.test_label) 59 | 60 | with open('./datasets/aircraft/aircraft_test_data.pkl', 'wb') as f: 61 | pickle.dump(self.test_data, f) 62 | with open('./datasets/aircraft/aircraft_test_label.pkl','wb') as f: 63 | pickle.dump(self.test_label,f) 64 | 65 | 66 | def setup(self): 67 | if self.scenario == 'ni': 68 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 69 | self.train_label, 70 | self.test_data, self.test_label, 71 | self.task_nums, 32, 72 | self.params.val_size, 73 | self.params.ns_type, self.params.ns_factor, 74 | plot=self.params.plot_sample) 75 | elif self.scenario == 'nc': 76 | self.task_labels = create_task_composition(class_nums=102, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 77 | self.test_set = [] 78 | for labels in self.task_labels: 79 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 80 | self.test_set.append((x_test, y_test)) 81 | else: 82 | raise Exception('wrong scenario') 83 | 84 | def new_task(self, cur_task, **kwargs): 85 | if self.scenario == 'ni': 86 | x_train, y_train = self.train_set[cur_task] 87 | labels = set(y_train) 88 | elif self.scenario == 'nc': 89 | labels = self.task_labels[cur_task] 90 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 91 | return x_train, y_train, labels 92 | 93 | def new_run(self, **kwargs): 94 | self.setup() 95 | return self.test_set 96 | 97 | def test_plot(self): 98 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 99 | self.params.ns_factor) 100 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/inaturalist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from continuum.data_utils import create_task_composition, load_task_with_labels 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | import os 7 | import pickle 8 | from PIL import Image 9 | 10 | class inaturalist(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'inaturalist' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | 18 | super(inaturalist, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 19 | 20 | 21 | def download_load(self): 22 | if os.path.exists('./datasets/inaturalist/inaturalist_train_data.pkl') and os.path.exists('./datasets/inaturalist/inaturalist_train_label.pkl') and \ 23 | os.path.exists('./datasets/inaturalist/inaturalist_test_data.pkl') and os.path.exists('./datasets/inaturalist/inaturalist_test_label.pkl'): 24 | with open('./datasets/inaturalist/inaturalist_train_data.pkl', 'rb') as f: 25 | self.train_data = pickle.load(f) 26 | with open('./datasets/inaturalist/inaturalist_train_label.pkl', 'rb') as f: 27 | self.train_label = pickle.load(f) 28 | with open('./datasets/inaturalist/inaturalist_test_data.pkl', 'rb') as f: 29 | self.test_data = pickle.load(f) 30 | with open('./datasets/inaturalist/inaturalist_test_label.pkl', 'rb') as f: 31 | self.test_label = pickle.load(f) 32 | else: 33 | train = datasets.INaturalist(root=self.root, version='2021_train_mini', download=True) 34 | test = datasets.INaturalist(root=self.root, version='2021_validation', download=True) 35 | self.train_data = [] 36 | self.train_label = [] 37 | self.test_data = [] 38 | self.test_label = [] 39 | for image, label in train: 40 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 41 | image = np.array(image) 42 | self.train_data.append(image) 43 | self.train_label.append(label) 44 | self.train_data = np.array(self.train_data) 45 | self.train_label = np.array(self.train_label) 46 | with open('./datasets/inaturalist/inaturalist_train_data.pkl', 'wb') as f: 47 | pickle.dump(self.train_data, f) 48 | with open('./datasets/inaturalist/inaturalist_train_label.pkl', 'wb') as f: 49 | pickle.dump(self.train_label, f) 50 | for image, label in test: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | image = np.array(image) 53 | self.test_data.append(image) 54 | self.test_label.append(label) 55 | self.test_data = np.array(self.test_data) 56 | self.test_label = np.array(self.test_label) 57 | with open('./datasets/inaturalist/inaturalist_test_data.pkl', 'wb') as f: 58 | pickle.dump(self.test_data, f) 59 | with open('./datasets/inaturalist/inaturalist_test_label.pkl', 'wb') as f: 60 | pickle.dump(self.test_label, f) 61 | 62 | 63 | def setup(self): 64 | if self.scenario == 'ni': 65 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 66 | self.train_label, 67 | self.test_data, self.test_label, 68 | self.task_nums, 32, 69 | self.params.val_size, 70 | self.params.ns_type, self.params.ns_factor, 71 | plot=self.params.plot_sample) 72 | elif self.scenario == 'nc': 73 | self.task_labels = create_task_composition(class_nums=10000, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 74 | self.test_set = [] 75 | for labels in self.task_labels: 76 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 77 | self.test_set.append((x_test, y_test)) 78 | else: 79 | raise Exception('wrong scenario') 80 | 81 | def new_task(self, cur_task, **kwargs): 82 | if self.scenario == 'ni': 83 | x_train, y_train = self.train_set[cur_task] 84 | labels = set(y_train) 85 | elif self.scenario == 'nc': 86 | labels = self.task_labels[cur_task] 87 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 88 | return x_train, y_train, labels 89 | 90 | def new_run(self, **kwargs): 91 | self.setup() 92 | return self.test_set 93 | 94 | def test_plot(self): 95 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 96 | self.params.ns_factor) 97 | -------------------------------------------------------------------------------- /utils/buffer/aser_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.reservoir_update import Reservoir_update 3 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve 4 | from utils.buffer.aser_utils import compute_knn_sv, add_minority_class_input 5 | from utils.setup_elements import n_classes 6 | from utils.utils import nonzero_indices, maybe_cuda 7 | 8 | 9 | class ASER_update(object): 10 | def __init__(self, params, **kwargs): 11 | super().__init__() 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.num_tasks = params.num_tasks 16 | self.out_dim = n_classes[params.data] 17 | self.n_smp_cls = int(params.n_smp_cls) 18 | self.n_total_smp = int(params.n_smp_cls * self.out_dim) 19 | self.reservoir_update = Reservoir_update(params) 20 | ClassBalancedRandomSampling.class_index_cache = None 21 | 22 | def update(self, buffer, x, y, **kwargs): 23 | model = buffer.model 24 | 25 | place_left = self.mem_size - buffer.current_index 26 | 27 | # If buffer is not filled, use available space to store whole or part of batch 28 | if place_left: 29 | x_fit = x[:place_left] 30 | y_fit = y[:place_left] 31 | 32 | ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device) 33 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 34 | new_y=y_fit, ind=ind, device=self.device) 35 | self.reservoir_update.update(buffer, x_fit, y_fit) 36 | 37 | # If buffer is filled, update buffer by sv 38 | if buffer.current_index == self.mem_size: 39 | # remove what is already in the buffer 40 | cur_x, cur_y = x[place_left:], y[place_left:] 41 | self._update_by_knn_sv(model, buffer, cur_x, cur_y) 42 | 43 | def _update_by_knn_sv(self, model, buffer, cur_x, cur_y): 44 | """ 45 | Returns indices for replacement. 46 | Buffered instances with smallest SV are replaced by current input with higher SV. 47 | Args: 48 | model (object): neural network. 49 | buffer (object): buffer object. 50 | cur_x (tensor): current input data tensor. 51 | cur_y (tensor): current input label tensor. 52 | Returns 53 | ind_buffer (tensor): indices of buffered instances to be replaced. 54 | ind_cur (tensor): indices of current data to do replacement. 55 | """ 56 | cur_x = maybe_cuda(cur_x) 57 | cur_y = maybe_cuda(cur_y) 58 | 59 | # Find minority class samples from current input batch 60 | minority_batch_x, minority_batch_y = add_minority_class_input(cur_x, cur_y, self.mem_size, self.out_dim) 61 | 62 | # Evaluation set 63 | eval_x, eval_y, eval_indices = \ 64 | ClassBalancedRandomSampling.sample(buffer.buffer_img, buffer.buffer_label, self.n_smp_cls, 65 | device=self.device) 66 | 67 | # Concatenate minority class samples from current input batch to evaluation set 68 | eval_x = torch.cat((eval_x, minority_batch_x)) 69 | eval_y = torch.cat((eval_y, minority_batch_y)) 70 | 71 | # Candidate set 72 | cand_excl_indices = set(eval_indices.tolist()) 73 | cand_x, cand_y, cand_ind = random_retrieve(buffer, self.n_total_smp, cand_excl_indices, return_indices=True) 74 | 75 | # Concatenate current input batch to candidate set 76 | cand_x = torch.cat((cand_x, cur_x)) 77 | cand_y = torch.cat((cand_y, cur_y)) 78 | 79 | sv_matrix = compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, self.k, device=self.device) 80 | sv = sv_matrix.sum(0) 81 | 82 | n_cur = cur_x.size(0) 83 | n_cand = cand_x.size(0) 84 | 85 | # Number of previously buffered instances in candidate set 86 | n_cand_buf = n_cand - n_cur 87 | 88 | sv_arg_sort = sv.argsort(descending=True) 89 | 90 | # Divide SV array into two segments 91 | # - large: candidate args to be retained; small: candidate args to be discarded 92 | sv_arg_large = sv_arg_sort[:n_cand_buf] 93 | sv_arg_small = sv_arg_sort[n_cand_buf:] 94 | 95 | # Extract args relevant to replacement operation 96 | # If current data instances are in 'large' segment, they are added to buffer 97 | # If buffered instances are in 'small' segment, they are discarded from buffer 98 | # Replacement happens between these two sets 99 | # Retrieve original indices from candidate args 100 | ind_cur = sv_arg_large[nonzero_indices(sv_arg_large >= n_cand_buf)] - n_cand_buf 101 | arg_buffer = sv_arg_small[nonzero_indices(sv_arg_small < n_cand_buf)].cuda() 102 | cand_ind = cand_ind.cuda() 103 | ind_buffer = cand_ind[arg_buffer] 104 | buffer.n_seen_so_far += n_cur 105 | 106 | # perform overwrite op 107 | y_upt = cur_y[ind_cur] 108 | x_upt = cur_x[ind_cur] 109 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 110 | new_y=y_upt, ind=ind_buffer, device=self.device) 111 | buffer.buffer_img[ind_buffer] = x_upt 112 | buffer.buffer_label[ind_buffer] = y_upt 113 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/SUN397.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.data import random_split 3 | import numpy as np 4 | from torchvision import datasets, transforms 5 | from continuum.data_utils import create_task_composition, load_task_with_labels 6 | from continuum.dataset_scripts.dataset_base import DatasetBase 7 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 8 | import os 9 | import pickle 10 | from PIL import Image 11 | 12 | # 加载数据集 13 | 14 | 15 | class SUN397(DatasetBase): 16 | def __init__(self, scenario, params): 17 | dataset = 'SUN397' 18 | if scenario == 'ni': 19 | num_tasks = len(params.ns_factor) 20 | else: 21 | num_tasks = params.num_tasks 22 | 23 | super(SUN397, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 24 | 25 | 26 | def download_load(self): 27 | if os.path.exists('./datasets/SUN397/SUN397_train_data.pkl') and os.path.exists('./datasets/SUN397/SUN397_train_label.pkl') and \ 28 | os.path.exists('./datasets/SUN397/SUN397_test_data.pkl') and os.path.exists('./datasets/SUN397/SUN397_test_label.pkl'): 29 | with open('./datasets/SUN397/SUN397_train_data.pkl', 'rb') as f: 30 | self.train_data = pickle.load(f) 31 | with open('./datasets/SUN397/SUN397_train_label.pkl', 'rb') as f: 32 | self.train_label = pickle.load(f) 33 | with open('./datasets/SUN397/SUN397_test_data.pkl', 'rb') as f: 34 | self.test_data = pickle.load(f) 35 | with open('./datasets/SUN397/SUN397_test_label.pkl', 'rb') as f: 36 | self.test_label = pickle.load(f) 37 | else: 38 | dataset = torchvision.datasets.SUN397(root='path/to/dataset', download=True) 39 | 40 | # 划分数据集,例如划分为80%训练数据和20%测试数据 41 | train_size = int(0.8 * len(dataset)) 42 | test_size = len(dataset) - train_size 43 | train, test = random_split(dataset, [train_size, test_size]) 44 | # train = datasets.SUN397(root=self.root, split='train', download=True) 45 | # test = datasets.SUN397(root=self.root, split='test', download=True) 46 | self.train_data = [] 47 | self.train_label = [] 48 | self.test_data = [] 49 | self.test_label = [] 50 | for image, label in train: 51 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 52 | image = np.array(image) 53 | self.train_data.append(image) 54 | self.train_label.append(label) 55 | self.train_data = np.array(self.train_data) 56 | self.train_label = np.array(self.train_label) 57 | with open('./datasets/SUN397/SUN397_train_data.pkl', 'wb') as f: 58 | pickle.dump(self.train_data, f) 59 | with open('./datasets/SUN397/SUN397_train_label.pkl', 'wb') as f: 60 | pickle.dump(self.train_label, f) 61 | for image, label in test: 62 | image = image.resize((224, 224), Image.Resampling.LANCZOS) 63 | self.test_data.append(image) 64 | self.test_label.append(label) 65 | self.test_data = np.array(self.test_data) 66 | self.test_label = np.array(self.test_label) 67 | with open('./datasets/SUN397/SUN397_test_data.pkl', 'wb') as f: 68 | pickle.dump(self.test_data, f) 69 | with open('./datasets/SUN397/SUN397_test_label.pkl', 'wb') as f: 70 | pickle.dump(self.test_label, f) 71 | 72 | 73 | def setup(self): 74 | if self.scenario == 'ni': 75 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 76 | self.train_label, 77 | self.test_data, self.test_label, 78 | self.task_nums, 32, 79 | self.params.val_size, 80 | self.params.ns_type, self.params.ns_factor, 81 | plot=self.params.plot_sample) 82 | elif self.scenario == 'nc': 83 | self.task_labels = create_task_composition(class_nums=211, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 84 | self.test_set = [] 85 | for labels in self.task_labels: 86 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 87 | self.test_set.append((x_test, y_test)) 88 | else: 89 | raise Exception('wrong scenario') 90 | 91 | def new_task(self, cur_task, **kwargs): 92 | if self.scenario == 'ni': 93 | x_train, y_train = self.train_set[cur_task] 94 | labels = set(y_train) 95 | elif self.scenario == 'nc': 96 | labels = self.task_labels[cur_task] 97 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 98 | return x_train, y_train, labels 99 | 100 | def new_run(self, **kwargs): 101 | self.setup() 102 | return self.test_set 103 | 104 | def test_plot(self): 105 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 106 | self.params.ns_factor) -------------------------------------------------------------------------------- /models/pretrained.py: -------------------------------------------------------------------------------- 1 | 2 | import timm 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class cosLinear(nn.Module): 8 | def __init__(self, indim, outdim): 9 | super(cosLinear, self).__init__() 10 | self.L = nn.Linear(indim, outdim, bias=False) 11 | self.scale = 0.09 12 | 13 | def forward(self, x): 14 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 15 | x_normalized = x.div(x_norm + 0.000001) 16 | 17 | L_norm = torch.norm(self.L.weight, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 18 | weight_normalized = self.L.weight.div(L_norm + 0.000001) 19 | cos_dist = torch.mm(x_normalized, weight_normalized.transpose(0, 1)) 20 | scores = cos_dist / self.scale 21 | return scores 22 | 23 | 24 | class ModifiedViT(nn.Module): 25 | def __init__(self, n_classes): 26 | super(ModifiedViT, self).__init__() 27 | self.numclass=n_classes 28 | self.vit = timm.create_model("vit_base_patch16_224",pretrained=True) 29 | # self.cnn = timm.create_model('resnet34', pretrained=True) 30 | # self.vit = torchvision.models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) 31 | self.vit.head = nn.Identity() 32 | # self.vit.heads = nn.Identity() 33 | self.fc = nn.Linear(768,n_classes) 34 | self.pcrLinear = cosLinear(768, n_classes) 35 | 36 | def features(self, x): 37 | x = self.vit(x) 38 | return x 39 | 40 | def forward(self, x): 41 | x = self.vit(x) 42 | x = self.fc(x) 43 | return x 44 | 45 | def pcrForward(self, x): 46 | out = self.features(x) 47 | logits = self.pcrLinear(out) 48 | return logits, out 49 | 50 | class ModifiedViTDVC(nn.Module): 51 | def __init__(self, n_classes): 52 | super(ModifiedViTDVC, self).__init__() 53 | self.numclass=n_classes 54 | self.vit = timm.create_model("vit_base_patch16_224",pretrained=True) 55 | # self.vit = torchvision.models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) 56 | self.vit.head =nn.Identity() 57 | # self.vit.heads = nn.Identity() 58 | self.fc = nn.Linear(768,n_classes) 59 | self.pcrLinear = cosLinear(768, n_classes) 60 | 61 | def features(self, x): 62 | x = self.vit(x) 63 | return x 64 | 65 | def forward(self, x): 66 | out = self.vit(x) 67 | logits = self.fc(out) 68 | return logits,out 69 | 70 | class QNet(nn.Module): 71 | def __init__(self, 72 | n_units, 73 | n_classes): 74 | super(QNet, self).__init__() 75 | 76 | self.model = nn.Sequential( 77 | nn.Linear(2 * n_classes, n_units), 78 | nn.ReLU(True), 79 | nn.Linear(n_units, n_classes), 80 | ) 81 | 82 | def forward(self, zcat): 83 | zzt = self.model(zcat) 84 | return zzt 85 | 86 | 87 | class DVCNet(nn.Module): 88 | def __init__(self, 89 | backbone, 90 | n_units, 91 | n_classes, 92 | has_mi_qnet=True): 93 | super(DVCNet, self).__init__() 94 | 95 | self.backbone = backbone 96 | self.has_mi_qnet = has_mi_qnet 97 | 98 | if has_mi_qnet: 99 | self.qnet = QNet(n_units=n_units, 100 | n_classes=n_classes) 101 | 102 | def forward(self, x, xt): 103 | size = x.size(0) 104 | xx = torch.cat((x, xt)) 105 | zz,fea = self.backbone(xx) 106 | z = zz[0:size] 107 | zt = zz[size:] 108 | 109 | fea_z = fea[0:size] 110 | fea_zt = fea[size:] 111 | 112 | if not self.has_mi_qnet: 113 | return z, zt, None 114 | 115 | zcat = torch.cat((z, zt), dim=1) 116 | zzt = self.qnet(zcat) 117 | 118 | return z, zt, zzt,[torch.sum(torch.abs(fea_z), 1).reshape(-1, 1),torch.sum(torch.abs(fea_zt), 1).reshape(-1, 1)] 119 | 120 | 121 | def VIT_DVC(nclasses): 122 | """ 123 | Reduced ResNet18 as in GEM MIR(note that nf=20). 124 | """ 125 | backnone = ModifiedViTDVC(nclasses) 126 | return DVCNet(backbone=backnone,n_units=128,n_classes=nclasses,has_mi_qnet=True) 127 | 128 | 129 | 130 | ''' This is encoder for FOAL''' 131 | class Encoder(nn.Module): 132 | def __init__(self, n_classes,projection_size): 133 | super(Encoder, self).__init__() 134 | self.vit = timm.create_model('vit_base_patch16_224', pretrained=True) 135 | self.numclass = n_classes 136 | self.vit.head=nn.Identity() 137 | self.linear_projection=nn.Linear(768,projection_size,bias=False) 138 | self.fc=nn.Linear(projection_size,n_classes,bias=False) 139 | self.sig=nn.Sigmoid() 140 | 141 | 142 | def expansion(self,x): 143 | 144 | B = x.shape[0] # Batch size 145 | x = self.vit.patch_embed(x) 146 | 147 | cls_tokens = self.vit.cls_token.expand(B, -1, -1) 148 | x = torch.cat((cls_tokens, x), dim=1) 149 | x = x + self.vit.pos_embed 150 | x = self.vit.pos_drop(x) 151 | cumulative_output = 0 152 | concatenated_output = [] 153 | 154 | for blk in self.vit.blocks: 155 | x = blk(x) 156 | cumulative_output += x[:, 0] 157 | concatenated_output.append(x[:, 0]) 158 | x = cumulative_output / len(self.vit.blocks) 159 | x = self.vit.norm(x) 160 | x = self.linear_projection(x) 161 | x = self.sig(x) 162 | return x 163 | 164 | 165 | 166 | def forward(self, x): 167 | fusion_feature = self.expansion(x) 168 | x = self.fc(fusion_feature) 169 | return x 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /utils/setup_elements.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.resnet import SupConResNet 3 | from models.pretrained import ModifiedViT, Encoder,VIT_DVC 4 | from torchvision import transforms 5 | 6 | 7 | 8 | default_trick = {'labels_trick': False, 'kd_trick': False, 'separated_softmax': False, 9 | 'review_trick': False, 'ncm_trick': False, 'kd_trick_star': False} 10 | 11 | 12 | input_size_match = { 13 | 'cifar100': [3, 224, 224], 14 | 'cifar10': [3, 32, 32], 15 | 'core50': [3, 224, 224], 16 | 'mini_imagenet': [3, 84, 84], 17 | 'openloris': [3, 50, 50], 18 | 'flower102': [3, 224, 224], 19 | 'food101':[3,224,224], 20 | 'DTD': [3,224,224], 21 | 'aircraft': [3,224,224], 22 | 'CelebA':[3,224,224], 23 | 'Country211':[3,224,224], 24 | 'StanfordCars':[3,224,224] 25 | } 26 | 27 | 28 | n_classes = { 29 | 'cifar100': 100, 30 | 'cifar10': 10, 31 | 'core50': 50, 32 | 'mini_imagenet': 100, 33 | 'openloris': 69, 34 | 'flower102': 102, 35 | 'DTD': 40, 36 | 'aircraft':100, 37 | 'CelebA':10177, 38 | 'Country211':211, 39 | 'StanfordCars':196 40 | } 41 | 42 | 43 | transforms_match = { 44 | 'core50': transforms.Compose([ 45 | transforms.ToPILImage(), 46 | transforms.Resize((224,224)), 47 | transforms.ToTensor(), 48 | ]), 49 | 'cifar100': transforms.Compose([ 50 | transforms.ToPILImage(), 51 | transforms.Resize((224,224)), 52 | transforms.ToTensor(), 53 | ]), 54 | 'cifar10': transforms.Compose([ 55 | transforms.ToTensor(), 56 | ]), 57 | 'mini_imagenet': transforms.Compose([ 58 | transforms.ToTensor()]), 59 | 'openloris': transforms.Compose([ 60 | transforms.ToTensor()]), 61 | 'flower102': transforms.Compose([ 62 | transforms.ToPILImage(), 63 | transforms.Resize((224,224)), 64 | transforms.ToTensor(), 65 | ]), 66 | 'eurosat': transforms.Compose([ 67 | transforms.ToPILImage(), 68 | transforms.Resize((224,224)), 69 | transforms.ToTensor(), 70 | ]), 71 | 'DTD': transforms.Compose([ 72 | # transforms.Resize((224,224)), 73 | transforms.ToTensor(), 74 | ]), 75 | 'aircraft':transforms.Compose([ 76 | transforms.ToTensor(), 77 | ]), 78 | 'CelebA':transforms.Compose([ 79 | transforms.ToTensor(), 80 | ]), 81 | 'Country211': transforms.Compose([ 82 | transforms.ToTensor(), 83 | ]), 84 | 'StanfordCars': transforms.Compose([ 85 | transforms.ToTensor(), 86 | ]), 87 | 88 | } 89 | transforms_aug = { 90 | 'cifar100': transforms.Compose([ 91 | transforms.ToPILImage(), 92 | # transforms.RandomCrop(32, padding=4), 93 | transforms.Resize((224,224)), 94 | transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.RandomApply([ 97 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 98 | ], p=0.8), 99 | transforms.RandomGrayscale(p=0.2), 100 | transforms.ToTensor(), 101 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)) 102 | ]), 103 | 'DTD': transforms.Compose([ 104 | transforms.ToPILImage(), 105 | # transforms.RandomCrop(32, padding=4), 106 | transforms.Resize((224, 224)), 107 | transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)), 108 | transforms.RandomHorizontalFlip(), 109 | transforms.RandomApply([ 110 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 111 | ], p=0.8), 112 | transforms.RandomGrayscale(p=0.2), 113 | transforms.ToTensor(), 114 | ]), 115 | 'core50': transforms.Compose([ 116 | transforms.ToPILImage(), 117 | # transforms.RandomCrop(32, padding=4), 118 | transforms.Resize((224,224)), 119 | transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)), 120 | transforms.RandomHorizontalFlip(), 121 | transforms.RandomApply([ 122 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 123 | ], p=0.8), 124 | transforms.RandomGrayscale(p=0.2), 125 | transforms.ToTensor(), 126 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)) 127 | ]), 128 | 'aircraft': transforms.Compose([ 129 | transforms.ToPILImage(), 130 | # transforms.RandomCrop(32, padding=4), 131 | transforms.Resize((224,224)), 132 | transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)), 133 | transforms.RandomHorizontalFlip(), 134 | transforms.RandomApply([ 135 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 136 | ], p=0.8), 137 | transforms.RandomGrayscale(p=0.2), 138 | transforms.ToTensor(), 139 | # transforms.Normalize((0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)) 140 | ]) 141 | } 142 | 143 | def setup_architecture(params): 144 | nclass = n_classes[params.data] 145 | if params.agent in ['SCR', 'SCP']: 146 | if params.data == 'mini_imagenet': 147 | return SupConResNet(640, head=params.head) 148 | return SupConResNet(head=params.head) 149 | if params.agent == 'FOAL': 150 | return Encoder(nclass,params.projection) 151 | if params.agent == 'ER_DVC': 152 | return VIT_DVC(nclass) 153 | if params.agent == 'CNDPM': 154 | from models.ndpm.ndpm import Ndpm 155 | return Ndpm(params) 156 | return ModifiedViT(nclass) 157 | 158 | 159 | 160 | 161 | def setup_opt(optimizer, model, lr, wd): 162 | if optimizer == 'SGD': 163 | optim = torch.optim.SGD(model.parameters(), 164 | lr=lr, 165 | weight_decay=wd) 166 | elif optimizer == 'Adam': 167 | optim = torch.optim.Adam(model.parameters(), 168 | lr=lr, 169 | weight_decay=wd) 170 | else: 171 | raise Exception('wrong optimizer name') 172 | return optim 173 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/core50.py: -------------------------------------------------------------------------------- 1 | import os 2 | from continuum.dataset_scripts.dataset_base import DatasetBase 3 | import pickle as pkl 4 | import logging 5 | from hashlib import md5 6 | import numpy as np 7 | from PIL import Image 8 | from continuum.data_utils import shuffle_data, load_task_with_labels 9 | import time 10 | 11 | core50_ntask = { 12 | 'ni': 8, 13 | 'nc': 9, 14 | 'nic': 79, 15 | 'nicv2_79': 79, 16 | 'nicv2_196': 196, 17 | 'nicv2_391': 391 18 | } 19 | 20 | 21 | class CORE50(DatasetBase): 22 | def __init__(self, scenario, params): 23 | if isinstance(params.num_runs, int) and params.num_runs > 10: 24 | raise Exception('the max number of runs for CORE50 is 10') 25 | dataset = 'core50' 26 | task_nums = core50_ntask[scenario] 27 | super(CORE50, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 28 | 29 | 30 | def download_load(self): 31 | 32 | print("Loading paths...") 33 | with open(os.path.join(self.root, 'paths.pkl'), 'rb') as f: 34 | self.paths = pkl.load(f) 35 | 36 | print("Loading LUP...") 37 | with open(os.path.join(self.root, 'LUP.pkl'), 'rb') as f: 38 | self.LUP = pkl.load(f) 39 | 40 | print("Loading labels...") 41 | with open(os.path.join(self.root, 'labels.pkl'), 'rb') as f: 42 | self.labels = pkl.load(f) 43 | 44 | 45 | 46 | def setup(self, cur_run): 47 | self.val_set = [] 48 | self.test_set = [] 49 | print('Loading test set...') 50 | test_idx_list = self.LUP[self.scenario][cur_run][-1] 51 | 52 | #test paths 53 | test_paths = [] 54 | for idx in test_idx_list: 55 | test_paths.append(os.path.join(self.root, self.paths[idx])) 56 | 57 | # test imgs 58 | self.test_data = self.get_batch_from_paths(test_paths) 59 | self.test_label = np.asarray(self.labels[self.scenario][cur_run][-1]) 60 | 61 | 62 | if self.scenario == 'nc': 63 | self.task_labels = self.labels[self.scenario][cur_run][:-1] 64 | for labels in self.task_labels: 65 | labels = list(set(labels)) 66 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 67 | self.test_set.append((x_test, y_test)) 68 | elif self.scenario == 'ni': 69 | self.test_set = [(self.test_data, self.test_label)] 70 | 71 | def new_task(self, cur_task, **kwargs): 72 | cur_run = kwargs['cur_run'] 73 | s = time.time() 74 | train_idx_list = self.LUP[self.scenario][cur_run][cur_task] 75 | print("Loading data...") 76 | # Getting the actual paths 77 | train_paths = [] 78 | for idx in train_idx_list: 79 | train_paths.append(os.path.join(self.root, self.paths[idx])) 80 | # loading imgs 81 | train_x = self.get_batch_from_paths(train_paths) 82 | train_y = self.labels[self.scenario][cur_run][cur_task] 83 | train_y = np.asarray(train_y) 84 | # get val set 85 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 86 | val_size = int(len(train_x_rdm) * self.params.val_size) 87 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 88 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 89 | self.val_set.append((val_data_rdm, val_label_rdm)) 90 | e = time.time() 91 | print('loading time {}'.format(str(e-s))) 92 | return train_data_rdm, train_label_rdm, set(train_label_rdm) 93 | 94 | 95 | def new_run(self, **kwargs): 96 | cur_run = kwargs['cur_run'] 97 | self.setup(cur_run) 98 | 99 | 100 | @staticmethod 101 | def get_batch_from_paths(paths, compress=False, snap_dir='', 102 | on_the_fly=True, verbose=False): 103 | """ Given a number of abs. paths it returns the numpy array 104 | of all the images. """ 105 | 106 | # Getting root logger 107 | log = logging.getLogger('mylogger') 108 | 109 | # If we do not process data on the fly we check if the same train 110 | # filelist has been already processed and saved. If so, we load it 111 | # directly. In either case we end up returning x and y, as the full 112 | # training set and respective labels. 113 | num_imgs = len(paths) 114 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest() 115 | log.debug("Paths Hex: " + str(hexdigest)) 116 | loaded = False 117 | x = None 118 | file_path = None 119 | 120 | if compress: 121 | file_path = snap_dir + hexdigest + ".npz" 122 | if os.path.exists(file_path) and not on_the_fly: 123 | loaded = True 124 | with open(file_path, 'rb') as f: 125 | npzfile = np.load(f) 126 | x, y = npzfile['x'] 127 | else: 128 | x_file_path = snap_dir + hexdigest + "_x.bin" 129 | if os.path.exists(x_file_path) and not on_the_fly: 130 | loaded = True 131 | with open(x_file_path, 'rb') as f: 132 | x = np.fromfile(f, dtype=np.uint8) \ 133 | .reshape(num_imgs, 128, 128, 3) 134 | 135 | # Here we actually load the images. 136 | if not loaded: 137 | # Pre-allocate numpy arrays 138 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8) 139 | 140 | for i, path in enumerate(paths): 141 | if verbose: 142 | print("\r" + path + " processed: " + str(i + 1), end='') 143 | x[i] = np.array(Image.open(path)) 144 | 145 | if verbose: 146 | print() 147 | 148 | if not on_the_fly: 149 | # Then we save x 150 | if compress: 151 | with open(file_path, 'wb') as g: 152 | np.savez_compressed(g, x=x) 153 | else: 154 | x.tofile(snap_dir + hexdigest + "_x.bin") 155 | 156 | assert (x is not None), 'Problems loading data. x is None!' 157 | 158 | return x 159 | -------------------------------------------------------------------------------- /continuum/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils import data 4 | from utils.setup_elements import transforms_match 5 | import random 6 | 7 | 8 | def create_task_composition(class_nums, num_tasks, fixed_order=False): 9 | classes_per_task = class_nums // num_tasks 10 | total_classes = classes_per_task * num_tasks 11 | label_array = np.arange(0, total_classes) 12 | if not fixed_order: 13 | np.random.shuffle(label_array) 14 | 15 | task_labels = [] 16 | for tt in range(num_tasks): 17 | tt_offset = tt * classes_per_task 18 | task_labels.append(list(label_array[tt_offset:tt_offset + classes_per_task])) 19 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 20 | return task_labels 21 | 22 | 23 | '''All data''' 24 | # def create_task_composition(class_nums, num_tasks, fixed_order=False): 25 | # # Calculate the base number of classes per task and the total classes used for the uniform distribution 26 | # classes_per_task = class_nums // num_tasks 27 | # extra_classes = class_nums % num_tasks # This will give us the remaining classes to distribute 28 | # 29 | # label_array = np.arange(0, class_nums) # Include all class labels 30 | # if not fixed_order: 31 | # np.random.shuffle(label_array) 32 | # 33 | # task_labels = [] 34 | # extra_class_indices = 0 # Keep track of how many extra classes have been assigned 35 | # 36 | # for tt in range(num_tasks): 37 | # # Start with the base classes per task 38 | # tt_offset = tt * classes_per_task 39 | # task_classes = list(label_array[tt_offset:tt_offset + classes_per_task]) 40 | # 41 | # # Add an extra class to this task if there are any left to distribute 42 | # if extra_class_indices < extra_classes: 43 | # task_classes.append(label_array[class_nums - extra_classes + extra_class_indices]) 44 | # extra_class_indices += 1 45 | # 46 | # task_labels.append(task_classes) 47 | # print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 48 | # 49 | # return task_labels 50 | 51 | '''random or personalized''' 52 | # def generate_list(num_tasks,class_nums): 53 | # numbers = [] 54 | # total_sum = 0 55 | # for i in range(num_tasks): 56 | # number = random.randint(1, class_nums) 57 | # total_sum += number 58 | # numbers.append(number) 59 | # 60 | # # normalize the numbers to add up to total classes 61 | # for i in range(num_tasks): 62 | # numbers[i] = int(numbers[i] / total_sum * class_nums) 63 | # 64 | # # adjust for any rounding errors 65 | # remaining = class_nums - sum(numbers) 66 | # for i in range(remaining): 67 | # numbers[i] += 1 68 | # 69 | # random.shuffle(numbers) 70 | # return numbers 71 | # 72 | # 73 | # def create_task_composition(class_nums, num_tasks, fixed_order=False): 74 | # classes_per_task = generate_list(num_tasks,class_nums) 75 | # total_classes = sum(classes_per_task) 76 | # label_array = np.arange(0, total_classes) 77 | # if not fixed_order: 78 | # np.random.shuffle(label_array) 79 | # 80 | # task_labels = [] 81 | # start_idx = 0 82 | # for tt, num_classes in enumerate(classes_per_task): 83 | # task_labels.append(list(label_array[start_idx:start_idx + num_classes])) 84 | # start_idx += num_classes 85 | # print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 86 | # return task_labels 87 | 88 | 89 | def load_task_with_labels_torch(x, y, labels): 90 | tmp = [] 91 | for i in labels: 92 | tmp.append((y == i).nonzero().view(-1)) 93 | idx = torch.cat(tmp) 94 | return x[idx], y[idx] 95 | 96 | 97 | def load_task_with_labels(x, y, labels): 98 | tmp = [] 99 | for i in labels: 100 | tmp.append((np.where(y == i)[0])) 101 | idx = np.concatenate(tmp, axis=None) 102 | return x[idx], y[idx] 103 | 104 | 105 | 106 | class dataset_transform(data.Dataset): 107 | def __init__(self, x, y, transform=None): 108 | self.x = x 109 | self.y = torch.from_numpy(y).type(torch.LongTensor) 110 | self.transform = transform # save the transform 111 | 112 | def __len__(self): 113 | return len(self.y)#self.x.shape[0] # return 1 as we have only one image 114 | 115 | def __getitem__(self, idx): 116 | # return the augmented image 117 | if self.transform: 118 | x = self.transform(self.x[idx]) 119 | else: 120 | x = self.x[idx] 121 | 122 | return x.float(), self.y[idx] 123 | 124 | 125 | def setup_test_loader(test_data, params): 126 | test_loaders = [] 127 | 128 | for (x_test, y_test) in test_data: 129 | test_dataset = dataset_transform(x_test, y_test, transform=transforms_match[params.data]) 130 | test_loader = data.DataLoader(test_dataset, batch_size=params.test_batch, shuffle=True, num_workers=0,drop_last=True) 131 | test_loaders.append(test_loader) 132 | return test_loaders 133 | 134 | 135 | def shuffle_data(x, y): 136 | perm_inds = np.arange(0, x.shape[0]) 137 | np.random.shuffle(perm_inds) 138 | rdm_x = x[perm_inds] 139 | rdm_y = y[perm_inds] 140 | return rdm_x, rdm_y 141 | 142 | 143 | def train_val_test_split_ni(train_data, train_label, test_data, test_label, task_nums, img_size, val_size=0.1): 144 | train_data_rdm, train_label_rdm = shuffle_data(train_data, train_label) 145 | val_size = int(len(train_data_rdm) * val_size) 146 | val_data_rdm, val_label_rdm = train_data_rdm[:val_size], train_label_rdm[:val_size] 147 | train_data_rdm, train_label_rdm = train_data_rdm[val_size:], train_label_rdm[val_size:] 148 | test_data_rdm, test_label_rdm = shuffle_data(test_data, test_label) 149 | train_data_rdm_split = train_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 150 | train_label_rdm_split = train_label_rdm.reshape(task_nums, -1) 151 | val_data_rdm_split = val_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 152 | val_label_rdm_split = val_label_rdm.reshape(task_nums, -1) 153 | test_data_rdm_split = test_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 154 | test_label_rdm_split = test_label_rdm.reshape(task_nums, -1) 155 | return train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split --------------------------------------------------------------------------------