├── agents ├── __init__.py ├── cndpm.py ├── lwf.py ├── icarl.py ├── gdumb.py ├── agem.py └── ewc_pp.py ├── utils ├── __init__.py ├── buffer │ ├── __init__.py │ ├── recycle.py │ ├── sc_retrieve.py │ ├── mem_match.py │ ├── random_retrieve.py │ ├── rl_retrieve.py │ ├── tmp_buffer.py │ ├── buffer_logits.py │ ├── replay_times_update.py │ ├── aser_retrieve.py │ ├── test_buffer.py │ ├── reservoir_update.py │ ├── aser_update.py │ └── mir_retrieve.py ├── argparser │ ├── argparser_der.py │ ├── argparser_aug.py │ ├── argparser_scr.py │ └── argparser_replay.py ├── kd_manager.py ├── io.py ├── global_vars.py ├── loss.py ├── name_match.py └── setup_elements.py ├── continuum ├── __init__.py ├── dataset_scripts │ ├── __init__.py │ ├── dataset_base.py │ ├── cifar10.py │ ├── mini_imagenet.py │ ├── openloris.py │ └── cifar100.py ├── continuum.py └── data_utils.py ├── experiment ├── __init__.py ├── metrics.py └── tune_hyperparam.py ├── models ├── ndpm │ ├── __init__.py │ ├── utils.py │ ├── loss.py │ ├── priors.py │ ├── expert.py │ └── component.py ├── __init__.py ├── pretrained.py └── modelfactory.py ├── config ├── global.yml ├── agent │ ├── lwf │ │ ├── lwf.yml │ │ └── lwf_tune.yml │ ├── aser │ │ ├── aser_tune.yml │ │ └── aser.yml │ ├── er │ │ ├── offline_tune.yml │ │ ├── offline.yml │ │ ├── er_1k.yml │ │ ├── er_2k.yml │ │ ├── er_5k.yml │ │ ├── finetune.yml │ │ ├── er_10k.yml │ │ ├── er_rp_1k.yml │ │ ├── er_rp_5k.yml │ │ ├── er_2k_b100.yml │ │ ├── er_2k_b30.yml │ │ ├── er_rp_10k.yml │ │ ├── er_tune.yml │ │ ├── er_tune_miter.yml │ │ ├── er_tune_ratio.yml │ │ ├── er_tune_incoming_ratio.yml │ │ ├── er_tune_aug_iter_0.yml │ │ └── er_tune_aug_iter.yml │ ├── ewc │ │ ├── ewc.yml │ │ └── ewc_tune.yml │ ├── gdumb │ │ ├── gdumb_tune.yml │ │ ├── gdumb_1k.yml │ │ ├── gdumb_5k.yml │ │ └── gdumb_10k.yml │ ├── cndpm │ │ ├── cndpm_tune_core50_1.yml │ │ ├── cndpm_tune_mini_ni_1.yml │ │ ├── cndpm.yml │ │ ├── cndpm_10k.yml │ │ ├── cndpm_1k.yml │ │ ├── cndpm_2k.yml │ │ ├── cndpm_5k.yml │ │ ├── cndpm_tune.yml │ │ ├── cndpm_tune_core50.yml │ │ ├── cndpm_tune_mini_fast.yml │ │ ├── cndpm_tune_mini.yml │ │ └── cndpm_tune_mini_ni.yml │ ├── agem │ │ ├── agem_tune.yml │ │ ├── agem_1k.yml │ │ ├── agem_5k.yml │ │ └── agem_10k.yml │ ├── mir │ │ ├── mir_tune.yml │ │ ├── mir_10k.yml │ │ ├── mir_1k.yml │ │ └── mir_5k.yml │ ├── gss │ │ ├── gss_tune.yml │ │ ├── gss_tune_core50.yml │ │ ├── gss_10k.yml │ │ ├── gss_1k.yml │ │ └── gss_5k.yml │ └── icarl │ │ ├── icarl_1k.yml │ │ ├── icarl_5k.yml │ │ ├── icarl_10k.yml │ │ └── icarl_tune.yml ├── data │ ├── core50 │ │ ├── core50_ni.yml │ │ ├── core50_nic.yml │ │ ├── core50_nicv2_196.yml │ │ ├── core50_nicv2_391.yml │ │ ├── core50_nicv2_79.yml │ │ ├── core50_nc.yml │ │ ├── core50_nc_gpu0.yml │ │ └── core50_nc_gpu7.yml │ ├── cifar10 │ │ ├── cifar10_nc.yml │ │ ├── cifar10_noise.yml │ │ ├── cifar10_blur.yml │ │ └── cifar10_occlusion.yml │ ├── clrs25 │ │ ├── clrs25_nc.yml │ │ └── clrs25_nc_gpu7.yml │ ├── openloris │ │ ├── openloris_sequence.yml │ │ ├── openloris_clutter.yml │ │ ├── openloris_occlusion.yml │ │ ├── openloris_pixel.yml │ │ └── openloris_illumination.yml │ ├── cifar100 │ │ ├── cifar100_nc.yml │ │ ├── cifar100_nc_gpu7.yml │ │ ├── cifar100_noise.yml │ │ ├── cifar100_blur.yml │ │ └── cifar100_occlusion.yml │ └── mini_imagenet │ │ ├── mini_imagenet_nc.yml │ │ ├── mini_imagenet_nc_gpu7.yml │ │ ├── mini_imagenet_noise.yml │ │ ├── mini_imagenet_blur.yml │ │ └── mini_imagenet_occlusion.yml ├── general.yml ├── general_1_openloris.yml ├── general_2_openloris.yml ├── general_3_openloris.yml ├── general_2_cndpm_mini.yml ├── general_2.yml ├── general_3.yml ├── general_1_core50.yml ├── general_1_gdumb.yml ├── general_2_core50.yml ├── general_2_gdumb.yml ├── general_3_core50.yml ├── general_3_gdumb.yml ├── general_4_core50.yml ├── general_5_core50.yml ├── general_offline_1.yml ├── general_finetune_1.yml ├── ni_mini_general │ ├── general_1.yml │ ├── general_2.yml │ ├── general_3.yml │ ├── general_4.yml │ ├── general_5.yml │ ├── general_1_gdumb.yml │ └── general_2_gdumb.yml ├── general_1_core50_gdumb.yml ├── general_2_core50_gdumb.yml ├── general_3_core50_gdumb.yml ├── general_4_core50_gdumb.yml ├── general_5_core50_gdumb.yml ├── general_finetune_core50_1.yml ├── general_1_core50_offline.yml ├── general_2_core50_offline.yml ├── general_3_core50_offline.yml ├── general_4_core50_offline.yml ├── general_5_core50_offline.yml ├── general_labels_trick_3.yml ├── general_labels_trick_1.yml ├── general_labels_trick_2.yml ├── general_1.yml ├── general_17.yml ├── general_20.yml ├── general_1_aug_cifar10.yml ├── general_1_aug.yml ├── general_1_aug_seed2.yml └── general_1_aug_seed3.yml ├── requirements.txt ├── run_commands ├── base_commands │ ├── command_mir.sh │ ├── command_er.sh │ ├── command_scr.sh │ ├── command_aser.sh │ ├── command_er_scraug.sh │ ├── command_mir_raug.sh │ ├── command_er_raug_t1.sh │ ├── command_scr_raug.sh │ ├── command_er_raug.sh │ ├── command_mir_raug_t1.sh │ ├── command_scr_raug_t1.sh │ ├── command_aser_raug.sh │ ├── command_aser_raug_t1.sh │ ├── command_der_deraug.sh │ └── command_adaptive_rar_RL.sh └── runs │ ├── run_rar_with_scr.sh │ ├── run_test_rar_er_cifar100.sh │ ├── run_rar_with_er_mir_aser.sh │ ├── run_adaptive_RAR_RL.sh │ ├── run_rar_with_der_deraug.sh │ └── rar_ablation_aug_iter.sh ├── fetch_data_setup.sh ├── checkpoint_model.py ├── main_config.py ├── main_tune.py ├── README.md └── general_main.py /agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/ndpm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet -------------------------------------------------------------------------------- /config/global.yml: -------------------------------------------------------------------------------- 1 | path: 2 | tables: tables/ 3 | result: result/ 4 | -------------------------------------------------------------------------------- /config/agent/lwf/lwf.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: LWF 3 | temperature: 2 4 | model_name: LWF -------------------------------------------------------------------------------- /config/data/core50/core50_ni.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: ni 4 | data_name: core50_ni 5 | -------------------------------------------------------------------------------- /config/data/core50/core50_nic.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nic 4 | data_name: core50_nic 5 | -------------------------------------------------------------------------------- /config/agent/aser/aser_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.01, 0.1] 3 | k: [3, 5, 7] 4 | n_smp_cls: [2.0, 5.0, 7.0] -------------------------------------------------------------------------------- /config/agent/er/offline_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.001, 0.01, 0.1] 3 | weight_decay: [0.0001, 0.001] 4 | -------------------------------------------------------------------------------- /config/agent/ewc/ewc.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: EWC 3 | fisher_update_after: 50 4 | alpha: 0.9 5 | model_name: EWC 6 | 7 | -------------------------------------------------------------------------------- /config/data/core50/core50_nicv2_196.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nicv2_196 4 | data_name: core50_nicv2_196 5 | -------------------------------------------------------------------------------- /config/data/core50/core50_nicv2_391.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nicv2_391 4 | data_name: core50_nicv2_391 5 | -------------------------------------------------------------------------------- /config/data/core50/core50_nicv2_79.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nicv2_79 4 | data_name: core50_nicv2_79 5 | -------------------------------------------------------------------------------- /config/agent/gdumb/gdumb_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.001, 0.01, 0.1] 3 | weight_decay: [0.0001, 0.000001] 4 | 5 | -------------------------------------------------------------------------------- /config/data/cifar10/cifar10_nc.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar10 3 | num_tasks: 5 4 | cl_type: nc 5 | data_name: cifar10_nc 6 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune_core50_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.1] 3 | classifier_chill: [0.01] 4 | log_alpha: [-1200, -800, -300] -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune_mini_ni_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.1] 3 | classifier_chill: [0.01] 4 | log_alpha: [-2000, -500] 5 | -------------------------------------------------------------------------------- /config/data/core50/core50_nc.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nc 4 | data_name: core50_nc 5 | num_tasks: 9 6 | GPU_ID: 7 7 | -------------------------------------------------------------------------------- /config/data/clrs25/clrs25_nc.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: clrs25 3 | num_tasks: 5 4 | cl_type: nc 5 | data_name: clrs25_nc 6 | GPU_ID: 0 7 | 8 | -------------------------------------------------------------------------------- /config/data/core50/core50_nc_gpu0.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nc 4 | data_name: core50_nc 5 | num_tasks: 9 6 | GPU_ID: 0 7 | -------------------------------------------------------------------------------- /config/data/core50/core50_nc_gpu7.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: core50 3 | cl_type: nc 4 | data_name: core50_nc 5 | num_tasks: 9 6 | GPU_ID: 7 7 | -------------------------------------------------------------------------------- /config/data/openloris/openloris_sequence.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: openloris 3 | cl_type: ni 4 | ns_type: sequence 5 | data_name: openloris_seq -------------------------------------------------------------------------------- /config/agent/agem/agem_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 3 | weight_decay: [0.0001, 0.001, 0.01, 0.1] -------------------------------------------------------------------------------- /config/agent/ewc/ewc_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.0001, 0.001, 0.01, 0.1] 3 | weight_decay: [0.0001, 0.001] 4 | lambda_: [0, 100, 1000] -------------------------------------------------------------------------------- /config/agent/mir/mir_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.0001, 0.001, 0.01, 0.1] 3 | weight_decay: [0.0001, 0.001] 4 | subsample: [25, 50, 100] -------------------------------------------------------------------------------- /config/data/clrs25/clrs25_nc_gpu7.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: clrs25 3 | num_tasks: 5 4 | cl_type: nc 5 | data_name: clrs25_nc 6 | GPU_ID: 0 7 | 8 | -------------------------------------------------------------------------------- /config/data/openloris/openloris_clutter.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: openloris 3 | cl_type: ni 4 | ns_type: clutter 5 | data_name: openloris_clutter 6 | -------------------------------------------------------------------------------- /config/data/openloris/openloris_occlusion.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: openloris 3 | cl_type: ni 4 | ns_type: occlusion 5 | data_name: openloris_occlusion -------------------------------------------------------------------------------- /config/data/openloris/openloris_pixel.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: openloris 3 | cl_type: ni 4 | ns_type: pixel 5 | data_name: openloris_pixel 6 | 7 | -------------------------------------------------------------------------------- /config/agent/gss/gss_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.0001, 0.001, 0.01, 0.1] 3 | weight_decay: [0.0001, 0.001] 4 | gss_mem_strength: [10, 20, 50] -------------------------------------------------------------------------------- /config/agent/gss/gss_tune_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.001, 0.01, 0.1] 3 | weight_decay: [0.0001, 0.001] 4 | gss_mem_strength: [10, 20, 50] -------------------------------------------------------------------------------- /config/data/cifar100/cifar100_nc.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar100 3 | num_tasks: 20 4 | cl_type: nc 5 | data_name: cifar100_nc 6 | GPU_ID: 0 7 | 8 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: CNDPM 3 | stm_capacity: 1000 4 | weight_decay: 0.00001 5 | eps_mem_batch: 10 6 | model_name: CNDPM 7 | -------------------------------------------------------------------------------- /config/agent/er/offline.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 0 6 | eps_mem_batch: 0 7 | model_name: offline -------------------------------------------------------------------------------- /config/data/cifar100/cifar100_nc_gpu7.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar100 3 | num_tasks: 20 4 | cl_type: nc 5 | data_name: cifar100_nc 6 | GPU_ID: 7 7 | 8 | -------------------------------------------------------------------------------- /config/data/openloris/openloris_illumination.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: openloris 3 | cl_type: ni 4 | ns_type: illumination 5 | data_name: openloris_illum 6 | -------------------------------------------------------------------------------- /config/agent/agem/agem_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: AGEM 3 | update: random 4 | retrieve: random 5 | mem_size: 1000 6 | eps_mem_batch: 10 7 | model_name: AGEM_1k -------------------------------------------------------------------------------- /config/agent/agem/agem_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: AGEM 3 | update: random 4 | retrieve: random 5 | mem_size: 5000 6 | eps_mem_batch: 10 7 | model_name: AGEM_5k -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: CNDPM 3 | stm_capacity: 10000 4 | weight_decay: 0.00001 5 | eps_mem_batch: 10 6 | model_name: CNDPM_10k 7 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: CNDPM 3 | stm_capacity: 1000 4 | weight_decay: 0.00001 5 | eps_mem_batch: 10 6 | model_name: CNDPM_1k 7 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_2k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: CNDPM 3 | stm_capacity: 2000 4 | weight_decay: 0.00001 5 | eps_mem_batch: 10 6 | model_name: CNDPM_2k 7 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: CNDPM 3 | stm_capacity: 5000 4 | weight_decay: 0.00001 5 | eps_mem_batch: 10 6 | model_name: CNDPM_5k 7 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.0001, 0.001, 0.01, 0.1] 3 | classifier_chill: [0.001, 0.01, 0.1] 4 | log_alpha: [-100, -300, -500] -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.01] 3 | classifier_chill: [0.0005, 0.001, 0.002] 4 | log_alpha: [-1200, -1000, -800, -300] -------------------------------------------------------------------------------- /config/agent/er/er_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 1000 6 | eps_mem_batch: 10 7 | model_name: Naive_ER_1k -------------------------------------------------------------------------------- /config/agent/er/er_2k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 2000 6 | eps_mem_batch: 10 7 | model_name: Naive_ER_2k -------------------------------------------------------------------------------- /config/agent/er/er_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 5000 6 | eps_mem_batch: 10 7 | model_name: Naive_ER_5k -------------------------------------------------------------------------------- /config/agent/er/finetune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 0 6 | eps_mem_batch: 0 7 | model_name: finetune 8 | -------------------------------------------------------------------------------- /config/agent/gdumb/gdumb_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: GDUMB 3 | mem_size: 1000 4 | mem_epoch: 30 5 | minlr: 0.0005 6 | clip: 10.0 7 | model_name: GDUMB_1k 8 | -------------------------------------------------------------------------------- /config/agent/gdumb/gdumb_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: GDUMB 3 | mem_size: 5000 4 | mem_epoch: 30 5 | minlr: 0.0005 6 | clip: 10.0 7 | model_name: GDUMB_5k 8 | -------------------------------------------------------------------------------- /config/agent/mir/mir_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: MIR 5 | mem_size: 10000 6 | eps_mem_batch: 10 7 | model_name: MIR_10k 8 | -------------------------------------------------------------------------------- /config/agent/mir/mir_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: MIR 5 | mem_size: 1000 6 | eps_mem_batch: 10 7 | model_name: MIR_1k 8 | -------------------------------------------------------------------------------- /config/agent/mir/mir_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: MIR 5 | mem_size: 5000 6 | eps_mem_batch: 10 7 | model_name: MIR_5k 8 | -------------------------------------------------------------------------------- /config/agent/agem/agem_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: AGEM 3 | update: random 4 | retrieve: random 5 | mem_size: 10000 6 | eps_mem_batch: 10 7 | model_name: AGEM_10k -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune_mini_fast.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.01, 0.1] 3 | classifier_chill: [0.001, 0.0015, 0.002] 4 | log_alpha: [-1200, -1000, -800] 5 | -------------------------------------------------------------------------------- /config/agent/er/er_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 10000 6 | eps_mem_batch: 10 7 | model_name: Naive_ER_10k -------------------------------------------------------------------------------- /config/agent/er/er_rp_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 1000 6 | eps_mem_batch: 50 7 | model_name: RP_ER_1k 8 | -------------------------------------------------------------------------------- /config/agent/er/er_rp_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 5000 6 | eps_mem_batch: 50 7 | model_name: RP_ER_5k 8 | -------------------------------------------------------------------------------- /config/agent/gdumb/gdumb_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: GDUMB 3 | mem_size: 10000 4 | mem_epoch: 30 5 | minlr: 0.0005 6 | clip: 10.0 7 | model_name: GDUMB_10k 8 | -------------------------------------------------------------------------------- /config/agent/icarl/icarl_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ICARL 3 | update: random 4 | retrieve: random 5 | mem_size: 1000 6 | eps_mem_batch: 10 7 | model_name: ICARL_1k -------------------------------------------------------------------------------- /config/agent/icarl/icarl_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ICARL 3 | update: random 4 | retrieve: random 5 | mem_size: 5000 6 | eps_mem_batch: 10 7 | model_name: ICARL_5k -------------------------------------------------------------------------------- /config/data/cifar10/cifar10_noise.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar10 3 | cl_type: ni 4 | ns_factor: [0, 0.6, 1.4, 2.2, 3] 5 | ns_type: noise 6 | data_name: cifar10_noise -------------------------------------------------------------------------------- /config/data/mini_imagenet/mini_imagenet_nc.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: mini_imagenet 3 | num_tasks: 10 4 | cl_type: nc 5 | data_name: mini_imagenet_nc 6 | GPU_ID: 0 7 | -------------------------------------------------------------------------------- /config/general.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | num_runs: 2 3 | seed: 0 4 | optimizer: SGD 5 | epoch: 1 6 | batch: 10 7 | test_batch: 128 8 | num_val: 2 9 | num_runs_val: 2 -------------------------------------------------------------------------------- /config/agent/aser/aser.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: ASER 4 | retrieve: ASER 5 | mem_size: 5000 6 | eps_mem_batch: 10 7 | aser_type: asv 8 | name: ER_ASER -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune_mini.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.005, 0.007, 0.01] 3 | classifier_chill: [0.001, 0.0015, 0.002] 4 | log_alpha: [-1200, -1000, -800] 5 | -------------------------------------------------------------------------------- /config/agent/er/er_2k_b100.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 2000 6 | eps_mem_batch: 100 7 | model_name: Naive_ER_2k 8 | -------------------------------------------------------------------------------- /config/agent/er/er_2k_b30.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 2000 6 | eps_mem_batch: 30 7 | model_name: Naive_ER_2k 8 | -------------------------------------------------------------------------------- /config/agent/er/er_rp_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: random 4 | retrieve: random 5 | mem_size: 10000 6 | eps_mem_batch: 50 7 | model_name: RP_ER_10k 8 | -------------------------------------------------------------------------------- /config/agent/icarl/icarl_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ICARL 3 | update: random 4 | retrieve: random 5 | mem_size: 10000 6 | eps_mem_batch: 10 7 | model_name: ICARL_10k -------------------------------------------------------------------------------- /config/data/cifar10/cifar10_blur.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar10 3 | cl_type: ni 4 | ns_factor: [0, 0.6, 0.9, 1.2, 1.5] 5 | ns_type: blur 6 | data_name: cifar10_blur 7 | -------------------------------------------------------------------------------- /config/data/cifar100/cifar100_noise.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar100 3 | cl_type: ni 4 | ns_factor: [0, 0.6, 1.4, 2.2, 3] 5 | ns_type: noise 6 | data_name: cifar100_noise -------------------------------------------------------------------------------- /config/data/mini_imagenet/mini_imagenet_nc_gpu7.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: mini_imagenet 3 | num_tasks: 10 4 | cl_type: nc 5 | data_name: mini_imagenet_nc 6 | GPU_ID: 0 7 | -------------------------------------------------------------------------------- /config/agent/cndpm/cndpm_tune_mini_ni.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | learning_rate: [0.005, 0.007, 0.01] 3 | classifier_chill: [0.0005, 0.001, 0.002] 4 | log_alpha: [-15000, -5000, -500] 5 | -------------------------------------------------------------------------------- /config/data/cifar100/cifar100_blur.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar100 3 | cl_type: ni 4 | ns_factor: [0, 0.6, 0.9, 1.2, 1.5] 5 | ns_type: blur 6 | data_name: cifar100_blur 7 | -------------------------------------------------------------------------------- /config/agent/gss/gss_10k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: GSS 4 | retrieve: random 5 | mem_size: 10000 6 | eps_mem_batch: 10 7 | gss_batch_size: 10 8 | model_name: GSS_10k 9 | -------------------------------------------------------------------------------- /config/agent/gss/gss_1k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: GSS 4 | retrieve: random 5 | mem_size: 1000 6 | eps_mem_batch: 10 7 | gss_batch_size: 10 8 | model_name: GSS_1k 9 | -------------------------------------------------------------------------------- /config/agent/gss/gss_5k.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | agent: ER 3 | update: GSS 4 | retrieve: random 5 | mem_size: 5000 6 | eps_mem_batch: 10 7 | gss_batch_size: 10 8 | model_name: GSS_5k 9 | -------------------------------------------------------------------------------- /config/data/cifar10/cifar10_occlusion.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar10 3 | cl_type: ni 4 | ns_factor: [0, 0.3, 0.4, 0.5, 0.6] 5 | ns_type: occlusion 6 | data_name: cifar10_occlusion 7 | -------------------------------------------------------------------------------- /config/data/cifar100/cifar100_occlusion.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: cifar100 3 | cl_type: ni 4 | ns_factor: [0, 0.3, 0.4, 0.5, 0.6] 5 | ns_type: occlusion 6 | data_name: cifar100_occlusion 7 | -------------------------------------------------------------------------------- /utils/argparser/argparser_der.py: -------------------------------------------------------------------------------- 1 | from utils.utils import boolean_string 2 | 3 | 4 | def parse_der(parser): 5 | 6 | parser.add_argument("--DER_alpha",default=0.3,type=float) 7 | 8 | 9 | return parser -------------------------------------------------------------------------------- /config/data/mini_imagenet/mini_imagenet_noise.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: mini_imagenet 3 | cl_type: ni 4 | ns_factor: [0.0, 0.4, 0.8, 1.2, 1.6, 2.0, 2.4, 2.8, 3.2, 3.6] 5 | ns_type: noise 6 | data_name: mini_noise -------------------------------------------------------------------------------- /config/data/mini_imagenet/mini_imagenet_blur.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: mini_imagenet 3 | cl_type: ni 4 | ns_factor: [0.0, 0.28, 0.56, 0.83, 1.11, 1.39, 1.67, 1.94, 2.22, 2.5] 5 | ns_type: blur 6 | data_name: mini_blur 7 | -------------------------------------------------------------------------------- /config/data/mini_imagenet/mini_imagenet_occlusion.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | data: mini_imagenet 3 | cl_type: ni 4 | ns_factor: [0.0, 0.07, 0.13, 0.2, 0.27, 0.33, 0.4, 0.47, 0.53, 0.6] 5 | ns_type: occlusion 6 | data_name: mini_occlusion 7 | -------------------------------------------------------------------------------- /models/pretrained.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | import torch 3 | 4 | def ResNet18_pretrained(n_classes): 5 | classifier = models.resnet18(pretrained=True) 6 | classifier.fc = torch.nn.Linear(512, n_classes) 7 | return classifier -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.5.1 3 | torchvision==0.6.1 4 | matplotlib==3.2.1 5 | scipy==1.4.1 6 | scikit-image==0.14.2 7 | scikit-learn==0.23.0 8 | pandas==1.0.5 9 | PyYAML==5.3.1 10 | psutil==5.7.0 -------------------------------------------------------------------------------- /config/agent/icarl/icarl_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | # learning_rate: [0.0001, 0.0003] 4 | # weight_decay: [0.0001,] 5 | #real 6 | learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 7 | weight_decay: [0.0001, 0.001, 0.01, 0.1] 8 | 9 | -------------------------------------------------------------------------------- /config/agent/lwf/lwf_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | # learning_rate: [0.0001, 0.0003] 4 | # weight_decay: [0.0001,] 5 | #real 6 | learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 7 | weight_decay: [0.0001, 0.001, 0.01, 0.1] 8 | 9 | -------------------------------------------------------------------------------- /models/ndpm/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Lambda(nn.Module): 5 | def __init__(self, f=None): 6 | super().__init__() 7 | self.f = f if f is not None else (lambda x: x) 8 | 9 | def forward(self, *args, **kwargs): 10 | return self.f(*args, **kwargs) 11 | -------------------------------------------------------------------------------- /config/agent/er/er_tune.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | learning_rate: [0.001, 0.01,0.05,0.1] 4 | #weight_decay: [0.0001,] 5 | # mem_ratio: [1,] 6 | # mem_iters: [1,3,7,10] 7 | 8 | #real 9 | # learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 10 | # weight_decay: [0.0001, 0.001, 0.01, 0.1] 11 | 12 | -------------------------------------------------------------------------------- /config/agent/er/er_tune_miter.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | # learning_rate: [0.001, 0.01,0.03,0.06,0.1] 4 | # weight_decay: [0.0001,] 5 | #mem_ratio: [1] 6 | mem_iters: [1,3,7,10] 7 | 8 | #real 9 | # learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 10 | # weight_decay: [0.0001, 0.001, 0.01, 0.1] 11 | 12 | -------------------------------------------------------------------------------- /config/general_1_openloris.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: -1 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False -------------------------------------------------------------------------------- /config/general_2_openloris.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: -1 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False -------------------------------------------------------------------------------- /config/general_3_openloris.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: -1 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False -------------------------------------------------------------------------------- /config/general_2_cndpm_mini.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 2 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | -------------------------------------------------------------------------------- /config/agent/er/er_tune_ratio.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | # learning_rate: [0.001, 0.01,0.03,0.06,0.1] 4 | # weight_decay: [0.0001,] 5 | mem_ratio: [0.1,0.5,1,1.5] 6 | #mem_iters: [1,] 7 | 8 | #real 9 | # learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 10 | # weight_decay: [0.0001, 0.001, 0.01, 0.1] 11 | 12 | -------------------------------------------------------------------------------- /config/general_2.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_3.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/agent/er/er_tune_incoming_ratio.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | # learning_rate: [0.001, 0.01,0.03,0.06,0.1] 4 | # weight_decay: [0.0001,] 5 | incoming_ratio: [0.1,0.5,1,1.5] 6 | #mem_iters: [1,] 7 | 8 | #real 9 | # learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 10 | # weight_decay: [0.0001, 0.001, 0.01, 0.1] 11 | 12 | -------------------------------------------------------------------------------- /config/general_1_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_1_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_2_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [2, 3] 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_2_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_3_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [4, 5] 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_3_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_4_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [6, 7] 4 | seed: 3 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_5_core50.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [8, 9] 4 | seed: 4 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_offline_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 10 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 70 8 | batch: 128 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: False -------------------------------------------------------------------------------- /config/general_finetune_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 15 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/ni_mini_general/general_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 2 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/ni_mini_general/general_2.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 2 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/ni_mini_general/general_3.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 2 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/ni_mini_general/general_4.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 2 4 | seed: 3 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/ni_mini_general/general_5.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 2 4 | seed: 4 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True -------------------------------------------------------------------------------- /config/general_1_core50_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_2_core50_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [2, 3] 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_3_core50_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [4, 5] 4 | seed: 3 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_4_core50_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [6, 7] 4 | seed: 4 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_5_core50_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [8, 9] 4 | seed: 5 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_finetune_core50_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 10 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/general_1_core50_offline.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 70 8 | batch: 128 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: False 18 | -------------------------------------------------------------------------------- /config/general_2_core50_offline.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 70 8 | batch: 128 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: False 18 | -------------------------------------------------------------------------------- /config/general_3_core50_offline.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 70 8 | batch: 128 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: False 18 | -------------------------------------------------------------------------------- /config/general_4_core50_offline.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 3 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 70 8 | batch: 128 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: False 18 | -------------------------------------------------------------------------------- /config/general_5_core50_offline.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: [0, 1] 4 | seed: 4 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 70 8 | batch: 128 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: False 18 | -------------------------------------------------------------------------------- /config/ni_mini_general/general_1_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/ni_mini_general/general_2_gdumb.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 16 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 1 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | -------------------------------------------------------------------------------- /config/agent/er/er_tune_aug_iter_0.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | randaug_N: [ 0] 4 | randaug_M: [ 0] 5 | #learning_rate: [0.001, 0.01,0.05,0.1] 6 | #weight_decay: [0.0001,] 7 | # mem_ratio: [1,] 8 | mem_iters: [1,2,5,10,20] 9 | 10 | #real 11 | # learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 12 | # weight_decay: [0.0001, 0.001, 0.01, 0.1] 13 | 14 | -------------------------------------------------------------------------------- /config/general_labels_trick_3.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 2 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | # tricks 18 | trick: {labels_trick: True} -------------------------------------------------------------------------------- /config/agent/er/er_tune_aug_iter.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #test 3 | randaug_N: [ 1, 2,3] 4 | randaug_M: [ 5,14] 5 | #learning_rate: [0.001, 0.01,0.05,0.1] 6 | #weight_decay: [0.0001,] 7 | # mem_ratio: [1,] 8 | mem_iters: [1,2,5,10,20] 9 | 10 | #real 11 | # learning_rate: [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1] 12 | # weight_decay: [0.0001, 0.001, 0.01, 0.1] 13 | 14 | -------------------------------------------------------------------------------- /config/general_labels_trick_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | # tricks 18 | trick: {labels_trick: True} 19 | -------------------------------------------------------------------------------- /config/general_labels_trick_2.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 5 4 | seed: 1 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 13 | num_runs_val: 3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | # tricks 18 | trick: {labels_trick: True} 19 | -------------------------------------------------------------------------------- /utils/buffer/recycle.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class recycle(object): 4 | 5 | def __init__(self): 6 | self.tmp_img=[] 7 | self.tmp_label=[] 8 | def store_tmp(self,img,label): 9 | self.tmp_img.append(img) 10 | self.tmp_label.append(label) 11 | def clear(self): 12 | self.tmp_img =[] 13 | self.tmp_img = [] 14 | 15 | def retrieve_tmp(self): 16 | return self.tmp_img, self.tmp_label -------------------------------------------------------------------------------- /run_commands/base_commands/command_mir.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RES_SIZE=$7 10 | ## mir 11 | python general_main.py --data $DATASET_NAME --cl_type nc \ 12 | --agent "ER" --retrieve "MIR" --update random \ 13 | --mem_size $MEM_SIZE --eps_mem_batch 10 \ 14 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 15 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 16 | --nmc_trick True --resnet_size $RES_SIZE 17 | 18 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_er.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RES_SIZE=$7 10 | 11 | ## ER 12 | python general_main.py --data $DATASET_NAME --cl_type nc \ 13 | --agent "ER" --retrieve "random" --update random \ 14 | --mem_size $MEM_SIZE --eps_mem_batch 10 \ 15 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 16 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 17 | --nmc_trick True --resnet_size $RES_SIZE 18 | 19 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_scr.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RES_SIZE=$7 10 | 11 | ## scr 12 | 13 | python general_main.py --data $DATASET_NAME --cl_type nc \ 14 | --agent "SCR" --retrieve random --update random \ 15 | --mem_size $MEM_SIZE --head mlp --temp 0.07 --eps_mem_batch 100 \ 16 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 17 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 18 | --nmc_trick True \ 19 | --resnet_size $RES_SIZE 20 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_aser.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RES_SIZE=$7 10 | 11 | ## ASER 12 | 13 | python general_main.py --data $DATASET_NAME --cl_type nc \ 14 | --agent "ER" --retrieve "ASER" --update "ASER" \ 15 | --mem_size $MEM_SIZE --eps_mem_batch 10 \ 16 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 17 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 18 | --aser_type "asvm" --n_smp_cls 1.5 --k 3 --nmc_trick True \ 19 | --resnet_size $RES_SIZE -------------------------------------------------------------------------------- /utils/buffer/sc_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | import torch 3 | 4 | class Match_retrieve(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | self.num_retrieve = params.eps_mem_batch 8 | self.warmup = params.warmup 9 | 10 | def retrieve(self, buffer, **kwargs): 11 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 12 | cur_x, cur_y = kwargs['x'], kwargs['y'] 13 | 14 | return match_retrieve(buffer, cur_y) 15 | else: 16 | return torch.tensor([]), torch.tensor([]) -------------------------------------------------------------------------------- /run_commands/base_commands/command_er_scraug.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | #RAUG_N=$7 10 | #RAUG_M=$8 11 | 12 | MEM_BATCH=${7} 13 | RES_SIZE=${8} 14 | 15 | RAUG_TARGET=$9 16 | ## ER--scraug 17 | python general_main.py --data $DATASET_NAME --cl_type nc \ 18 | --agent "ER" --retrieve "random" --update random \ 19 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 20 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 21 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 22 | --nmc_trick True \ 23 | --resnet_size $RES_SIZE \ 24 | --scraug True \ 25 | --aug_target $RAUG_TARGET -------------------------------------------------------------------------------- /run_commands/base_commands/command_mir_raug.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | GPU_ID=$1 5 | NUM_TASKS=$2 6 | DATASET_NAME=$3 7 | SEED=$4 8 | MEM_SIZE=$5 9 | MEM_ITER=$6 10 | RAUG_N=$7 11 | RAUG_M=$8 12 | RAUG_TARGET=$9 13 | MEM_BATCH=${10} 14 | RES_SIZE=${11} 15 | 16 | ## mir-raug 17 | python general_main.py --data $DATASET_NAME --cl_type nc \ 18 | --agent "ER" --retrieve "MIR" --update random \ 19 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 20 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 21 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 22 | --nmc_trick True \ 23 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 24 | --resnet_size $RES_SIZE 25 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_er_raug_t1.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | 15 | ## ER 16 | python general_main.py --data $DATASET_NAME --cl_type nc \ 17 | --agent "ER" --retrieve "random" --update random \ 18 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 19 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 20 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 21 | --nmc_trick True \ 22 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 23 | --resnet_size $RES_SIZE --aug_start 1 24 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_scr_raug.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | ## scr raug 15 | python general_main.py --data $DATASET_NAME --cl_type nc \ 16 | --agent "SCR" --retrieve random --update random \ 17 | --mem_size $MEM_SIZE --head mlp --temp 0.07 --eps_mem_batch 100 \ 18 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 19 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 20 | --nmc_trick True \ 21 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 22 | --resnet_size $RES_SIZE 23 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_er_raug.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | 15 | ## ER 16 | python general_main.py --data $DATASET_NAME --cl_type nc \ 17 | --agent "ER" --retrieve "random" --update random \ 18 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 19 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 20 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 21 | --nmc_trick True \ 22 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 23 | --resnet_size $RES_SIZE #--immediate_evaluate True 24 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_mir_raug_t1.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | GPU_ID=$1 5 | NUM_TASKS=$2 6 | DATASET_NAME=$3 7 | SEED=$4 8 | MEM_SIZE=$5 9 | MEM_ITER=$6 10 | RAUG_N=$7 11 | RAUG_M=$8 12 | RAUG_TARGET=$9 13 | MEM_BATCH=${10} 14 | RES_SIZE=${11} 15 | 16 | ## mir-raug 17 | python general_main.py --data $DATASET_NAME --cl_type nc \ 18 | --agent "ER" --retrieve "MIR" --update random \ 19 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 20 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 21 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 22 | --nmc_trick True \ 23 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 24 | --resnet_size $RES_SIZE --aug_start 1 25 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_scr_raug_t1.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | ## scr raug 15 | python general_main.py --data $DATASET_NAME --cl_type nc \ 16 | --agent "SCR" --retrieve random --update random \ 17 | --mem_size $MEM_SIZE --head mlp --temp 0.07 --eps_mem_batch 100 \ 18 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 19 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 20 | --nmc_trick True \ 21 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 22 | --resnet_size $RES_SIZE --aug_start 1 23 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_aser_raug.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | 15 | ## ASER 16 | 17 | python general_main.py --data $DATASET_NAME --cl_type nc \ 18 | --agent "ER" --retrieve "ASER" --update "ASER" \ 19 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 20 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 21 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 22 | --aser_type "asvm" --n_smp_cls 1.5 --k 3 --nmc_trick True \ 23 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 24 | --resnet_size $RES_SIZE 25 | -------------------------------------------------------------------------------- /fetch_data_setup.sh: -------------------------------------------------------------------------------- 1 | DIR="$( cd "$( dirname "$0" )" && pwd )" 2 | mkdir -p $DIR/datasets 3 | 4 | echo "Downloading Core50 (128x128 version)..." 5 | echo $DIR'/datasets/core50/' 6 | wget --directory-prefix=$DIR'/datasets/core50/' http://bias.csr.unibo.it/maltoni/download/core50/core50_128x128.zip 7 | 8 | echo "Unzipping data..." 9 | unzip $DIR/datasets/core50/core50_128x128.zip -d $DIR/datasets/core50/ 10 | 11 | mv $DIR/datasets/core50/core50_128x128/* $DIR/datasets/core50/ 12 | 13 | wget --directory-prefix=$DIR'/datasets/core50/' https://vlomonaco.github.io/core50/data/paths.pkl 14 | wget --directory-prefix=$DIR'/datasets/core50/' https://vlomonaco.github.io/core50/data/LUP.pkl 15 | wget --directory-prefix=$DIR'/datasets/core50/' https://vlomonaco.github.io/core50/data/labels.pkl -------------------------------------------------------------------------------- /run_commands/base_commands/command_aser_raug_t1.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | 15 | ## ASER 16 | 17 | python general_main.py --data $DATASET_NAME --cl_type nc \ 18 | --agent "ER" --retrieve "ASER" --update "ASER" \ 19 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 20 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 21 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 22 | --aser_type "asvm" --n_smp_cls 1.5 --k 3 --nmc_trick True \ 23 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 24 | --resnet_size $RES_SIZE --aug_start 1 25 | -------------------------------------------------------------------------------- /checkpoint_model.py: -------------------------------------------------------------------------------- 1 | 2 | #from models.resnet import Reduced_ResNet18 3 | #import torch.nn as nn 4 | from RL.pytorch_util import build_mlp 5 | import torch 6 | 7 | 8 | PATH = "results/701/ER_random_random_NMC_testbatch100_RLmemIter_31_11_numRuns1_20_5000_cifar100_model" 9 | model_dict = torch.load(PATH) 10 | 11 | #nclass=100 12 | #model = Reduced_ResNet18(nclass) 13 | model = build_mlp(4,3,n_layers=2,size=32) 14 | 15 | 16 | #optimizer = TheOptimizerClass(*args, **kwargs) 17 | 18 | checkpoint = torch.load(PATH) 19 | model.load_state_dict(checkpoint) 20 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 21 | # epoch = checkpoint['epoch'] 22 | # loss = checkpoint['loss'] 23 | 24 | model.eval() 25 | 26 | x = torch.zeros((1,4)) 27 | with torch.no_grad(): 28 | y = model.forward(x) 29 | print(y) -------------------------------------------------------------------------------- /run_commands/base_commands/command_der_deraug.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | EPOCH=${12} 15 | DER_alpha=${13} 16 | 17 | ## ER 18 | python general_main.py --data $DATASET_NAME --cl_type nc \ 19 | --agent "DER" --retrieve "random" --update random \ 20 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 21 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 22 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 23 | --nmc_trick True \ 24 | --randaug False --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 25 | --deraug True \ 26 | --resnet_size $RES_SIZE --dataset_random_type "task_random" --epoch $EPOCH \ 27 | #--DER_alpha $DER_alpha 28 | #--immediate_evaluate True 29 | -------------------------------------------------------------------------------- /run_commands/runs/run_rar_with_scr.sh: -------------------------------------------------------------------------------- 1 | #SEED=1259051 2 | MEM_SIZE=2000 3 | NUM_TASKS=0 4 | #DATASET_NAME="cifar100" 5 | GPU_ID=0 6 | 7 | 8 | NAME_PREFIX="run_commands/base_commands/command_" 9 | 10 | for SEED in 1259051 1259052 1259053 11 | do 12 | for DATASET_NAME in "mini_imagenet" "cifar100" "clrs25" "core50" 13 | do 14 | for ALGO_NAME in "scr" 15 | do 16 | 17 | RES_SIZE="reduced" 18 | 19 | MEM_ITER=1 20 | FILE_NAME=$NAME_PREFIX$ALGO_NAME".sh" 21 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER $RES_SIZE 22 | 23 | 24 | MEM_ITER=10 25 | RAUG_N=1 26 | RAUG_M=14 27 | RAUG_TARGET="both" ## mem incoming none 28 | MEM_BATCH=100 29 | FILE_NAME=$NAME_PREFIX$ALGO_NAME"_raug.sh" 30 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 31 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE 32 | 33 | # 34 | # 35 | 36 | done 37 | 38 | done 39 | 40 | done -------------------------------------------------------------------------------- /run_commands/runs/run_test_rar_er_cifar100.sh: -------------------------------------------------------------------------------- 1 | #SEED=1259051 2 | MEM_SIZE=2000 3 | NUM_TASKS=0 4 | #DATASET_NAME="cifar10" 5 | GPU_ID=0 6 | NAME_PREFIX="run_commands/base_commands/command_" 7 | 8 | RES_SIZE="reduced" 9 | for DATASET_NAME in "cifar100" 10 | do 11 | for SEED in 1259051 #1259052 1259053 12 | do 13 | for ALGO_NAME in "er" 14 | do 15 | # ########### baseline: without RAR 16 | # MEM_ITER=1 17 | # FILE_NAME=$NAME_PREFIX$ALGO_NAME".sh" 18 | # source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER $RES_SIZE 19 | 20 | 21 | # 22 | ########### with RAR 23 | MEM_ITER=10 24 | RAUG_N=1 25 | RAUG_M=14 26 | RAUG_TARGET="both" ## mem incoming none 27 | MEM_BATCH=10 28 | FILE_NAME=$NAME_PREFIX$ALGO_NAME"_raug.sh" 29 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 30 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE 31 | 32 | 33 | 34 | done 35 | # 36 | # 37 | 38 | 39 | done 40 | 41 | done -------------------------------------------------------------------------------- /run_commands/runs/run_rar_with_er_mir_aser.sh: -------------------------------------------------------------------------------- 1 | #SEED=1259051 2 | MEM_SIZE=2000 3 | GPU_ID=0 4 | 5 | 6 | NAME_PREFIX="run_commands/base_commands/command_" 7 | RES_SIZE="reduced" 8 | for DATASET_NAME in "cifar100" "mini_imagenet" "clrs25" "core50" 9 | do 10 | for SEED in 1259051 1259052 1259053 11 | do 12 | for ALGO_NAME in "er" "aser" "mir" 13 | do 14 | ########### baseline: without RAR 15 | MEM_ITER=1 16 | FILE_NAME=$NAME_PREFIX$ALGO_NAME".sh" 17 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER $RES_SIZE 18 | 19 | 20 | # 21 | ########### with RAR 22 | MEM_ITER=1 23 | RAUG_N=1 24 | RAUG_M=14 25 | RAUG_TARGET="both" ## mem incoming none 26 | MEM_BATCH=10 27 | FILE_NAME=$NAME_PREFIX$ALGO_NAME"_raug.sh" 28 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 29 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE 30 | # 31 | E 32 | 33 | done 34 | # 35 | # 36 | 37 | 38 | done 39 | 40 | done -------------------------------------------------------------------------------- /run_commands/runs/run_adaptive_RAR_RL.sh: -------------------------------------------------------------------------------- 1 | #SEED=1259051 2 | MEM_SIZE=2000 3 | 4 | NUM_TASKS=0 5 | GPU_ID=0 6 | NAME_PREFIX="run_commands/base_commands/command_" 7 | 8 | 9 | for SEED in 1259051 #1259052 1259053 10 | do 11 | ALGO_NAME="er" ## er, scr, aser 12 | 13 | for DATASET_NAME in "cifar100" #"mini_imagenet" "clrs25" "core50" 14 | do 15 | 16 | 17 | 18 | MEM_ITER=10 19 | RAUG_N=1 20 | RAUG_M=14 21 | RAUG_TARGET="both" ## mem incoming none 22 | MEM_BATCH=100 23 | RES_SIZE="reduced" 24 | MEM_MAX=15 25 | STOP_RATIO=10 26 | LR_large=10 27 | LR_small=5 28 | ACC_MAX=0.9 29 | ACC_MIN=0.8 30 | AUG_NUM=5 31 | AUG_MAX=0.8 32 | AUG_MIN=0.7 33 | 34 | FILE_NAME=$NAME_PREFIX"adaptive_rar_RL.sh" 35 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 36 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE $MEM_MAX $STOP_RATIO \ 37 | $LR_large $LR_small $ACC_MAX $ACC_MIN $AUG_NUM $AUG_MAX $AUG_MIN 38 | # 39 | # 40 | 41 | done 42 | 43 | done -------------------------------------------------------------------------------- /utils/kd_manager.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def loss_fn_kd(scores, target_scores, T=2.): 7 | log_scores_norm = F.log_softmax(scores / T, dim=1) 8 | targets_norm = F.softmax(target_scores / T, dim=1) 9 | # Calculate distillation loss (see e.g., Li and Hoiem, 2017) 10 | kd_loss = (-1 * targets_norm * log_scores_norm).sum(dim=1).mean() * T ** 2 11 | return kd_loss 12 | 13 | 14 | class KdManager: 15 | def __init__(self): 16 | self.teacher_model = None 17 | 18 | def update_teacher(self, model): 19 | self.teacher_model = copy.deepcopy(model) 20 | 21 | def get_kd_loss(self, cur_model_logits, x): 22 | if self.teacher_model is not None: 23 | with torch.no_grad(): 24 | prev_model_logits = self.teacher_model.forward(x) 25 | dist_loss = loss_fn_kd(cur_model_logits, prev_model_logits) 26 | else: 27 | dist_loss = 0 28 | return dist_loss 29 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pandas as pd 3 | import os 4 | import psutil 5 | 6 | def load_yaml(path, key='parameters'): 7 | with open(path, 'r') as stream: 8 | try: 9 | return yaml.load(stream, Loader=yaml.FullLoader)[key] 10 | except yaml.YAMLError as exc: 11 | print(exc) 12 | 13 | def save_dataframe_csv(df, path, name): 14 | df.to_csv(path + '/' + name, index=False) 15 | 16 | 17 | def load_dataframe_csv(path, name=None, delimiter=None, names=None): 18 | if not name: 19 | return pd.read_csv(path, delimiter=delimiter, names=names) 20 | else: 21 | return pd.read_csv(path+name, delimiter=delimiter, names=names) 22 | 23 | def check_ram_usage(): 24 | """ 25 | Compute the RAM usage of the current process. 26 | Returns: 27 | mem (float): Memory occupation in Megabytes 28 | """ 29 | 30 | process = psutil.Process(os.getpid()) 31 | mem = process.memory_info().rss / (1024 * 1024) 32 | 33 | return mem -------------------------------------------------------------------------------- /utils/buffer/mem_match.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import match_retrieve 2 | from utils.buffer.buffer_utils import random_retrieve 3 | import torch 4 | 5 | class MemMatch_retrieve(object): 6 | def __init__(self, params): 7 | super().__init__() 8 | self.num_retrieve = params.eps_mem_batch 9 | self.warmup = params.warmup 10 | 11 | 12 | def retrieve(self, buffer, **kwargs): 13 | match_x, match_y = torch.tensor([]), torch.tensor([]) 14 | candidate_x, candidate_y = torch.tensor([]), torch.tensor([]) 15 | if buffer.n_seen_so_far > self.num_retrieve * self.warmup: 16 | while match_x.size(0) == 0: 17 | candidate_x, candidate_y, indices = random_retrieve(buffer, self.num_retrieve,return_indices=True) 18 | if candidate_x.size(0) == 0: 19 | return candidate_x, candidate_y, match_x, match_y 20 | match_x, match_y = match_retrieve(buffer, candidate_y, indices) 21 | return candidate_x, candidate_y, match_x, match_y -------------------------------------------------------------------------------- /config/general_1.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 1 #5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 3 #17 13 | num_runs_val: 1 #3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | 19 | exp_name: para17_7_4 20 | 21 | dataset_random_type: "task_random" 22 | switch_buffer_type: "one_buffer" 23 | RL_type: "NoRL" 24 | use_test_buffer: False 25 | test_mem_type: "after" 26 | use_tmp_buffer: False 27 | frozen_old_fc: False 28 | mem_iters: 1 29 | incoming_ratio: 1 30 | mem_ratio: 1 31 | save_prefix: "joint_training" 32 | replay_old_only: False 33 | only_task_seen: False 34 | use_softmaxloss: False 35 | buffer_tracker: False 36 | error_analysis: False 37 | start_mem_iters: -1 38 | close_loop_mem_type: "random" 39 | joint_replay_type: "together" 40 | online_hyper_tune: False 41 | 42 | lambda_: 100 43 | nmc_trick : False 44 | learning_rate: 0.1 45 | weight_decay: 0.0001 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /config/general_17.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 1 #5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 17 #17 13 | num_runs_val: 1 #3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | 19 | exp_name: para17_7_4 20 | 21 | dataset_random_type: "task_random" 22 | switch_buffer_type: "one_buffer" 23 | RL_type: "NoRL" 24 | use_test_buffer: False 25 | test_mem_type: "after" 26 | use_tmp_buffer: False 27 | frozen_old_fc: False 28 | mem_iters: 1 29 | incoming_ratio: 1 30 | mem_ratio: 1 31 | save_prefix: "joint_training" 32 | replay_old_only: False 33 | only_task_seen: False 34 | use_softmaxloss: False 35 | buffer_tracker: False 36 | error_analysis: False 37 | start_mem_iters: -1 38 | close_loop_mem_type: "random" 39 | joint_replay_type: "together" 40 | online_hyper_tune: False 41 | 42 | lambda_: 100 43 | nmc_trick : False 44 | learning_rate: 0.1 45 | weight_decay: 0.0001 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /config/general_20.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 1 #5 4 | seed: 0 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 20 #17 13 | num_runs_val: 1 #3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | 19 | exp_name: para17_7_4 20 | 21 | dataset_random_type: "task_random" 22 | switch_buffer_type: "one_buffer" 23 | RL_type: "NoRL" 24 | use_test_buffer: False 25 | test_mem_type: "after" 26 | use_tmp_buffer: False 27 | frozen_old_fc: False 28 | mem_iters: 1 29 | incoming_ratio: 1 30 | mem_ratio: 1 31 | save_prefix: "joint_training" 32 | replay_old_only: False 33 | only_task_seen: False 34 | use_softmaxloss: False 35 | buffer_tracker: False 36 | error_analysis: False 37 | start_mem_iters: -1 38 | close_loop_mem_type: "random" 39 | joint_replay_type: "together" 40 | online_hyper_tune: False 41 | 42 | lambda_: 100 43 | nmc_trick : False 44 | learning_rate: 0.1 45 | weight_decay: 0.0001 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /utils/buffer/random_retrieve.py: -------------------------------------------------------------------------------- 1 | from utils.buffer.buffer_utils import random_retrieve 2 | 3 | class Random_retrieve(object): 4 | def __init__(self, params): 5 | super().__init__() 6 | self.num_retrieve = params.eps_mem_batch 7 | 8 | # def retrieve(self, buffer, **kwargs): 9 | # x,y,indices = random_retrieve(buffer, self.num_retrieve,return_indices=True) 10 | # buffer.update_replay_times(indices) 11 | # return x,y 12 | def set_retrieve_num(self,num): 13 | self.num_retrieve = num 14 | ## todo: make retrieve compatitble with ER 15 | def retrieve(self, buffer,retrieve_num=None, **kwargs): 16 | if(retrieve_num == None): 17 | retrieve_num = self.num_retrieve 18 | if hasattr(buffer, "buffer_logits"): 19 | 20 | x,y,logits,indices = random_retrieve(buffer, retrieve_num,return_indices=True) 21 | buffer.update_replay_times(indices) 22 | return x,y,logits 23 | else: 24 | x,y,indices = random_retrieve(buffer, retrieve_num,return_indices=True) 25 | buffer.update_replay_times(indices) 26 | return x,y 27 | -------------------------------------------------------------------------------- /run_commands/runs/run_rar_with_der_deraug.sh: -------------------------------------------------------------------------------- 1 | #SEED=1259051 2 | NUM_TASKS=0 3 | #DATASET_NAME="cifar10" 4 | GPU_ID=0 5 | NAME_PREFIX="run_commands/base_commands/command_" 6 | 7 | RES_SIZE="reduced" 8 | for DATASET_NAME in "cifar100" #"mini_imagenet" #"cifar100" "clrs25" "core50" 9 | do 10 | for SEED in 1259051 #1259052 1259053 11 | do 12 | for ALGO_NAME in "der" 13 | do 14 | # ########### baseline: without RAR 15 | # MEM_ITER=1 16 | # MEM_BATCH=100 17 | # FILE_NAME=$NAME_PREFIX$ALGO_NAME".sh" 18 | # source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER $RES_SIZE \ 19 | # $MEM_BATCH 20 | 21 | for MEM_ITER in 50 #1 22 | do 23 | # 24 | MEM_SIZE=2000 25 | ########### with RAR 26 | #MEM_ITER=1 27 | RAUG_N=1 28 | RAUG_M=14 29 | RAUG_TARGET="both" ## mem incoming none 30 | MEM_BATCH=10 31 | EPOCH=1 32 | FILE_NAME=$NAME_PREFIX$ALGO_NAME"_deraug.sh" 33 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 34 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE $EPOCH 35 | done 36 | # 37 | 38 | 39 | done 40 | # 41 | # 42 | 43 | 44 | done 45 | 46 | done -------------------------------------------------------------------------------- /config/general_1_aug_cifar10.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | 4 | num_runs: 1 #5 5 | seed: 1259051 6 | # optimizer 7 | optimizer: SGD 8 | epoch: 1 9 | batch: 10 10 | test_batch: 128 11 | # validation 12 | val_size: 0.0 13 | num_val: 2 #17 14 | num_runs_val: 1 #3 15 | # data general 16 | fix_order: False 17 | plot_sample: False 18 | online: True 19 | 20 | exp_name: para17_7_4 21 | 22 | dataset_random_type: "task_random" 23 | switch_buffer_type: "one_buffer" 24 | RL_type: "NoRL" 25 | use_test_buffer: False 26 | test_mem_type: "after" 27 | use_tmp_buffer: False 28 | frozen_old_fc: False 29 | mem_iters: 1 30 | incoming_ratio: 1 31 | mem_ratio: 1 32 | save_prefix: "joint_training" 33 | replay_old_only: False 34 | only_task_seen: False 35 | use_softmaxloss: False 36 | buffer_tracker: False 37 | error_analysis: False 38 | start_mem_iters: -1 39 | close_loop_mem_type: "random" 40 | joint_replay_type: "together" 41 | online_hyper_tune: False 42 | 43 | lambda_: 100 44 | nmc_trick : False 45 | learning_rate: 0.1 46 | weight_decay: 0.0001 47 | 48 | randaug: True 49 | randaug_N: 1 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /continuum/continuum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # other imports 4 | from utils.name_match import data_objects 5 | 6 | class continuum(object): 7 | def __init__(self, dataset, scenario, params): 8 | """" Initialize Object """ 9 | self.data_object = data_objects[dataset](scenario, params) 10 | self.run = params.num_runs 11 | self.task_nums = self.data_object.task_nums 12 | self.cur_task = 0 13 | self.cur_run = -1 14 | 15 | def __iter__(self): 16 | return self 17 | 18 | def __next__(self): 19 | if self.cur_task == self.data_object.task_nums: 20 | raise StopIteration 21 | x_train, y_train, labels = self.data_object.new_task(self.cur_task, cur_run=self.cur_run) 22 | self.cur_task += 1 23 | return x_train, y_train, labels 24 | 25 | def test_data(self): 26 | return self.data_object.get_test_set() 27 | 28 | def clean_mem_test_set(self): 29 | self.data_object.clean_mem_test_set() 30 | 31 | def reset_run(self): 32 | self.cur_task = 0 33 | 34 | def new_run(self): 35 | self.cur_task = 0 36 | self.cur_run += 1 37 | self.data_object.new_run(cur_run=self.cur_run) 38 | 39 | 40 | -------------------------------------------------------------------------------- /utils/global_vars.py: -------------------------------------------------------------------------------- 1 | MODLES_NDPM_VAE_NF_BASE = 32 2 | MODELS_NDPM_VAE_NF_EXT = 4 3 | MODELS_NDPM_VAE_PRECURSOR_CONDITIONED_DECODER = False 4 | MODELS_NDPM_VAE_Z_DIM = 64 5 | MODELS_NDPM_VAE_RECON_LOSS = 'gaussian' 6 | MODELS_NDPM_VAE_LEARN_X_LOG_VAR = False 7 | MODELS_NDPM_VAE_X_LOG_VAR_PARAM = 0 8 | MODELS_NDPM_VAE_Z_SAMPLES = 16 9 | MODELS_NDPM_CLASSIFIER_NUM_BLOCKS = [1, 1, 1, 1] 10 | MODELS_NDPM_CLASSIFIER_NORM_LAYER = 'InstanceNorm2d' 11 | MODELS_NDPM_CLASSIFIER_CLS_NF_BASE = 20 12 | MODELS_NDPM_CLASSIFIER_CLS_NF_EXT = 4 13 | MODELS_NDPM_NDPM_DISABLE_D = False 14 | MODELS_NDPM_NDPM_SEND_TO_STM_ALWAYS = False 15 | MODELS_NDPM_NDPM_SLEEP_BATCH_SIZE = 50 16 | MODELS_NDPM_NDPM_SLEEP_NUM_WORKERS = 0 17 | MODELS_NDPM_NDPM_SLEEP_STEP_G = 4000 18 | MODELS_NDPM_NDPM_SLEEP_STEP_D = 1000 19 | MODELS_NDPM_NDPM_SLEEP_SLEEP_VAL_SIZE = 0 20 | MODELS_NDPM_NDPM_SLEEP_SUMMARY_STEP = 500 21 | MODELS_NDPM_NDPM_WEIGHT_DECAY = 0.00001 22 | MODELS_NDPM_NDPM_IMPLICIT_LR_DECAY = False 23 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 24 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D = {'type': 'MultiStepLR', 'options': {'milestones': [1], 'gamma': 0.2}} 25 | MODELS_NDPM_COMPONENT_CLIP_GRAD = {'type': 'value', 'options': {'clip_value': 0.5}} 26 | -------------------------------------------------------------------------------- /run_commands/base_commands/command_adaptive_rar_RL.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPU_ID=$1 4 | NUM_TASKS=$2 5 | DATASET_NAME=$3 6 | SEED=$4 7 | MEM_SIZE=$5 8 | MEM_ITER=$6 9 | RAUG_N=$7 10 | RAUG_M=$8 11 | RAUG_TARGET=$9 12 | MEM_BATCH=${10} 13 | RES_SIZE=${11} 14 | MEM_MAX=${12} 15 | STOP_RATIO=${13} 16 | BPG_LR_large=${14} 17 | BPG_LR_small=${15} 18 | ACC_MAX=${16} 19 | ACC_MIN=${17} 20 | AUG_NUM=${18} 21 | AUG_MAX=${19} 22 | AUG_MIN=${20} 23 | ## ER-aug-dyna 24 | 25 | python general_main.py --data $DATASET_NAME --cl_type nc \ 26 | --agent "ER_dyna_iter_aug_dbpg_joint" --retrieve "random" --update random \ 27 | --mem_size $MEM_SIZE --eps_mem_batch $MEM_BATCH \ 28 | --dataset_random_type task_random --seed $SEED --num_tasks $NUM_TASKS \ 29 | --seed $SEED --GPU_ID $GPU_ID --mem_iters $MEM_ITER \ 30 | --nmc_trick True \ 31 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M --aug_target $RAUG_TARGET \ 32 | --resnet_size $RES_SIZE --dyna_type "bpg" --mem_iter_max $MEM_MAX \ 33 | --dyna_mem_iter "STOP_loss" --stop_ratio $STOP_RATIO --bpg_restart True \ 34 | --adjust_aug_flag True --bpg_lr_large $BPG_LR_large --bpg_lr_small $BPG_LR_small \ 35 | --train_acc_max $ACC_MAX --train_acc_min $ACC_MIN --save_prefix "nostopflag" \ 36 | --aug_action_num $AUG_NUM --train_acc_max_aug $AUG_MAX --train_acc_min_aug $AUG_MIN 37 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/dataset_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | 4 | class DatasetBase(ABC): 5 | def __init__(self, dataset, scenario, task_nums, run, params): 6 | super(DatasetBase, self).__init__() 7 | self.params = params 8 | self.scenario = scenario 9 | self.dataset = dataset 10 | self.task_nums = task_nums 11 | self.run = run 12 | self.root = os.path.join('./datasets', self.dataset) 13 | self.test_set = [] 14 | self.val_set = [] 15 | self._is_properly_setup() 16 | self.download_load() 17 | 18 | 19 | @abstractmethod 20 | def download_load(self): 21 | pass 22 | 23 | @abstractmethod 24 | def setup(self, **kwargs): 25 | pass 26 | 27 | @abstractmethod 28 | def new_task(self, cur_task, **kwargs): 29 | pass 30 | 31 | def _is_properly_setup(self): 32 | pass 33 | 34 | @abstractmethod 35 | def new_run(self, **kwargs): 36 | pass 37 | 38 | @property 39 | def dataset_info(self): 40 | return self.dataset 41 | 42 | def get_test_set(self): 43 | return self.test_set 44 | 45 | def clean_mem_test_set(self): 46 | self.test_set = None 47 | self.test_data = None 48 | self.test_label = None -------------------------------------------------------------------------------- /run_commands/runs/rar_ablation_aug_iter.sh: -------------------------------------------------------------------------------- 1 | #SEED=1259051 2 | MEM_SIZE=2000 3 | NUM_TASKS=0 4 | GPU_ID=0 5 | 6 | NAME_PREFIX="run_commands/base_commands/command_" 7 | 8 | for SEED in 1259051 1259052 1259053 9 | do 10 | for DATASET_NAME in "mini_imagenet" "cifar100" "clrs25" "core50" 11 | do 12 | ALGO_NAME="er" 13 | 14 | RES_SIZE="reduced" 15 | 16 | 17 | for MEM_ITER in 1 2 5 10 20 18 | do 19 | ############### no augmentation 20 | FILE_NAME=$NAME_PREFIX$ALGO_NAME".sh" 21 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER $RES_SIZ 22 | 23 | ############ augmentation(P=1,Q=14) ######### 24 | RAUG_N=1 25 | RAUG_M=14 26 | RAUG_TARGET="both" ## mem incoming none 27 | MEM_BATCH=10 28 | FILE_NAME=$NAME_PREFIX$ALGO_NAME"_raug.sh" 29 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 30 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE 31 | 32 | ############ augmentation(P=2,Q=14) ######### 33 | RAUG_N=2 34 | RAUG_M=14 35 | RAUG_TARGET="both" ## mem incoming none 36 | MEM_BATCH=10 37 | FILE_NAME=$NAME_PREFIX$ALGO_NAME"_raug.sh" 38 | source $FILE_NAME $GPU_ID $NUM_TASKS $DATASET_NAME $SEED $MEM_SIZE $MEM_ITER \ 39 | $RAUG_N $RAUG_M $RAUG_TARGET $MEM_BATCH $RES_SIZE 40 | 41 | done 42 | 43 | 44 | 45 | 46 | done 47 | 48 | done -------------------------------------------------------------------------------- /config/general_1_aug.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 1 #5 4 | seed: 1259051 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 #17 13 | num_runs_val: 1 #3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | 19 | exp_name: para17_7_4 20 | 21 | dataset_random_type: "task_random" 22 | switch_buffer_type: "one_buffer" 23 | RL_type: "NoRL" 24 | use_test_buffer: False 25 | test_mem_type: "after" 26 | use_tmp_buffer: False 27 | frozen_old_fc: False 28 | mem_iters: 1 29 | incoming_ratio: 1 30 | mem_ratio: 1 31 | save_prefix: "joint_training" 32 | replay_old_only: False 33 | only_task_seen: False 34 | use_softmaxloss: False 35 | buffer_tracker: False 36 | error_analysis: False 37 | start_mem_iters: -1 38 | close_loop_mem_type: "random" 39 | joint_replay_type: "together" 40 | online_hyper_tune: False 41 | 42 | lambda_: 100 43 | nmc_trick : False 44 | learning_rate: 0.1 45 | weight_decay: 0.0001 46 | 47 | randaug: True 48 | #randaug_M: 14 49 | resnet_size: "reduced" 50 | aug_target: "both" 51 | aug_start: 0 52 | randaug_type: "static" 53 | scraug: False 54 | immediate_evaluate: False 55 | test_add_buffer: False 56 | dyna_mem_iter: "STOP_loss" 57 | stop_ratio : 3 58 | offline: False 59 | 60 | 61 | -------------------------------------------------------------------------------- /config/general_1_aug_seed2.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 1 #5 4 | seed: 1259052 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 #17 13 | num_runs_val: 1 #3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | 19 | exp_name: para17_7_4 20 | 21 | dataset_random_type: "task_random" 22 | switch_buffer_type: "one_buffer" 23 | RL_type: "NoRL" 24 | use_test_buffer: False 25 | test_mem_type: "after" 26 | use_tmp_buffer: False 27 | frozen_old_fc: False 28 | mem_iters: 1 29 | incoming_ratio: 1 30 | mem_ratio: 1 31 | save_prefix: "joint_training" 32 | replay_old_only: False 33 | only_task_seen: False 34 | use_softmaxloss: False 35 | buffer_tracker: False 36 | error_analysis: False 37 | start_mem_iters: -1 38 | close_loop_mem_type: "random" 39 | joint_replay_type: "together" 40 | online_hyper_tune: False 41 | 42 | lambda_: 100 43 | nmc_trick : False 44 | learning_rate: 0.1 45 | weight_decay: 0.0001 46 | 47 | randaug: True 48 | #randaug_M: 14 49 | resnet_size: "reduced" 50 | aug_target: "both" 51 | aug_start: 0 52 | randaug_type: "static" 53 | scraug: False 54 | immediate_evaluate: False 55 | test_add_buffer: False 56 | dyna_mem_iter: "STOP_loss" 57 | stop_ratio : 3 58 | offline: False 59 | 60 | 61 | -------------------------------------------------------------------------------- /config/general_1_aug_seed3.yml: -------------------------------------------------------------------------------- 1 | parameters: 2 | #experiment 3 | num_runs: 1 #5 4 | seed: 1259053 5 | # optimizer 6 | optimizer: SGD 7 | epoch: 1 8 | batch: 10 9 | test_batch: 128 10 | # validation 11 | val_size: 0.0 12 | num_val: 2 #17 13 | num_runs_val: 1 #3 14 | # data general 15 | fix_order: False 16 | plot_sample: False 17 | online: True 18 | 19 | exp_name: para17_7_4 20 | 21 | dataset_random_type: "task_random" 22 | switch_buffer_type: "one_buffer" 23 | RL_type: "NoRL" 24 | use_test_buffer: False 25 | test_mem_type: "after" 26 | use_tmp_buffer: False 27 | frozen_old_fc: False 28 | mem_iters: 1 29 | incoming_ratio: 1 30 | mem_ratio: 1 31 | save_prefix: "joint_training" 32 | replay_old_only: False 33 | only_task_seen: False 34 | use_softmaxloss: False 35 | buffer_tracker: False 36 | error_analysis: False 37 | start_mem_iters: -1 38 | close_loop_mem_type: "random" 39 | joint_replay_type: "together" 40 | online_hyper_tune: False 41 | 42 | lambda_: 100 43 | nmc_trick : False 44 | learning_rate: 0.1 45 | weight_decay: 0.0001 46 | 47 | randaug: True 48 | #randaug_M: 14 49 | resnet_size: "reduced" 50 | aug_target: "both" 51 | aug_start: 0 52 | randaug_type: "static" 53 | scraug: False 54 | immediate_evaluate: False 55 | test_add_buffer: False 56 | dyna_mem_iter: "STOP_loss" 57 | stop_ratio : 3 58 | offline: False 59 | 60 | 61 | -------------------------------------------------------------------------------- /models/ndpm/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.functional import binary_cross_entropy 5 | 6 | 7 | def gaussian_nll(x, mean, log_var, min_noise=0.001): 8 | return ( 9 | ((x - mean) ** 2 + min_noise) / (2 * log_var.exp() + 1e-8) 10 | + 0.5 * log_var + 0.5 * np.log(2 * np.pi) 11 | ) 12 | 13 | 14 | def laplace_nll(x, median, log_scale, min_noise=0.01): 15 | return ( 16 | ((x - median).abs() + min_noise) / (log_scale.exp() + 1e-8) 17 | + log_scale + np.log(2) 18 | ) 19 | 20 | 21 | def bernoulli_nll(x, p): 22 | # Broadcast 23 | x_exp, p_exp = [], [] 24 | for x_size, p_size in zip(x.size(), p.size()): 25 | if x_size > p_size: 26 | x_exp.append(-1) 27 | p_exp.append(x_size) 28 | elif x_size < p_size: 29 | x_exp.append(p_size) 30 | p_exp.append(-1) 31 | else: 32 | x_exp.append(-1) 33 | p_exp.append(-1) 34 | x = x.expand(*x_exp) 35 | p = p.expand(*p_exp) 36 | 37 | return binary_cross_entropy(p, x, reduction='none') 38 | 39 | 40 | def logistic_nll(x, mean, log_scale): 41 | bin_size = 1 / 256 42 | scale = log_scale.exp() 43 | x_centered = x - mean 44 | cdf1 = x_centered / scale 45 | cdf2 = (x_centered + bin_size) / scale 46 | p = torch.sigmoid(cdf2) - torch.sigmoid(cdf1) + 1e-12 47 | return -p.log() 48 | -------------------------------------------------------------------------------- /utils/argparser/argparser_aug.py: -------------------------------------------------------------------------------- 1 | 2 | from utils.utils import boolean_string 3 | 4 | 5 | def parse_aug(parser): 6 | 7 | ########################## RAR ##################### 8 | parser.add_argument("--adjust_aug_flag",default=False,type=boolean_string) 9 | parser.add_argument("--randaug",default=False,type=boolean_string) 10 | parser.add_argument("--deraug",default=False,type=boolean_string) 11 | parser.add_argument("--aug_normal",default=False,type=boolean_string) 12 | parser.add_argument("--randaug_type",default="static",choices=["dynamic","static"]) 13 | parser.add_argument("--aug_target",default="both",choices=["mem","incoming","both","none"]) 14 | parser.add_argument("--scraug",default=False) 15 | parser.add_argument("--scrview",default="scraug",choices=["None","randaug","scraug"]) 16 | parser.add_argument("--randaug_N", default=0,type=int) 17 | parser.add_argument("--randaug_N_mem", default=0,type=int) 18 | parser.add_argument("--randaug_N_incoming", default=0,type=int) 19 | parser.add_argument("--randaug_M", default=1,type=int) 20 | parser.add_argument("--aug_start",default=0,type=int) 21 | 22 | parser.add_argument("--quality",default=100,type=int) 23 | # parser.add_argument("--do_cutmix", dest="do_cutmix", default=False, type=boolean_string) 24 | # parser.add_argument("--cutmix_prob", default=0.5, type=float) 25 | # parser.add_argument("--cutmix_batch", default=10, type=int) 26 | # parser.add_argument("--cutmix_type", default="random", choices=["most_confused","train_mem","random","cross_task","mixed"]) 27 | 28 | return parser 29 | -------------------------------------------------------------------------------- /main_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.io import load_yaml 3 | from types import SimpleNamespace 4 | from utils.utils import boolean_string 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | from experiment.run import multiple_run 10 | 11 | def main(args): 12 | genereal_params = load_yaml(args.general) 13 | data_params = load_yaml(args.data) 14 | agent_params = load_yaml(args.agent) 15 | genereal_params['verbose'] = args.verbose 16 | genereal_params['cuda'] = torch.cuda.is_available() 17 | final_params = SimpleNamespace(**genereal_params, **data_params, **agent_params) 18 | time_start = time.time() 19 | print(final_params) 20 | 21 | #reproduce 22 | np.random.seed(final_params.seed) 23 | random.seed(final_params.seed) 24 | torch.manual_seed(final_params.seed) 25 | if final_params.cuda: 26 | torch.cuda.manual_seed(final_params.seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | #run 31 | multiple_run(final_params) 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | # Commandline arguments 37 | parser = argparse.ArgumentParser('CVPR Continual Learning Challenge') 38 | parser.add_argument('--general', dest='general', default='config/general.yml') 39 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml') 40 | parser.add_argument('--agent', dest='agent', default='config/agent/er.yml') 41 | 42 | parser.add_argument('--verbose', type=boolean_string, default=True, 43 | help='print information or not') 44 | args = parser.parse_args() 45 | main(args) -------------------------------------------------------------------------------- /agents/cndpm.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from models.ndpm.ndpm import Ndpm 4 | from utils.setup_elements import transforms_match 5 | from torch.utils import data 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class Cndpm(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Cndpm, self).__init__(model, opt, params) 13 | self.model = model 14 | 15 | 16 | def train_learner(self, x_train, y_train): 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | # setup tracker 22 | losses_batch = AverageMeter() 23 | acc_batch = AverageMeter() 24 | 25 | self.model.train() 26 | 27 | for ep in range(self.epoch): 28 | for i, batch_data in enumerate(train_loader): 29 | # batch update 30 | batch_x, batch_y = batch_data 31 | batch_x = maybe_cuda(batch_x, self.cuda) 32 | batch_y = maybe_cuda(batch_y, self.cuda) 33 | self.model.learn(batch_x, batch_y) 34 | if self.params.verbose: 35 | print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format( 36 | i, 37 | len(self.model.stm_x), self.params.stm_capacity, 38 | len(self.model.experts) - 1 39 | ), end='') 40 | print() 41 | -------------------------------------------------------------------------------- /utils/argparser/argparser_scr.py: -------------------------------------------------------------------------------- 1 | from utils.utils import boolean_string 2 | 3 | 4 | def parse_scr(parser): 5 | ####################SupContrast###################### 6 | parser.add_argument('--temp', type=float, default=0.07, 7 | help='temperature for loss function') 8 | parser.add_argument('--examine_train', type=boolean_string, default=False, 9 | ) 10 | parser.add_argument('--no_aug', type=boolean_string, default=False, 11 | ) 12 | parser.add_argument('--single_aug', type=boolean_string, default=False, 13 | ) 14 | parser.add_argument('--aug_type', default="") 15 | parser.add_argument('--softmaxhead_lr', type=float, default=0.1) 16 | 17 | parser.add_argument('--buffer_tracker', type=boolean_string, default=False, 18 | help='Keep track of buffer with a dictionary') 19 | parser.add_argument('--warmup', type=int, default=4, 20 | help='warmup of buffer before retrieve') 21 | parser.add_argument('--head', type=str, default='mlp', 22 | help='projection head') 23 | 24 | parser.add_argument('--use_softmaxloss', type=boolean_string, default=False) 25 | parser.add_argument('--softmax_nlayers', type=int, default=1, help="softmax head for scr") 26 | parser.add_argument('--softmax_nsize', type=int, default=1024, help="softmax head size for scr") 27 | parser.add_argument('--softmax_membatch', type=int, default=100, help="softmax mem batchsize for scr") 28 | parser.add_argument('--softmax_dropout', type=boolean_string, default=False, 29 | help="whether to use dropout in softmax head") 30 | parser.add_argument('--softmax_type', type=str, default='None', choices=['None', 'seperate', 'meta']) 31 | 32 | 33 | return parser -------------------------------------------------------------------------------- /utils/buffer/rl_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.buffer_utils import random_retrieve 3 | 4 | import copy 5 | 6 | 7 | class RL_retrieve(object): 8 | def __init__(self, params,RL_agent,RL_env): 9 | 10 | super().__init__() 11 | self.RL_agent = RL_agent 12 | self.RL_env = RL_env 13 | 14 | self.params = params 15 | self.subsample = params.subsample 16 | self.num_retrieve = params.eps_mem_batch 17 | 18 | def retrieve(self, buffer, **kwargs): 19 | sub_x, sub_y, mem_indices = random_retrieve(buffer, self.subsample, return_indices=True) 20 | #sub_x, sub_y = random_retrieve(buffer, self.subsample) 21 | 22 | 23 | 24 | if sub_x.size(0) > 0: 25 | 26 | ## TODO-zyq: rl 27 | # mem_indices [0,5000], big_ind [0,50] 28 | #state = env.compute_state(mem_indices) # 50*3 29 | action = self.RL_agent.sample_action() 30 | big_ind = self.RL_env.from_action_to_indices(action,buffer,mem_indices) 31 | buffer.update_replay_times(mem_indices[big_ind]) 32 | #print("RL retreive",) 33 | return sub_x[big_ind], sub_y[big_ind] 34 | else: 35 | return sub_x, sub_y 36 | 37 | # 38 | # if sub_x.size(0) > 0: 39 | # ## TODO-zyq: rl 40 | # # mem_indices [0,5000], big_ind [0,50] 41 | # state = env.compute_state(mem_indices) # 50*3 42 | # action = RL_agent.sample(state) 43 | # big_ind = env.intepret(action) 44 | # reward, next_state = env.step(action,) 45 | # RL_agent.update_agent(state, action, reward, next_state) # todo 46 | # buffer.update_replay_times(mem_indices[big_ind]) 47 | # return sub_x[big_ind], sub_y[big_ind] 48 | # else: 49 | # return sub_x, sub_y 50 | 51 | -------------------------------------------------------------------------------- /models/ndpm/priors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | from utils.utils import maybe_cuda 5 | 6 | 7 | class Prior(ABC): 8 | def __init__(self, params): 9 | self.params = params 10 | 11 | @abstractmethod 12 | def add_expert(self): 13 | pass 14 | 15 | @abstractmethod 16 | def record_usage(self, usage, index=None): 17 | pass 18 | 19 | @abstractmethod 20 | def nl_prior(self, normalize=False): 21 | pass 22 | 23 | 24 | class CumulativePrior(Prior): 25 | def __init__(self, params): 26 | super().__init__(params) 27 | self.log_counts = maybe_cuda(torch.tensor( 28 | params.log_alpha 29 | )).float().unsqueeze(0) 30 | 31 | def add_expert(self): 32 | self.log_counts = torch.cat( 33 | [self.log_counts, maybe_cuda(torch.zeros(1))], 34 | dim=0 35 | ) 36 | 37 | def record_usage(self, usage, index=None): 38 | """Record expert usage 39 | 40 | Args: 41 | usage: Tensor of shape [K+1] if index is None else scalar 42 | index: expert index 43 | """ 44 | if index is None: 45 | self.log_counts = torch.logsumexp(torch.stack([ 46 | self.log_counts, 47 | usage.log() 48 | ], dim=1), dim=1) 49 | else: 50 | self.log_counts[index] = torch.logsumexp(torch.stack([ 51 | self.log_counts[index], 52 | maybe_cuda(torch.tensor(usage)).float().log() 53 | ], dim=0), dim=0) 54 | 55 | def nl_prior(self, normalize=False): 56 | nl_prior = -self.log_counts 57 | if normalize: 58 | nl_prior += torch.logsumexp(self.log_counts, dim=0) 59 | return nl_prior 60 | 61 | @property 62 | def counts(self): 63 | return self.log_counts.exp() 64 | -------------------------------------------------------------------------------- /models/ndpm/expert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.ndpm.classifier import ResNetSharingClassifier 4 | from models.ndpm.vae import CnnSharingVae 5 | from utils.utils import maybe_cuda 6 | 7 | from utils.global_vars import * 8 | 9 | 10 | class Expert(nn.Module): 11 | def __init__(self, params, experts=()): 12 | super().__init__() 13 | self.id = len(experts) 14 | self.experts = experts 15 | 16 | self.g = maybe_cuda(CnnSharingVae(params, experts)) 17 | self.d = maybe_cuda(ResNetSharingClassifier(params, experts)) if not MODELS_NDPM_NDPM_DISABLE_D else None 18 | 19 | 20 | # use random initialized g if it's a placeholder 21 | if self.id == 0: 22 | self.eval() 23 | for p in self.g.parameters(): 24 | p.requires_grad = False 25 | 26 | # use random initialized d if it's a placeholder 27 | if self.id == 0 and self.d is not None: 28 | for p in self.d.parameters(): 29 | p.requires_grad = False 30 | 31 | def forward(self, x): 32 | return self.d(x) 33 | 34 | def nll(self, x, y, step=None): 35 | """Negative log likelihood""" 36 | nll = self.g.nll(x, step) 37 | if self.d is not None: 38 | d_nll = self.d.nll(x, y, step) 39 | nll = nll + d_nll 40 | return nll 41 | 42 | def collect_nll(self, x, y, step=None): 43 | if self.id == 0: 44 | nll = self.nll(x, y, step) 45 | return nll.unsqueeze(1) 46 | 47 | nll = self.g.collect_nll(x, step) 48 | if self.d is not None: 49 | d_nll = self.d.collect_nll(x, y, step) 50 | nll = nll + d_nll 51 | 52 | return nll 53 | 54 | def lr_scheduler_step(self): 55 | if self.g.lr_scheduler is not NotImplemented: 56 | self.g.lr_scheduler.step() 57 | if self.d is not None and self.d.lr_scheduler is not NotImplemented: 58 | self.d.lr_scheduler.step() 59 | 60 | def clip_grad(self): 61 | self.g.clip_grad() 62 | if self.d is not None: 63 | self.d.clip_grad() 64 | 65 | def optimizer_step(self): 66 | self.g.optimizer.step() 67 | if self.d is not None: 68 | self.d.optimizer.step() 69 | -------------------------------------------------------------------------------- /utils/buffer/tmp_buffer.py: -------------------------------------------------------------------------------- 1 | from utils.setup_elements import input_size_match 2 | from utils import name_match # import update_methods, retrieve_methods 3 | from utils.utils import maybe_cuda 4 | import torch 5 | import numpy as np 6 | 7 | 8 | 9 | class Tmp_Buffer(torch.nn.Module): 10 | def __init__(self, model, params,buffer): 11 | super().__init__() 12 | self.buffer= buffer 13 | self.params = params 14 | self.model = model 15 | input_size = input_size_match[params.data] 16 | self.tmp_buffer_size = params.mem_size # to-do: zyq set tmp buffer size 17 | self.tmp_x = maybe_cuda(torch.FloatTensor(self.tmp_buffer_size, *input_size).fill_(0)) 18 | self.tmp_y = maybe_cuda(torch.LongTensor(self.tmp_buffer_size).fill_(0)) 19 | self.current_n = 0 20 | 21 | 22 | 23 | # define buffer 24 | 25 | 26 | 27 | 28 | def tmp_store(self,batch_x,batch_y): 29 | 30 | new_num = batch_x.size(0) 31 | self.tmp_x[self.current_n:self.current_n+new_num]=batch_x 32 | self.tmp_y[self.current_n:self.current_n+new_num]=batch_y 33 | self.current_n += new_num 34 | #print("tmp store",self.tmp_x.size(0),self.current_n,new_num,batch_x.size(0),batch_y.size(0)) 35 | 36 | def reset(self): 37 | self.tmp_x = torch.zeros_like(self.tmp_x) 38 | self.tmp_y = torch.zeros_like(self.tmp_y) 39 | print("reset tmp memory, space used in tmp memory",self.current_n) 40 | self.current_n = 0 41 | 42 | 43 | 44 | 45 | def update_true_buffer(self): 46 | #idx_buffer = torch.FloatTensor(self.current_n).to(self.tmp_x.device).uniform_(0, self.params.mem_size).long() 47 | idx_buffer = torch.FloatTensor(self.current_n).uniform_(0, self.params.mem_size).long() 48 | 49 | idx_new_data = torch.range(0,self.current_n) 50 | 51 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))} 52 | print(self.tmp_x.device,self.tmp_y.device,idx_buffer.device,idx_new_data.device) 53 | self.buffer.overwrite(idx_map, self.tmp_x.cpu(), self.tmp_y.cpu()) 54 | # print("to be store",self.tmp_y[self.current_n]) 55 | # print("buffer indices",idx_buffer) 56 | # print("after replacement",self.buffer.buffer_label[idx_buffer]) 57 | 58 | self.reset() 59 | -------------------------------------------------------------------------------- /agents/lwf.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | import copy 8 | 9 | 10 | class Lwf(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(Lwf, self).__init__(model, opt, params) 13 | 14 | def train_learner(self, x_train, y_train): 15 | self.before_train(x_train, y_train) 16 | 17 | # set up loader 18 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 19 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 20 | drop_last=True) 21 | 22 | # set up model 23 | self.model = self.model.train() 24 | 25 | # setup tracker 26 | losses_batch = AverageMeter() 27 | acc_batch = AverageMeter() 28 | 29 | for ep in range(self.epoch): 30 | for i, batch_data in enumerate(train_loader): 31 | # batch update 32 | batch_x, batch_y = batch_data 33 | batch_x = maybe_cuda(batch_x, self.cuda) 34 | batch_y = maybe_cuda(batch_y, self.cuda) 35 | 36 | logits = self.forward(batch_x) 37 | loss_old = self.kd_manager.get_kd_loss(logits, batch_x) 38 | loss_new = self.criterion(logits, batch_y) 39 | loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old 40 | _, pred_label = torch.max(logits, 1) 41 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 42 | # update tracker 43 | acc_batch.update(correct_cnt, batch_y.size(0)) 44 | losses_batch.update(loss, batch_y.size(0)) 45 | # backward 46 | self.opt.zero_grad() 47 | loss.backward() 48 | self.opt.step() 49 | 50 | if i % 100 == 1 and self.verbose: 51 | print( 52 | '==>>> it: {}, avg. loss: {:.6f}, ' 53 | 'running train acc: {:.3f}' 54 | .format(i, losses_batch.avg(), acc_batch.avg()) 55 | ) 56 | self.after_train() 57 | -------------------------------------------------------------------------------- /experiment/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import sem 3 | import scipy.stats as stats 4 | 5 | def compute_performance(end_task_acc_arr): 6 | """ 7 | Given test accuracy results from multiple runs saved in end_task_acc_arr, 8 | compute the average accuracy, forgetting, and task accuracies as well as their confidence intervals. 9 | 10 | :param end_task_acc_arr: (list) List of lists 11 | :param task_ids: (list or tuple) Task ids to keep track of 12 | :return: (avg_end_acc, forgetting, avg_acc_task) 13 | """ 14 | n_run, n_tasks = end_task_acc_arr.shape[:2] 15 | t_coef = stats.t.ppf((1+0.95) / 2, n_run-1) # t coefficient used to compute 95% CIs: mean +- t * 16 | 17 | # compute average test accuracy and CI 18 | end_acc = end_task_acc_arr[:, -1, :] # shape: (num_run, num_task) 19 | avg_acc_per_run = np.mean(end_acc, axis=1) # mean of end task accuracies per run 20 | avg_end_acc = (np.mean(avg_acc_per_run), t_coef * sem(avg_acc_per_run)) 21 | 22 | # compute forgetting 23 | best_acc = np.max(end_task_acc_arr, axis=1) 24 | final_forgets = best_acc - end_acc 25 | avg_fgt = np.mean(final_forgets, axis=1) 26 | avg_end_fgt = (np.mean(avg_fgt), t_coef * sem(avg_fgt)) 27 | 28 | # compute ACC 29 | acc_per_run = np.mean((np.sum(np.tril(end_task_acc_arr), axis=2) / 30 | (np.arange(n_tasks) + 1)), axis=1) 31 | avg_acc = (np.mean(acc_per_run), t_coef * sem(acc_per_run)) 32 | 33 | 34 | # compute BWT+ 35 | bwt_per_run = (np.sum(np.tril(end_task_acc_arr, -1), axis=(1,2)) - 36 | np.sum(np.diagonal(end_task_acc_arr, axis1=1, axis2=2) * 37 | (np.arange(n_tasks, 0, -1) - 1), axis=1)) / (n_tasks * (n_tasks - 1) / 2) 38 | bwtp_per_run = np.maximum(bwt_per_run, 0) 39 | avg_bwtp = (np.mean(bwtp_per_run), t_coef * sem(bwtp_per_run)) 40 | 41 | # compute FWT 42 | fwt_per_run = np.sum(np.triu(end_task_acc_arr, 1), axis=(1,2)) / (n_tasks * (n_tasks - 1) / 2) 43 | avg_fwt = (np.mean(fwt_per_run), t_coef * sem(fwt_per_run)) 44 | return avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt 45 | 46 | 47 | 48 | 49 | def single_run_avg_end_fgt(acc_array): 50 | best_acc = np.max(acc_array, axis=1) 51 | end_acc = acc_array[-1] 52 | final_forgets = best_acc - end_acc 53 | avg_fgt = np.mean(final_forgets) 54 | return avg_fgt 55 | -------------------------------------------------------------------------------- /utils/buffer/buffer_logits.py: -------------------------------------------------------------------------------- 1 | from utils.setup_elements import input_size_match 2 | from utils import name_match #import update_methods, retrieve_methods 3 | from utils.utils import maybe_cuda 4 | import torch 5 | import numpy as np 6 | from utils.buffer.buffer_utils import BufferClassTracker 7 | from utils.setup_elements import n_classes 8 | from utils.buffer.buffer import Buffer 9 | 10 | class Buffer_logits(Buffer): 11 | def __init__(self, model, params,mem_size=None,RL_agent=None, RL_env=None,): 12 | super().__init__(model,params,) 13 | if(mem_size==None): 14 | mem_size = self.params.mem_size 15 | buffer_size =mem_size 16 | self.buffer_size = mem_size 17 | print('buffer has %d slots' % buffer_size) 18 | input_size = input_size_match[params.data] 19 | class_num = n_classes[self.params.data] 20 | self.buffer_logits = torch.FloatTensor(buffer_size, class_num).fill_(0) 21 | self.buffer_label = torch.LongTensor(buffer_size).fill_(0) 22 | 23 | 24 | 25 | def update(self, x, y,logits=None,tmp_buffer=None): 26 | self.buffer_used_steps += 1 27 | return self.update_method.update(buffer=self, x=x, y=y,logits=logits,tmp_buffer=tmp_buffer) 28 | 29 | def retrieve(self, **kwargs): 30 | # if(self.retrieve_method.num_retrieve==-1): 31 | # print("dynamic mem batch size") 32 | # 33 | # self.retrieve_method.num_retrieve = self.task_seen_so_far * 10 # to-do: change 10 to the batch size of new data 34 | return self.retrieve_method.retrieve(buffer=self, **kwargs) 35 | def overwrite(self,idx_map,x,y,logits): 36 | ## zyq: save replay_times 37 | #print("----buffer overwrite") 38 | # for i in list(idx_map.keys()): 39 | # replay_times = self.buffer_replay_times[i].detach().cpu().numpy() 40 | # self.unique_replay_list.append(int(replay_times)) 41 | # self.buffer_replay_times[i]=0 42 | # self.buffer_last_replay[i]=0 43 | # sample_label = int(self.buffer_label[i].detach().cpu().numpy()) 44 | # self.replay_sample_label.append(sample_label) 45 | 46 | self.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 47 | self.buffer_label[list(idx_map.keys())] = y[list(idx_map.values())] 48 | self.buffer_logits[list(idx_map.keys())] = logits[list(idx_map.values())] 49 | self.buffer_new_old[list(idx_map.keys())]=1 50 | 51 | 52 | -------------------------------------------------------------------------------- /experiment/tune_hyperparam.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from sklearn.model_selection import ParameterGrid 3 | from utils.setup_elements import setup_opt, setup_architecture 4 | from utils.utils import maybe_cuda 5 | from utils.name_match import agents 6 | import numpy as np 7 | from experiment.metrics import compute_performance 8 | from experiment.save_prefix import get_prefix_time 9 | 10 | 11 | 12 | 13 | def tune_hyper(tune_data, tune_test_loaders, default_params, tune_params): 14 | prefix = get_prefix_time(default_params) 15 | param_grid_list = list(ParameterGrid(tune_params)) 16 | print(len(param_grid_list)) 17 | para_name = list(tune_params.keys()) 18 | para_str = "" 19 | for para in para_name: 20 | para_str += para 21 | print(para_str) 22 | 23 | 24 | tune_accs = [] 25 | tune_fgt = [] 26 | for param_set in param_grid_list: 27 | final_params = vars(default_params) 28 | print(param_set) 29 | final_params.update(param_set) 30 | final_params = SimpleNamespace(**final_params) 31 | accuracy_list = [] 32 | for run in range(final_params.num_runs_val): 33 | tmp_acc = [] 34 | model = setup_architecture(final_params) 35 | model = maybe_cuda(model, final_params.cuda) 36 | opt = setup_opt(final_params.optimizer, model, final_params.learning_rate, final_params.weight_decay) 37 | agent = agents[final_params.agent](model, opt, final_params) 38 | for i, (x_train, y_train, labels) in enumerate(tune_data): 39 | print("-----------tune run {} task {}-------------".format(run, i)) 40 | print('size: {}, {}'.format(x_train.shape, y_train.shape)) 41 | agent.train_learner(x_train, y_train) 42 | acc_array,loss_array = agent.evaluate(tune_test_loaders) 43 | tmp_acc.append(acc_array) 44 | print( 45 | "-----------tune run {}-----------avg_end_acc {}-----------".format(run, np.mean(tmp_acc[-1]))) 46 | accuracy_list.append(np.array(tmp_acc)) 47 | accuracy_list = np.array(accuracy_list) 48 | 49 | avg_end_acc, avg_end_fgt, avg_acc, avg_bwtp, avg_fwt = compute_performance(accuracy_list) 50 | tune_accs.append(avg_end_acc[0]) 51 | tune_fgt.append(avg_end_fgt[0]) 52 | best_tune = param_grid_list[tune_accs.index(max(tune_accs))] 53 | 54 | print("save tune acc!!!!") 55 | 56 | np.save(prefix +para_str+"_tune_acc.npy",np.array(tune_accs)) 57 | np.save(prefix +para_str+"_tune_para.npy", np.array(param_grid_list)) 58 | return best_tune -------------------------------------------------------------------------------- /main_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.io import load_yaml 3 | from types import SimpleNamespace 4 | from utils.utils import boolean_string 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | from experiment.run import multiple_run_tune_separate 10 | from utils.setup_elements import default_trick 11 | 12 | def main(args): 13 | genereal_params = load_yaml(args.general) 14 | data_params = load_yaml(args.data) 15 | default_params = load_yaml(args.default) 16 | tune_params = load_yaml(args.tune) 17 | genereal_params['verbose'] = args.verbose 18 | genereal_params['cuda'] = torch.cuda.is_available() 19 | genereal_params['train_val'] = args.train_val 20 | if args.trick: 21 | default_trick[args.trick] = True 22 | genereal_params['trick'] = default_trick 23 | final_default_params = SimpleNamespace(**genereal_params, **data_params, **default_params) 24 | 25 | time_start = time.time() 26 | print(final_default_params) 27 | print() 28 | 29 | #reproduce 30 | np.random.seed(final_default_params.seed) 31 | random.seed(final_default_params.seed) 32 | torch.manual_seed(final_default_params.seed) 33 | if final_default_params.cuda: 34 | torch.cuda.manual_seed(final_default_params.seed) 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | #run 39 | multiple_run_tune_separate(final_default_params, tune_params, args.save_path) 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | # Commandline arguments 45 | parser = argparse.ArgumentParser('Continual Learning') 46 | parser.add_argument('--general', dest='general', default='config/general_1.yml') 47 | parser.add_argument('--data', dest='data', default='config/data/cifar100/cifar100_nc.yml') 48 | parser.add_argument('--default', dest='default', default='config/agent/er/er_10k.yml') 49 | parser.add_argument('--tune', dest='tune', default='config/agent/er/er_tune.yml') 50 | parser.add_argument('--save-path', dest='save_path', default=None) 51 | parser.add_argument('--verbose', type=boolean_string, default=False, 52 | help='print information or not') 53 | parser.add_argument('--train_val', type=boolean_string, default=False, 54 | help='use tha val batches to train') 55 | parser.add_argument('--trick', type=str, default=None) 56 | parser.add_argument('--exp_name',type=str,default="") 57 | parser.add_argument('--GPU_ID', dest='GPU_ID', default= 6, 58 | type=int, 59 | help="") 60 | args = parser.parse_args() 61 | # args.cuda = torch.cuda.is_available() 62 | # torch.cuda.set_device(args.GPU_ID)#args.GPU_I 63 | main(args) -------------------------------------------------------------------------------- /utils/buffer/replay_times_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Replay_times_update(object): 6 | def __init__(self, params): 7 | super().__init__() 8 | 9 | def update(self, buffer, x, y, **kwargs): 10 | batch_size = x.size(0) 11 | 12 | # add whatever still fits in the buffer 13 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index) 14 | if place_left: 15 | offset = min(place_left, batch_size) 16 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset]) 17 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset]) 18 | 19 | 20 | buffer.current_index += offset 21 | buffer.n_seen_so_far += offset 22 | 23 | # everything was added 24 | if offset == x.size(0): 25 | return list(range(buffer.current_index, buffer.current_index + offset)) 26 | 27 | # remove what is already in the buffer 28 | x, y = x[place_left:], y[place_left:] 29 | 30 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long() 31 | 32 | valid_indices = (indices < buffer.buffer_img.size(0)).long() 33 | 34 | 35 | idx_new_data = valid_indices.nonzero().squeeze(-1) 36 | idx_buffer = indices[idx_new_data] 37 | 38 | ## zyq: choose the samples with least replay times to be replaced 39 | 40 | 41 | # store_sample_num = torch.sum(valid_indices.detach().cpu().numpy()) 42 | # idx_buffer = torch.argsort(buffer.buffer_replay_times.detach().cpu().numpy())[:store_sample_num] 43 | # idx_buffer = idx_buffer.cuda() 44 | ############################ 45 | 46 | buffer.n_seen_so_far += x.size(0) 47 | 48 | if idx_buffer.numel() == 0: 49 | return [] 50 | 51 | assert idx_buffer.max() < buffer.buffer_img.size(0) 52 | assert idx_buffer.max() < buffer.buffer_label.size(0) 53 | # assert idx_buffer.max() < self.buffer_task.size(0) 54 | 55 | assert idx_new_data.max() < x.size(0) 56 | assert idx_new_data.max() < y.size(0) 57 | 58 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))} 59 | ## zyq: save replay_times 60 | for i in list(idx_map.keys()): 61 | replay_times = buffer.buffer_replay_times[i].detach().cpu().numpy() 62 | buffer.unique_replay_list.append(int(replay_times)) 63 | buffer.buffer_replay_times[i]=0 64 | buffer.buffer_last_replay[i]=0 65 | label = int(buffer.buffer_label[i].detach().cpu().numpy()) 66 | buffer.replay_sample_label.append(label) 67 | # perform overwrite op 68 | buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 69 | buffer.buffer_label[list(idx_map.keys())] = y[list(idx_map.values())] 70 | return list(idx_map.keys()) -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, create_task_composition_order,load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR10(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar10' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR10, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | 17 | 18 | def download_load(self): 19 | dataset_train = datasets.CIFAR10(root=self.root, train=True, download=True) 20 | self.train_data = dataset_train.data 21 | self.train_label = np.array(dataset_train.targets) 22 | dataset_test = datasets.CIFAR10(root=self.root, train=False, download=True) 23 | self.test_data = dataset_test.data 24 | self.test_label = np.array(dataset_test.targets) 25 | 26 | def setup(self): 27 | if self.scenario == 'ni': 28 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 29 | self.train_label, 30 | self.test_data, self.test_label, 31 | self.task_nums, 32, 32 | self.params.val_size, 33 | self.params.ns_type, self.params.ns_factor, 34 | plot=self.params.plot_sample) 35 | elif self.scenario == 'nc':## todo: task order 36 | #self.task_labels = create_task_composition(class_nums=10, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 37 | self.task_labels = create_task_composition_order(class_nums=10, num_tasks=self.task_nums,) 38 | 39 | self.test_set = [] 40 | for labels in self.task_labels: 41 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 42 | self.test_set.append((x_test, y_test)) 43 | else: 44 | raise Exception('wrong scenario') 45 | 46 | def new_task(self, cur_task, **kwargs): 47 | if self.scenario == 'ni': 48 | x_train, y_train = self.train_set[cur_task] 49 | labels = set(y_train) 50 | elif self.scenario == 'nc': 51 | labels = self.task_labels[cur_task] 52 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 53 | return x_train, y_train, labels 54 | 55 | def new_run(self, **kwargs): 56 | self.setup() 57 | return self.test_set 58 | 59 | def test_plot(self): 60 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 61 | self.params.ns_factor) 62 | -------------------------------------------------------------------------------- /agents/icarl.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils import utils 4 | from utils.buffer.buffer_utils import random_retrieve 5 | from utils.setup_elements import transforms_match 6 | from torch.utils import data 7 | import numpy as np 8 | from torch.nn import functional as F 9 | from utils.utils import maybe_cuda, AverageMeter 10 | from utils.buffer.buffer import Buffer 11 | import torch 12 | import copy 13 | 14 | 15 | class Icarl(ContinualLearner): 16 | def __init__(self, model, opt, params): 17 | super(Icarl, self).__init__(model, opt, params) 18 | self.model = model 19 | self.mem_size = params.mem_size 20 | self.buffer = Buffer(model, params) 21 | self.prev_model = None 22 | 23 | def train_learner(self, x_train, y_train,labels=None): 24 | self.before_train(x_train, y_train) 25 | # set up loader 26 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 27 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 28 | drop_last=True) 29 | self.model.train() 30 | self.update_representation(train_loader) 31 | self.prev_model = copy.deepcopy(self.model) 32 | self.after_train() 33 | 34 | def update_representation(self, train_loader): 35 | updated_idx = [] 36 | for ep in range(self.epoch): 37 | for i, train_data in enumerate(train_loader): 38 | # batch update 39 | train_x, train_y = train_data 40 | train_x = maybe_cuda(train_x, self.cuda) 41 | train_y = maybe_cuda(train_y, self.cuda) 42 | train_y_copy = train_y.clone() 43 | for k, y in enumerate(train_y_copy): 44 | train_y_copy[k] = len(self.old_labels) + self.new_labels.index(y) 45 | all_cls_num = len(self.new_labels) + len(self.old_labels) 46 | target_labels = utils.ohe_label(train_y_copy, all_cls_num, device=train_y_copy.device).float() 47 | if self.prev_model is not None: 48 | mem_x, mem_y,indices = random_retrieve(self.buffer, self.batch, 49 | excl_indices=updated_idx,return_indices=True) 50 | #zyq: icarl replay update 51 | self.buffer.update_replay_times(indices) 52 | mem_x = maybe_cuda(mem_x, self.cuda) 53 | batch_x = torch.cat([train_x, mem_x]) 54 | target_labels = torch.cat([target_labels, torch.zeros_like(target_labels)]) 55 | else: 56 | batch_x = train_x 57 | logits = self.forward(batch_x) 58 | 59 | self.opt.zero_grad() 60 | if self.prev_model is not None: 61 | with torch.no_grad(): 62 | q = torch.sigmoid(self.prev_model.forward(batch_x)) 63 | for k, y in enumerate(self.old_labels): 64 | target_labels[:, k] = q[:, k] 65 | loss = F.binary_cross_entropy_with_logits(logits[:, :all_cls_num], target_labels, reduction='none').sum(dim=1).mean() 66 | loss.backward() 67 | self.opt.step() 68 | updated_idx += self.buffer.update(train_x, train_y) 69 | 70 | ## todo zyq: save replay times and label of all the samples ever enter the memory 71 | #self.buffer.save_buffer_info() 72 | 73 | 74 | -------------------------------------------------------------------------------- /utils/argparser/argparser_replay.py: -------------------------------------------------------------------------------- 1 | 2 | from utils.utils import boolean_string 3 | 4 | def parse_replay(parser): 5 | 6 | ################## Adaptive CL ########################### 7 | 8 | parser.add_argument("--adjust_iter_flag",default=False,type=boolean_string) 9 | parser.add_argument("--dyna_mem_iter",dest='dyna_mem_iter',default="None",type=str,choices=["random","STOP_loss","None","STOP_acc_loss","STOP_acc"], 10 | help='If True, adjust mem iter') 11 | parser.add_argument("--train_acc_max",default=0.95,type=float) 12 | parser.add_argument("--train_acc_max_aug",default=0.90,type=float) 13 | parser.add_argument("--train_acc_min_aug",default=0.80,type=float) 14 | parser.add_argument("--train_acc_min",default=0.85,type=float) 15 | 16 | 17 | parser.add_argument('--mem_iter_max', dest='mem_iter_max', default=20, type=int, 18 | help='') 19 | 20 | parser.add_argument('--mem_iter_min', dest='mem_iter_min', default=1, type=int, 21 | help='') 22 | 23 | parser.add_argument("--dyna_type",default="train_acc",choices=["bpg","random","train_acc"]) 24 | 25 | 26 | 27 | # #################### replay dynamics #################### 28 | parser.add_argument("--joint_replay_type",default="together",choices=["together","seperate"], 29 | help="implementation type of joint training of incoming batch and memory batch") 30 | parser.add_argument("--online_hyper_tune", default=False, type=boolean_string) 31 | parser.add_argument("--online_hyper_valid_type", default="test_data", type=str, choices=["real_data","test_mem"]) 32 | parser.add_argument("--online_hyper_freq", default=1, type=int) 33 | parser.add_argument("--online_hyper_lr_list_type",default="basic",choices=["scr","basic","4lr","5lr"]) 34 | parser.add_argument("--online_hyper_RL",default=False,type=boolean_string) 35 | parser.add_argument("--scr_memIter", default=False, type=boolean_string) 36 | parser.add_argument("--scr_memIter_type",default="c_MAB",choices=["c_MAB","MAB"]) 37 | parser.add_argument("--scr_memIter_state_type", default="4dim", choices=["7dim","6dim","3dim","4dim","train"]) 38 | parser.add_argument("--scr_memIter_action_type", default="4", choices=["4","8"]) 39 | 40 | # parser.add_argument("--temperature_scaling",default=False,type=boolean_string) 41 | # parser.add_argument("--frozen_old_fc", dest="frozen_old_fc", default=False, type=boolean_string) 42 | parser.add_argument("--close_loop_mem_type", default="random", 43 | choices=["low_acc", "random", ]) 44 | 45 | parser.add_argument('--mem_ratio_max', default=1.5, 46 | help='') 47 | 48 | parser.add_argument('--mem_ratio_min', default=0.1, 49 | help='') 50 | parser.add_argument('--incoming_ratio', dest='incoming_ratio', default=1.0, type=float, 51 | help='incoming gradient update ratio') 52 | parser.add_argument('--mem_ratio', dest='mem_ratio', default=1.0, type=float, 53 | help='mem gradient update ratio') 54 | 55 | parser.add_argument('--task_start_mem_ratio', dest='task_start_mem_ratio', default=0.5, type=float, 56 | help='mem gradient update ratio') 57 | parser.add_argument('--task_start_incoming_ratio', dest='task_start_incoming_ratio', default=0.1, type=float, 58 | help='mem gradient update ratio') 59 | parser.add_argument("--dyna_ratio", dest='dyna_ratio', type=str, default="None", choices=['dyna','random','None'], 60 | help='adjust dyna_ratio') 61 | 62 | parser.add_argument("--adaptive_ratio_type",type=str,default="offline",choices=["online","offline",]) 63 | return parser -------------------------------------------------------------------------------- /models/modelfactory.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | 3 | class ModelFactory(): 4 | def __init__(self): 5 | pass 6 | 7 | @staticmethod 8 | def get_model(model_type, sizes, dataset='mnist', hidden_size =320,nlayers=1,args=None): 9 | 10 | net_list = [] 11 | if "mnist" in dataset: 12 | if model_type=="linear": 13 | for i in range(0, len(sizes) - 1): 14 | net_list.append(('linear', [sizes[i+1], sizes[i]], '')) 15 | if i < (len(sizes) - 2): 16 | net_list.append(('relu', [True], '')) 17 | if i == (len(sizes) - 2): 18 | net_list.append(('rep', [], '')) 19 | return net_list 20 | 21 | elif dataset == "tinyimagenet": 22 | 23 | if model_type == 'pc_cnn': 24 | channels = 160 25 | return [ 26 | ('conv2d', [channels, 3, 3, 3, 2, 1], ''), 27 | ('relu', [True], ''), 28 | 29 | ('conv2d', [channels, channels, 3, 3, 2, 1], ''), 30 | ('relu', [True], ''), 31 | 32 | ('conv2d', [channels, channels, 3, 3, 2, 1], ''), 33 | ('relu', [True], ''), 34 | 35 | ('conv2d', [channels, channels, 3, 3, 2, 1], ''), 36 | ('relu', [True], ''), 37 | 38 | ('flatten', [], ''), 39 | ('rep', [], ''), 40 | 41 | ('linear', [640, 16 * channels], ''), 42 | ('relu', [True], ''), 43 | 44 | ('linear', [640, 640], ''), 45 | ('relu', [True], ''), 46 | ('linear', [sizes[-1], 640], '') 47 | ] 48 | 49 | elif dataset == "cifar100": 50 | 51 | 52 | if model_type == 'pc_cnn': 53 | channels = 160 54 | return [ 55 | ('conv2d', [channels, 3, 3, 3, 2, 1], ''), 56 | ('relu', [True], ''), 57 | 58 | ('conv2d', [channels, channels, 3, 3, 2, 1], ''), 59 | ('relu', [True], ''), 60 | 61 | ('conv2d', [channels, channels, 3, 3, 2, 1], ''), 62 | ('relu', [True], ''), 63 | 64 | ('flatten', [], ''), 65 | ('rep', [], ''), 66 | 67 | ('linear', [320, 16 * channels], ''), 68 | ('relu', [True], ''), 69 | 70 | ('linear', [320, 320], ''), 71 | ('relu', [True], ''), 72 | ('linear', [sizes[-1], 320], '') 73 | ] 74 | elif model_type == 'linear_softmax': 75 | structure = [ ('linear', [hidden_size,sizes], ''), 76 | ('relu', [True], ''), ] 77 | for i in range(nlayers-1): 78 | structure += [ ('linear', [hidden_size, hidden_size], ''), 79 | ('relu', [True], ''),] 80 | structure += [('linear', [100, hidden_size], '')] 81 | return structure 82 | 83 | 84 | # return [ 85 | # 86 | # 87 | # ('linear', [hidden_size,sizes], ''), 88 | # ('relu', [True], ''), 89 | # 90 | # # ('linear', [hidden_size, hidden_size], ''), 91 | # # ('relu', [True], ''), 92 | # ('linear', [100, hidden_size], '') 93 | # ] 94 | 95 | else: 96 | print("Unsupported model; either implement the model in model/ModelFactory or choose a different model") 97 | assert (False) 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /agents/gdumb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import math 4 | from agents.base import ContinualLearner 5 | from continuum.data_utils import dataset_transform 6 | from utils.setup_elements import transforms_match, setup_architecture, setup_opt 7 | from utils.utils import maybe_cuda, EarlyStopping 8 | import numpy as np 9 | import random 10 | 11 | 12 | class Gdumb(ContinualLearner): 13 | def __init__(self, model, opt, params): 14 | super(Gdumb, self).__init__(model, opt, params) 15 | self.mem_img = {} 16 | self.mem_c = {} 17 | #self.early_stopping = EarlyStopping(self.params.min_delta, self.params.patience, self.params.cumulative_delta) 18 | 19 | def greedy_balancing_update(self, x, y): 20 | k_c = self.params.mem_size // max(1, len(self.mem_img)) 21 | if y not in self.mem_img or self.mem_c[y] < k_c: 22 | if sum(self.mem_c.values()) >= self.params.mem_size: 23 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0] 24 | idx = random.randrange(self.mem_c[cls_max]) 25 | self.mem_img[cls_max].pop(idx) 26 | self.mem_c[cls_max] -= 1 27 | if y not in self.mem_img: 28 | self.mem_img[y] = [] 29 | self.mem_c[y] = 0 30 | self.mem_img[y].append(x) 31 | self.mem_c[y] += 1 32 | 33 | def train_learner(self, x_train, y_train): 34 | self.before_train(x_train, y_train) 35 | # set up loader 36 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 37 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 38 | drop_last=True) 39 | 40 | for i, batch_data in enumerate(train_loader): 41 | # batch update 42 | batch_x, batch_y = batch_data 43 | batch_x = maybe_cuda(batch_x, self.cuda) 44 | batch_y = maybe_cuda(batch_y, self.cuda) 45 | # update mem 46 | for j in range(len(batch_x)): 47 | self.greedy_balancing_update(batch_x[j], batch_y[j].item()) 48 | #self.early_stopping.reset() 49 | self.train_mem() 50 | self.after_train() 51 | 52 | def train_mem(self): 53 | mem_x = [] 54 | mem_y = [] 55 | for i in self.mem_img.keys(): 56 | mem_x += self.mem_img[i] 57 | mem_y += [i] * self.mem_c[i] 58 | 59 | mem_x = torch.stack(mem_x) 60 | mem_y = torch.LongTensor(mem_y) 61 | self.model = setup_architecture(self.params) 62 | self.model = maybe_cuda(self.model, self.cuda) 63 | opt = setup_opt(self.params.optimizer, self.model, self.params.learning_rate, self.params.weight_decay) 64 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=1, T_mult=2, eta_min=self.params.minlr) 65 | 66 | #loss = math.inf 67 | for i in range(self.params.mem_epoch): 68 | idx = np.random.permutation(len(mem_x)).tolist() 69 | mem_x = maybe_cuda(mem_x[idx], self.cuda) 70 | mem_y = maybe_cuda(mem_y[idx], self.cuda) 71 | self.model = self.model.train() 72 | batch_size = self.params.batch 73 | #scheduler.step() 74 | #if opt.param_groups[0]['lr'] == self.params.learning_rate: 75 | # if self.early_stopping.step(-loss): 76 | # return 77 | for j in range(len(mem_y) // batch_size): 78 | opt.zero_grad() 79 | logits = self.model.forward(mem_x[batch_size * j:batch_size * (j + 1)]) 80 | loss = self.criterion(logits, mem_y[batch_size * j:batch_size * (j + 1)]) 81 | loss.backward() 82 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip) 83 | opt.step() 84 | 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Repeated Augmented Rehearsal (RAR) for online continual learning 2 | 3 | This is the official code repository for [Repeated Augmented Rehearsal (NeurIPS 2022)](https://arxiv.org/abs/2209.13917). 4 | If you use any content of this repo for your work, please cite the following bib entry: 5 | 6 | ## Citation 7 | ``` 8 | @inproceedings{NEURIPS2022_5ebbbac6, 9 | author = {Zhang, Yaqian and Pfahringer, Bernhard and Frank, Eibe and Bifet, Albert and Lim, Nick Jin Sean and Jia, Yunzhe}, 10 | booktitle = {Advances in Neural Information Processing Systems}, 11 | pages = {14771--14783}, 12 | title = {A simple but strong baseline for online continual learning: Repeated Augmented Rehearsal}, 13 | url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/5ebbbac62b968254093023f1c95015d3-Paper-Conference.pdf}, 14 | volume = {35}, 15 | year = {2022} 16 | } 17 | ``` 18 | 19 | 20 | ## Requirements 21 | ![](https://img.shields.io/badge/python-3.7-green.svg) 22 | 23 | ![](https://img.shields.io/badge/torch-1.5.1-blue.svg) 24 | ![](https://img.shields.io/badge/torchvision-0.6.1-blue.svg) 25 | ![](https://img.shields.io/badge/PyYAML-5.3.1-blue.svg) 26 | ![](https://img.shields.io/badge/scikit--learn-0.23.0-blue.svg) 27 | ---- 28 | Create a virtual enviroment 29 | ```sh 30 | virtualenv rar 31 | ``` 32 | Activating a virtual environment 33 | ```sh 34 | source rar/bin/activate 35 | ``` 36 | Installing packages 37 | ```sh 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ## Run commands 42 | 43 | ### bash runs 44 | A test run of Repeated Augmented Rehearsal(RAR) with experience replay can be performed with the following command: 45 | ``` 46 | bash run_commands/runs/run_test_rar_er_cifar100.sh 47 | ``` 48 | Other experiment commands can be found in the folder of [run_commands/runs](run_commands/runs). 49 | 50 | Detailed descriptions of options can be found in [general_main.py](general_main.py) and [utils/argparser](utils/argparser) 51 | 52 | For example: 53 | 54 | The number of repeated iteration is set via: 55 | ``` 56 | --mem_iters $MEM_ITER 57 | ``` 58 | The number of augmentation strength is set via: 59 | ``` 60 | --randaug True --randaug_N $RAUG_N --randaug_M $RAUG_M 61 | ``` 62 | 63 | ## Evaluation the results 64 | The results of algorithm outputs will be stored in the folder of [results](results/). 65 | 66 | The jupyter notebook [visualize_results.ipynb](visualize_results.ipynb) is used to visualize and analyze results. 67 | 68 | 69 | ## Algorithms 70 | 71 | ### Baselines 72 | * LwF: Learning without forgetting (**ECCV, 2016**) [[Paper]](https://link.springer.com/chapter/10.1007/978-3-319-46493-0_37) 73 | * AGEM: Averaged Gradient Episodic Memory (**ICLR, 2019**) [[Paper]](https://openreview.net/forum?id=Hkf2_sC5FX) 74 | * ER: Experience Replay (**ICML Workshop, 2019**) [[Paper]](https://arxiv.org/abs/1902.10486) 75 | * ASER: Adversarial Shapley Value Experience Replay(**AAAI, 2021**) [[Paper]](https://arxiv.org/abs/2009.00093) 76 | * MIR: Maximally Interfered Retrieval (**NeurIPS, 2019**) [[Paper]](https://proceedings.neurips.cc/paper/2019/hash/15825aee15eb335cc13f9b559f166ee8-Abstract.html) 77 | * SCR: Supervised Contrastive Replay (**CVPR Workshop, 2021**) [[Paper]](https://arxiv.org/abs/2103.13885) 78 | * DER: Dark Experience Replay (**NeurIPS, 2020**) [[Paper]](https://proceedings.neurips.cc/paper/2020/file/b704ea2c39778f07c617f6b7ce480e9e-Paper.pdf) 79 | 80 | 81 | ## Datasets 82 | 83 | ### Online Class Incremental 84 | 85 | - Split CIFAR100 86 | - Split Mini-ImageNet 87 | - CORe50-NC 88 | - CLRS-NC (Continual Learning Benchmark for Remote 89 | Sensing Image Scene Classification) 90 | ### Data preparation 91 | - CIFAR100 will be downloaded during the first run 92 | - CORE50 download: `source fetch_data_setup.sh` 93 | - Mini-ImageNet: Download from https://www.kaggle.com/whitemoon/miniimagenet/download , and place it in datasets/mini_imagenet/ 94 | - CLRS: Download from https://github.com/lehaifeng/CLRS 95 | 96 | 97 | 98 | ## Acknowledgments 99 | Thanks for the great code base from: 100 | - [SCR/ASER](https://github.com/RaptorMai/online-continual-learning) 101 | - [Rehearsal Revealed](https://github.com/Mattdl/RehearsalRevealed) 102 | - [DER](https://github.com/aimagelab/mammoth) 103 | - [MIR](https://github.com/optimass/Maximally_Interfered_Retrieval) 104 | - [AGEM](https://github.com/facebookresearch/agem) 105 | 106 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | #from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all'): 15 | super(SupConLoss, self).__init__() 16 | self.temperature = temperature 17 | self.contrast_mode = contrast_mode 18 | 19 | def forward(self, features, labels=None, mask=None,need_full=False): 20 | """Compute loss for model. If both `labels` and `mask` are None, 21 | it degenerates to SimCLR unsupervised loss: 22 | https://arxiv.org/pdf/2002.05709.pdf 23 | Args: 24 | features: hidden vector of shape [bsz, n_views, ...]. 25 | labels: ground truth of shape [bsz]. 26 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 27 | has the same class as sample i. Can be asymmetric. 28 | Returns: 29 | A loss scalar. 30 | """ 31 | device = (torch.device('cuda') 32 | if features.is_cuda 33 | else torch.device('cpu')) 34 | 35 | if len(features.shape) < 3: 36 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 37 | 'at least 3 dimensions are required') 38 | if len(features.shape) > 3: 39 | features = features.view(features.shape[0], features.shape[1], -1) 40 | 41 | batch_size = features.shape[0] 42 | if labels is not None and mask is not None: 43 | raise ValueError('Cannot define both `labels` and `mask`') 44 | elif labels is None and mask is None: 45 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 46 | elif labels is not None: 47 | labels = labels.contiguous().view(-1, 1) 48 | if labels.shape[0] != batch_size: 49 | raise ValueError('Num of labels does not match num of features') 50 | mask = torch.eq(labels, labels.T).float().to(device) 51 | else: 52 | mask = mask.float().to(device) 53 | 54 | 55 | 56 | 57 | 58 | contrast_count = features.shape[1] 59 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 60 | 61 | if self.contrast_mode == 'one': 62 | anchor_feature = features[:, 0] 63 | anchor_count = 1 64 | elif self.contrast_mode == 'all': 65 | anchor_feature = contrast_feature 66 | anchor_count = contrast_count 67 | else: 68 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 69 | 70 | # compute logits 71 | anchor_dot_contrast = torch.div( 72 | torch.matmul(anchor_feature, contrast_feature.T), 73 | self.temperature) 74 | 75 | 76 | # for numerical stability 77 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 78 | logits = anchor_dot_contrast - logits_max.detach() 79 | 80 | 81 | # tile mask 82 | mask = mask.repeat(anchor_count, contrast_count) 83 | 84 | # mask-out self-contrast cases 85 | logits_mask = torch.scatter( 86 | torch.ones_like(mask), 87 | 1, 88 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 89 | 0 90 | ) 91 | 92 | mask = mask * logits_mask 93 | 94 | 95 | 96 | 97 | # compute log_prob 98 | exp_logits = torch.exp(logits) * logits_mask 99 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 100 | 101 | 102 | # compute mean of log-likelihood over positive 103 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 104 | 105 | 106 | #print(mean_log_prob_pos.shape) 107 | 108 | # loss 109 | loss = -1 * mean_log_prob_pos 110 | loss = loss.view(anchor_count, batch_size).mean() 111 | loss_full = -1 * mean_log_prob_pos 112 | 113 | # if(need_full): 114 | # 115 | # return loss,loss_full 116 | # else: 117 | return loss,loss_full 118 | 119 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from continuum.data_utils import create_task_composition, load_task_with_labels, shuffle_data 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | TEST_SPLIT = 1 / 6 8 | 9 | 10 | class Mini_ImageNet(DatasetBase): 11 | def __init__(self, scenario, params): 12 | dataset = 'mini_imagenet' 13 | if scenario == 'ni': 14 | num_tasks = len(params.ns_factor) 15 | else: 16 | num_tasks = params.num_tasks 17 | super(Mini_ImageNet, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | train_in = open("datasets/mini_imagenet/mini-imagenet-cache-train.pkl", "rb") 22 | train = pickle.load(train_in) 23 | train_x = train["image_data"].reshape([64, 600, 84, 84, 3]) 24 | val_in = open("datasets/mini_imagenet/mini-imagenet-cache-val.pkl", "rb") 25 | val = pickle.load(val_in) 26 | val_x = val['image_data'].reshape([16, 600, 84, 84, 3]) 27 | test_in = open("datasets/mini_imagenet/mini-imagenet-cache-test.pkl", "rb") 28 | test = pickle.load(test_in) 29 | test_x = test['image_data'].reshape([20, 600, 84, 84, 3]) 30 | all_data = np.vstack((train_x, val_x, test_x)) 31 | train_data = [] 32 | train_label = [] 33 | test_data = [] 34 | test_label = [] 35 | for i in range(len(all_data)): 36 | cur_x = all_data[i] 37 | cur_y = np.ones((600,)) * i 38 | rdm_x, rdm_y = shuffle_data(cur_x, cur_y) 39 | x_test = rdm_x[: int(600 * TEST_SPLIT)] 40 | y_test = rdm_y[: int(600 * TEST_SPLIT)] 41 | x_train = rdm_x[int(600 * TEST_SPLIT):] 42 | y_train = rdm_y[int(600 * TEST_SPLIT):] 43 | train_data.append(x_train) 44 | train_label.append(y_train) 45 | test_data.append(x_test) 46 | test_label.append(y_test) 47 | self.train_data = np.concatenate(train_data) 48 | self.train_label = np.concatenate(train_label) 49 | self.test_data = np.concatenate(test_data) 50 | self.test_label = np.concatenate(test_label) 51 | 52 | def new_run(self, **kwargs): 53 | self.setup() 54 | return self.test_set 55 | 56 | def new_task(self, cur_task, **kwargs): 57 | if self.scenario == 'ni': 58 | x_train, y_train = self.train_set[cur_task] 59 | labels = set(y_train) 60 | elif self.scenario == 'nc': 61 | labels = self.task_labels[cur_task] 62 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 63 | else: 64 | raise Exception('unrecognized scenario') 65 | return x_train, y_train, labels 66 | 67 | def setup(self): 68 | if self.scenario == 'ni': 69 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 70 | self.train_label, 71 | self.test_data, self.test_label, 72 | self.task_nums, 84, 73 | self.params.val_size, 74 | self.params.ns_type, self.params.ns_factor, 75 | plot=self.params.plot_sample) 76 | 77 | elif self.scenario == 'nc': 78 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, 79 | fixed_order=self.params.fix_order) 80 | self.test_set = [] 81 | for labels in self.task_labels: 82 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 83 | self.test_set.append((x_test, y_test)) 84 | 85 | def test_plot(self): 86 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 87 | self.params.ns_factor) 88 | -------------------------------------------------------------------------------- /models/ndpm/component.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | from typing import Tuple 5 | 6 | from utils.utils import maybe_cuda 7 | from utils.global_vars import * 8 | 9 | 10 | class Component(nn.Module, ABC): 11 | def __init__(self, params, experts: Tuple): 12 | super().__init__() 13 | self.params = params 14 | self.experts = experts 15 | 16 | self.optimizer = NotImplemented 17 | self.lr_scheduler = NotImplemented 18 | 19 | @abstractmethod 20 | def nll(self, x, y, step=None): 21 | """Return NLL""" 22 | pass 23 | 24 | @abstractmethod 25 | def collect_nll(self, x, y, step=None): 26 | """Return NLLs including previous experts""" 27 | pass 28 | 29 | def _clip_grad_value(self, clip_value): 30 | for group in self.optimizer.param_groups: 31 | nn.utils.clip_grad_value_(group['params'], clip_value) 32 | 33 | def _clip_grad_norm(self, max_norm, norm_type=2): 34 | for group in self.optimizer.param_groups: 35 | nn.utils.clip_grad_norm_(group['params'], max_norm, norm_type) 36 | 37 | def clip_grad(self): 38 | clip_grad_config = MODELS_NDPM_COMPONENT_CLIP_GRAD 39 | if clip_grad_config['type'] == 'value': 40 | self._clip_grad_value(**clip_grad_config['options']) 41 | elif clip_grad_config['type'] == 'norm': 42 | self._clip_grad_norm(**clip_grad_config['options']) 43 | else: 44 | raise ValueError('Invalid clip_grad type: {}' 45 | .format(clip_grad_config['type'])) 46 | 47 | @staticmethod 48 | def build_optimizer(optim_config, params): 49 | return getattr(torch.optim, optim_config['type'])( 50 | params, **optim_config['options']) 51 | 52 | @staticmethod 53 | def build_lr_scheduler(lr_config, optimizer): 54 | return getattr(torch.optim.lr_scheduler, lr_config['type'])( 55 | optimizer, **lr_config['options']) 56 | 57 | def weight_decay_loss(self): 58 | loss = maybe_cuda(torch.zeros([])) 59 | for param in self.parameters(): 60 | loss += torch.norm(param) ** 2 61 | return loss 62 | 63 | 64 | class ComponentG(Component, ABC): 65 | def setup_optimizer(self): 66 | self.optimizer = self.build_optimizer( 67 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 68 | self.lr_scheduler = self.build_lr_scheduler( 69 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_G, self.optimizer) 70 | 71 | def collect_nll(self, x, y=None, step=None): 72 | """Default `collect_nll` 73 | 74 | Warning: Parameter-sharing components should implement their own 75 | `collect_nll` 76 | 77 | Returns: 78 | nll: Tensor of shape [B, 1+K] 79 | """ 80 | outputs = [expert.g.nll(x, y, step) for expert in self.experts] 81 | nll = outputs 82 | output = self.nll(x, y, step) 83 | nll.append(output) 84 | return torch.stack(nll, dim=1) 85 | 86 | 87 | 88 | class ComponentD(Component, ABC): 89 | def setup_optimizer(self): 90 | self.optimizer = self.build_optimizer( 91 | {'type': self.params.optimizer, 'options': {'lr': self.params.learning_rate}}, self.parameters()) 92 | self.lr_scheduler = self.build_lr_scheduler( 93 | MODELS_NDPM_COMPONENT_LR_SCHEDULER_D, self.optimizer) 94 | 95 | def collect_forward(self, x): 96 | """Default `collect_forward` 97 | 98 | Warning: Parameter-sharing components should implement their own 99 | `collect_forward` 100 | 101 | Returns: 102 | output: Tensor of shape [B, 1+K, C] 103 | """ 104 | outputs = [expert.d(x) for expert in self.experts] 105 | outputs.append(self.forward(x)) 106 | return torch.stack(outputs, 1) 107 | 108 | def collect_nll(self, x, y, step=None): 109 | """Default `collect_nll` 110 | 111 | Warning: Parameter-sharing components should implement their own 112 | `collect_nll` 113 | 114 | Returns: 115 | nll: Tensor of shape [B, 1+K] 116 | """ 117 | outputs = [expert.d.nll(x, y, step) for expert in self.experts] 118 | nll = outputs 119 | output = self.nll(x, y, step) 120 | nll.append(output) 121 | return torch.stack(nll, dim=1) 122 | -------------------------------------------------------------------------------- /utils/buffer/aser_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.buffer_utils import random_retrieve, ClassBalancedRandomSampling 3 | from utils.buffer.aser_utils import compute_knn_sv 4 | from utils.utils import maybe_cuda 5 | from utils.setup_elements import n_classes 6 | 7 | 8 | class ASER_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.num_retrieve = params.eps_mem_batch 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.aser_type = params.aser_type 16 | self.n_smp_cls = int(params.n_smp_cls) 17 | self.out_dim = n_classes[params.data] 18 | self.is_aser_upt = params.update == "ASER" 19 | ClassBalancedRandomSampling.class_index_cache = None 20 | 21 | def retrieve(self, buffer, **kwargs): 22 | model = buffer.model 23 | 24 | if buffer.n_seen_so_far <= self.mem_size: 25 | # Use random retrieval until buffer is filled 26 | ret_x, ret_y = random_retrieve(buffer, self.num_retrieve) 27 | else: 28 | # Use ASER retrieval if buffer is filled 29 | cur_x, cur_y = kwargs['x'], kwargs['y'] 30 | buffer_x, buffer_y = buffer.buffer_img, buffer.buffer_label 31 | ret_x, ret_y = self._retrieve_by_knn_sv(model, buffer_x, buffer_y, cur_x, cur_y, self.num_retrieve) 32 | return ret_x, ret_y 33 | 34 | def _retrieve_by_knn_sv(self, model, buffer_x, buffer_y, cur_x, cur_y, num_retrieve): 35 | """ 36 | Retrieves data instances with top-N Shapley Values from candidate set. 37 | Args: 38 | model (object): neural network. 39 | buffer_x (tensor): data buffer. 40 | buffer_y (tensor): label buffer. 41 | cur_x (tensor): current input data tensor. 42 | cur_y (tensor): current input label tensor. 43 | num_retrieve (int): number of data instances to be retrieved. 44 | Returns 45 | ret_x (tensor): retrieved data tensor. 46 | ret_y (tensor): retrieved label tensor. 47 | """ 48 | cur_x = maybe_cuda(cur_x) 49 | cur_y = maybe_cuda(cur_y) 50 | 51 | # Reset and update ClassBalancedRandomSampling cache if ASER update is not enabled 52 | if not self.is_aser_upt: 53 | ClassBalancedRandomSampling.update_cache(buffer_y, self.out_dim) 54 | 55 | # Get candidate data for retrieval (i.e., cand <- class balanced subsamples from memory) 56 | cand_x, cand_y, cand_ind = \ 57 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, device=self.device) 58 | 59 | # Type 1 - Adversarial SV 60 | # Get evaluation data for type 1 (i.e., eval <- current input) 61 | eval_adv_x, eval_adv_y = cur_x, cur_y 62 | # Compute adversarial Shapley value of candidate data 63 | # (i.e., sv wrt current input) 64 | sv_matrix_adv = compute_knn_sv(model, eval_adv_x, eval_adv_y, cand_x, cand_y, self.k, device=self.device) 65 | 66 | if self.aser_type != "neg_sv": 67 | # Type 2 - Cooperative SV 68 | # Get evaluation data for type 2 69 | # (i.e., eval <- class balanced subsamples from memory excluding those already in candidate set) 70 | excl_indices = set(cand_ind.tolist()) 71 | eval_coop_x, eval_coop_y, _ = \ 72 | ClassBalancedRandomSampling.sample(buffer_x, buffer_y, self.n_smp_cls, 73 | excl_indices=excl_indices, device=self.device) 74 | # Compute Shapley value 75 | sv_matrix_coop = \ 76 | compute_knn_sv(model, eval_coop_x, eval_coop_y, cand_x, cand_y, self.k, device=self.device) 77 | if self.aser_type == "asv": 78 | # Use extremal SVs for computation 79 | sv = sv_matrix_coop.max(0).values - sv_matrix_adv.min(0).values 80 | else: 81 | # Use mean variation for aser_type == "asvm" or anything else 82 | sv = sv_matrix_coop.mean(0) - sv_matrix_adv.mean(0) 83 | else: 84 | # aser_type == "neg_sv" 85 | # No Type 1 - Cooperative SV; Use sum of Adversarial SV only 86 | sv = sv_matrix_adv.sum(0) * -1 87 | 88 | ret_ind = sv.argsort(descending=True) 89 | 90 | ret_x = cand_x[ret_ind][:num_retrieve] 91 | ret_y = cand_y[ret_ind][:num_retrieve] 92 | return ret_x, ret_y 93 | -------------------------------------------------------------------------------- /agents/agem.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.buffer.buffer import Buffer 6 | from utils.utils import maybe_cuda, AverageMeter 7 | import torch 8 | 9 | 10 | class AGEM(ContinualLearner): 11 | def __init__(self, model, opt, params): 12 | super(AGEM, self).__init__(model, opt, params) 13 | self.buffer = Buffer(model, params) 14 | self.mem_size = params.mem_size 15 | self.eps_mem_batch = params.eps_mem_batch 16 | self.mem_iters = params.mem_iters 17 | 18 | def train_learner(self, x_train, y_train): 19 | self.before_train(x_train, y_train) 20 | 21 | # set up loader 22 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 23 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 24 | drop_last=True) 25 | # set up model 26 | self.model = self.model.train() 27 | 28 | # setup tracker 29 | losses_batch = AverageMeter() 30 | acc_batch = AverageMeter() 31 | 32 | for ep in range(self.epoch): 33 | for i, batch_data in enumerate(train_loader): 34 | # batch update 35 | batch_x, batch_y = batch_data 36 | batch_x = maybe_cuda(batch_x, self.cuda) 37 | batch_y = maybe_cuda(batch_y, self.cuda) 38 | for j in range(self.mem_iters): 39 | logits = self.forward(batch_x) 40 | loss = self.criterion(logits, batch_y) 41 | if self.params.trick['kd_trick']: 42 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 43 | self.kd_manager.get_kd_loss(logits, batch_x) 44 | if self.params.trick['kd_trick_star']: 45 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 46 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x) 47 | _, pred_label = torch.max(logits, 1) 48 | correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0) 49 | # update tracker 50 | acc_batch.update(correct_cnt, batch_y.size(0)) 51 | losses_batch.update(loss, batch_y.size(0)) 52 | # backward 53 | self.opt.zero_grad() 54 | loss.backward() 55 | 56 | if self.task_seen > 0: 57 | # sample from memory of previous tasks 58 | mem_x, mem_y = self.buffer.retrieve() 59 | if mem_x.size(0) > 0: 60 | params = [p for p in self.model.parameters() if p.requires_grad] 61 | # gradient computed using current batch 62 | grad = [p.grad.clone() for p in params] 63 | mem_x = maybe_cuda(mem_x, self.cuda) 64 | mem_y = maybe_cuda(mem_y, self.cuda) 65 | mem_logits = self.forward(mem_x) 66 | loss_mem = self.criterion(mem_logits, mem_y) 67 | self.opt.zero_grad() 68 | loss_mem.backward() 69 | # gradient computed using memory samples 70 | grad_ref = [p.grad.clone() for p in params] 71 | 72 | # inner product of grad and grad_ref 73 | prod = sum([torch.sum(g * g_r) for g, g_r in zip(grad, grad_ref)]) 74 | if prod < 0: 75 | prod_ref = sum([torch.sum(g_r ** 2) for g_r in grad_ref]) 76 | # do projection 77 | grad = [g - prod / prod_ref * g_r for g, g_r in zip(grad, grad_ref)] 78 | # replace params' grad 79 | for g, p in zip(grad, params): 80 | p.grad.data.copy_(g) 81 | self.opt.step() 82 | # update mem 83 | self.buffer.update(batch_x, batch_y) 84 | 85 | if i % 100 == 1 and self.verbose: 86 | print( 87 | '==>>> it: {}, avg. loss: {:.6f}, ' 88 | 'running train acc: {:.3f}' 89 | .format(i, losses_batch.avg(), acc_batch.avg()) 90 | ) 91 | self.after_train() -------------------------------------------------------------------------------- /utils/name_match.py: -------------------------------------------------------------------------------- 1 | from agents.gdumb import Gdumb 2 | from continuum.dataset_scripts.cifar100 import CIFAR100 3 | from continuum.dataset_scripts.cifar10 import CIFAR10 4 | from continuum.dataset_scripts.core50 import CORE50 5 | from continuum.dataset_scripts.CLRS import CLRS25 6 | from continuum.dataset_scripts.imagenet import IMAGENET1000 7 | from continuum.dataset_scripts.mini_imagenet import Mini_ImageNet 8 | from continuum.dataset_scripts.openloris import OpenLORIS 9 | from agents.exp_replay import ExperienceReplay 10 | # from agents.unused.exp_replay_dyna_aug import ExperienceReplay_aug 11 | # from agents.unused.exp_replay_dyna_ratio import ExperienceReplay_ratio 12 | # from agents.unused.exp_replay_offline import ExperienceReplay_offline 13 | # from agents.unused.exp_replay_cl import ExperienceReplay_cl 14 | # from agents.unused.exp_replay_cl_bandits import ExperienceReplay_cl_bandits 15 | # from agents.unused.exp_replay_cl_bandits_ts import ExperienceReplay_cl_bandits_ts 16 | # from agents.unused.exp_replay_batchsize import ExperienceReplay_batchsize 17 | #from unused.rl_exp_replay import RL_ExperienceReplay 18 | from agents.agem import AGEM 19 | from agents.ewc_pp import EWC_pp 20 | from agents.cndpm import Cndpm 21 | from agents.lwf import Lwf 22 | from agents.icarl import Icarl 23 | #from agents.lamaml import LAMAML 24 | from utils.buffer.random_retrieve import Random_retrieve 25 | from utils.buffer.reservoir_update import Reservoir_update 26 | 27 | from utils.buffer.mir_retrieve import MIR_retrieve 28 | from utils.buffer.rl_retrieve import RL_retrieve 29 | from utils.buffer.gss_greedy_update import GSSGreedyUpdate 30 | from utils.buffer.aser_retrieve import ASER_retrieve 31 | from utils.buffer.aser_update import ASER_update 32 | 33 | from agents.scr import SupContrastReplay 34 | # from agents.scr_ratio import SCR_RL_ratio 35 | # from agents.scr_rl_addIter import SCR_RL_iter 36 | # from agents.ER_RL_iter import ER_RL_iter 37 | # from agents.unused.ER_RL_addIter import ER_RL_addIter 38 | # from agents.unused.ER_RL_addIter_stop import ER_RL_addIter_stop 39 | # from agents.unused.ER_RL_addIter_stop_new import ER_RL_addIter_stop_new 40 | from agents.ER_dyna_iter import ER_dyna_iter 41 | #from agents.ER_dyna_rnd import ER_dyna_rnd 42 | from agents.ER_dyna_iter_aug import ER_dyna_iter_aug 43 | #from agents.ER_dyna_iter_aug_only import ER_dyna_iter_aug_only 44 | #from agents.unused.ER_dyna_iter_aug_dbpg import ER_dyna_iter_aug_dbpg 45 | from agents.ER_dyna_iter_aug_dbpg_joint import ER_dyna_iter_aug_dbpg_joint 46 | from utils.buffer.sc_retrieve import Match_retrieve 47 | from utils.buffer.mem_match import MemMatch_retrieve 48 | from agents.DER import DER 49 | 50 | 51 | data_objects = { 52 | 'imagenet1000':IMAGENET1000, 53 | 'cifar100': CIFAR100, 54 | 'cifar10': CIFAR10, 55 | 'core50': CORE50, 56 | 'mini_imagenet': Mini_ImageNet, 57 | 'openloris': OpenLORIS, 58 | 'clrs25':CLRS25 59 | } 60 | agents = { 61 | "DER":DER, 62 | 'ER': ExperienceReplay, 63 | # 'ER_cl': ExperienceReplay_cl, 64 | # "ER_cl_bandits":ExperienceReplay_cl_bandits, 65 | # "ER_cl_bandits_ts":ExperienceReplay_cl_bandits_ts, 66 | # "ER_batchsize":ExperienceReplay_batchsize, 67 | # 'ER_aug': ExperienceReplay_aug, 68 | # "ER_ratio":ExperienceReplay_ratio, 69 | # "ER_offline":ExperienceReplay_offline, 70 | # "ER_RL_ratio":ER_RL_ratio, 71 | #"ER_RL_iter":ER_RL_iter, 72 | # "ER_RL_addIter":ER_RL_addIter, 73 | # "ER_RL_addIter_stop": ER_RL_addIter_stop, 74 | # "ER_RL_addIter_stop_new": ER_RL_addIter_stop_new, 75 | # "ER_dyna_rnd":ER_dyna_rnd, 76 | "ER_dyna_iter":ER_dyna_iter, 77 | "ER_dyna_iter_aug":ER_dyna_iter_aug, 78 | # "ER_dyna_iter_aug_only":ER_dyna_iter_aug_only, 79 | # "ER_dyna_iter_aug_dbpg":ER_dyna_iter_aug_dbpg, 80 | "ER_dyna_iter_aug_dbpg_joint":ER_dyna_iter_aug_dbpg_joint, 81 | #'RLER': RL_ExperienceReplay, 82 | #'LAMAML': LAMAML, 83 | 'EWC': EWC_pp, 84 | 'AGEM': AGEM, 85 | 'CNDPM': Cndpm, 86 | 'LWF': Lwf, 87 | 'ICARL': Icarl, 88 | 'GDUMB': Gdumb, 89 | 'SCR': SupContrastReplay, 90 | # 'SCR_RL_ratio':SCR_RL_ratio, 91 | # 'SCR_RL_iter':SCR_RL_iter, 92 | #'SCR_META':SupContrastReplay_meta, 93 | } 94 | 95 | retrieve_methods = { 96 | 'MIR': MIR_retrieve, 97 | 'random': Random_retrieve, 98 | 'ASER': ASER_retrieve, 99 | 'match': Match_retrieve, 100 | 'mem_match': MemMatch_retrieve, 101 | 'RL': RL_retrieve 102 | 103 | } 104 | 105 | update_methods = { 106 | 'random': Reservoir_update, 107 | 'GSS': GSSGreedyUpdate, 108 | 'ASER': ASER_update, 109 | 'rt':Reservoir_update, 110 | 'rt2':Reservoir_update, 111 | 'timestamp':Reservoir_update, 112 | } 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /continuum/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils import data 4 | from utils.setup_elements import transforms_match 5 | 6 | def create_task_composition(class_nums, num_tasks, fixed_order=False,start_class = 0): 7 | classes_per_task = class_nums // num_tasks 8 | total_classes = classes_per_task * num_tasks 9 | label_array = np.arange(start_class, total_classes) 10 | if not fixed_order: 11 | np.random.shuffle(label_array) 12 | 13 | task_labels = [] 14 | for tt in range(num_tasks): 15 | tt_offset = tt * classes_per_task 16 | task_labels.append(list(label_array[tt_offset:tt_offset + classes_per_task])) 17 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 18 | return task_labels 19 | 20 | def create_task_composition_order(class_nums, num_tasks, start_class = 0): 21 | classes_per_task = class_nums // num_tasks 22 | total_classes = classes_per_task * num_tasks 23 | 24 | label = np.arange(start_class, total_classes) 25 | task_org = np.arange(0, num_tasks) 26 | np.random.shuffle(task_org) 27 | 28 | task_labels = [label[task_org[i] * classes_per_task:task_org[i] * classes_per_task + classes_per_task] for i in range(num_tasks)] 29 | 30 | for tt in range(num_tasks): 31 | print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) 32 | print("task order:",task_org) 33 | return task_labels 34 | 35 | 36 | def load_task_with_labels_torch(x, y, labels): 37 | tmp = [] 38 | for i in labels: 39 | tmp.append((y == i).nonzero().view(-1)) 40 | idx = torch.cat(tmp) 41 | return x[idx], y[idx] 42 | 43 | 44 | def load_task_with_labels(x, y, labels): 45 | tmp = [] 46 | for i in labels: 47 | tmp.append((np.where(y == i)[0])) 48 | idx = np.concatenate(tmp, axis=None) 49 | return x[idx], y[idx] 50 | 51 | def load_task_with_labels_correct(x, y, labels): 52 | idx = [] 53 | for i,y_sample in enumerate(y): 54 | if(y_sample in labels): 55 | idx.append(i) 56 | return x[idx],y[idx] 57 | 58 | # idx = y==labels 59 | # return x[idx],y 60 | # tmp = [] 61 | # for i in labels: 62 | # tmp.append((np.where(y == i)[0])) 63 | # idx = np.concatenate(tmp, axis=None) 64 | # return x[idx], y[idx] 65 | 66 | 67 | 68 | class dataset_transform(data.Dataset): 69 | def __init__(self, x, y, transform=None): 70 | self.x = x 71 | self.y = torch.from_numpy(y).type(torch.LongTensor) 72 | self.transform = transform # save the transform 73 | 74 | def __len__(self): 75 | return len(self.y)#self.x.shape[0] # return 1 as we have only one image 76 | 77 | def __getitem__(self, idx): 78 | # return the augmented image 79 | if self.transform: 80 | x = self.transform(self.x[idx]) 81 | else: 82 | x = self.x[idx] 83 | 84 | return x.float(), self.y[idx] 85 | 86 | 87 | def setup_test_loader(test_data, params): 88 | test_loaders = [] 89 | 90 | for (x_test, y_test) in test_data: 91 | test_dataset = dataset_transform(x_test, y_test, transform=transforms_match[params.data]) 92 | test_loader = data.DataLoader(test_dataset, batch_size=params.test_batch, shuffle=True, num_workers=0) 93 | test_loaders.append(test_loader) 94 | return test_loaders 95 | 96 | 97 | def shuffle_data(x, y): 98 | perm_inds = np.arange(0, x.shape[0]) 99 | np.random.shuffle(perm_inds) 100 | rdm_x = x[perm_inds] 101 | rdm_y = y[perm_inds] 102 | return rdm_x, rdm_y 103 | 104 | 105 | def train_val_test_split_ni(train_data, train_label, test_data, test_label, task_nums, img_size, val_size=0.1): 106 | train_data_rdm, train_label_rdm = shuffle_data(train_data, train_label) 107 | val_size = int(len(train_data_rdm) * val_size) 108 | val_data_rdm, val_label_rdm = train_data_rdm[:val_size], train_label_rdm[:val_size] 109 | train_data_rdm, train_label_rdm = train_data_rdm[val_size:], train_label_rdm[val_size:] 110 | test_data_rdm, test_label_rdm = shuffle_data(test_data, test_label) 111 | train_data_rdm_split = train_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 112 | train_label_rdm_split = train_label_rdm.reshape(task_nums, -1) 113 | val_data_rdm_split = val_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 114 | val_label_rdm_split = val_label_rdm.reshape(task_nums, -1) 115 | test_data_rdm_split = test_data_rdm.reshape(task_nums, -1, img_size, img_size, 3) 116 | test_label_rdm_split = test_label_rdm.reshape(task_nums, -1) 117 | return train_data_rdm_split, train_label_rdm_split, val_data_rdm_split, val_label_rdm_split, test_data_rdm_split, test_label_rdm_split -------------------------------------------------------------------------------- /utils/buffer/test_buffer.py: -------------------------------------------------------------------------------- 1 | from utils.setup_elements import input_size_match 2 | from utils import name_match # import update_methods, retrieve_methods 3 | from utils.utils import maybe_cuda 4 | import torch 5 | import random 6 | import numpy as np 7 | from utils.buffer.recycle import recycle 8 | 9 | 10 | 11 | 12 | class Test_Buffer(torch.nn.Module): 13 | def __init__(self, params,): 14 | super().__init__() 15 | self.params = params 16 | #self.model = model 17 | self.cuda = self.params.cuda 18 | self.current_index = 0 19 | self.n_seen_so_far = 0 20 | self.training_steps = 0 21 | 22 | # define buffer 23 | self.buffer_size = params.test_mem_size 24 | print('test buffer has %d slots' % self.buffer_size) 25 | self.input_size = input_size_match[params.data] 26 | print("test buffer"+str(self.buffer_size)) 27 | 28 | # self.buffer_img = torch.FloatTensor(buffer_size, *input_size).fill_(0) 29 | # self.buffer_label = torch.LongTensor(buffer_size).fill_(0) 30 | 31 | self.mem_img = {} 32 | self.mem_c = {} 33 | 34 | # define update and retrieve method 35 | 36 | self.update_method = name_match.update_methods[params.update](params,) 37 | self.retrieve_method = name_match.retrieve_methods[params.retrieve](params) 38 | 39 | if (self.params.test_mem_recycle): 40 | self.recycler = recycle() 41 | 42 | # def update(self, x, y,tmp_buffer=None): 43 | # return self.update_method.update(buffer=self, x=x, y=y,tmp_buffer=tmp_buffer) 44 | 45 | def update(self, x, y,tmp_buffer=None): 46 | for i in range(len(y)): 47 | self.greedy_balancing_update(x[i],y[i].item()) 48 | index = 0 49 | for k in self.mem_c: 50 | index += self.mem_c[k] 51 | self.current_index = index 52 | 53 | def retrieve_strict_blc(self): 54 | 55 | min_class = min(self.mem_c.keys(), key=(lambda k: self.mem_c[k])) 56 | 57 | num_per_class = self.mem_c[min_class] 58 | mem_x = [] 59 | mem_y = [] 60 | for i in self.mem_img.keys(): 61 | perm_idx = np.array(np.random.permutation(len(self.mem_img[i]))) 62 | perm_arr = [self.mem_img[i][k] for k in perm_idx] 63 | selected_img = perm_arr[:num_per_class] 64 | 65 | mem_x += selected_img 66 | mem_y += [i] * num_per_class #self.mem_c[i] 67 | 68 | mem_x = torch.stack(mem_x) 69 | mem_y = torch.LongTensor(mem_y) 70 | mem_x = maybe_cuda(mem_x) 71 | mem_y = maybe_cuda(mem_y) 72 | return mem_x,mem_y 73 | # 74 | 75 | 76 | 77 | def retrieve(self,retrieve_num=None): 78 | mem_x = [] 79 | mem_y = [] 80 | for i in self.mem_img.keys(): 81 | mem_x += self.mem_img[i] 82 | mem_y += [i] * self.mem_c[i] 83 | 84 | mem_x = torch.stack(mem_x) 85 | mem_y = torch.LongTensor(mem_y) 86 | mem_x = maybe_cuda(mem_x) 87 | mem_y = maybe_cuda(mem_y) 88 | 89 | 90 | return mem_x,mem_y 91 | 92 | # def overwrite(self,idx_map,x,y): 93 | # 94 | # self.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 95 | # self.buffer_label[list(idx_map.keys())] = y[list(idx_map.values())] 96 | # 97 | # def reset(self): 98 | # buffer_size = self.params.mem_size 99 | # print('buffer has %d slots' % buffer_size) 100 | # input_size = input_size_match[self.params.data] 101 | # self.buffer_img = torch.FloatTensor(buffer_size, *input_size).fill_(0) 102 | # self.buffer_label = torch.LongTensor(buffer_size).fill_(0) 103 | # self.buffer_replay_times =maybe_cuda(torch.LongTensor(buffer_size).fill_(0)) 104 | # self.buffer_last_replay = maybe_cuda(torch.LongTensor(buffer_size).fill_(0)) 105 | 106 | 107 | def greedy_balancing_update(self, x, y): 108 | k_c = self.buffer_size // max(1, len(self.mem_img)) 109 | if y not in self.mem_img or self.mem_c[y] < k_c: 110 | if sum(self.mem_c.values()) >= self.buffer_size: 111 | cls_max = max(self.mem_c.items(), key=lambda k:k[1])[0] 112 | idx = random.randrange(self.mem_c[cls_max]) 113 | img = self.mem_img[cls_max].pop(idx) 114 | if(self.params.test_mem_recycle): 115 | recycle.store_tmp(img,cls_max) 116 | self.mem_c[cls_max] -= 1 117 | if y not in self.mem_img: 118 | self.mem_img[y] = [] 119 | self.mem_c[y] = 0 120 | self.mem_img[y].append(x) 121 | self.mem_c[y] += 1 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /general_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | import torch 5 | from experiment.run import multiple_run,multiple_RLtrainig_run 6 | from utils.utils import boolean_string 7 | import warnings 8 | from utils.argparser.argparser_RL import parse_RL_para 9 | from utils.argparser.argparser_basic import parse_cl_basic 10 | from utils.argparser.argparser_scr import parse_scr 11 | from utils.argparser.argparser_der import parse_der 12 | from utils.argparser.argparser_aug import parse_aug 13 | from utils.argparser.argparser_replay import parse_replay 14 | 15 | 16 | warnings.filterwarnings("ignore", category=DeprecationWarning) 17 | 18 | 19 | 20 | def main(args): 21 | print(args) 22 | # set up seed 23 | np.random.seed(args.seed) 24 | random.seed(args.seed) 25 | torch.manual_seed(args.seed) 26 | if args.cuda: 27 | torch.cuda.manual_seed(args.seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | args.trick = {'labels_trick': args.labels_trick, 'separated_softmax': args.separated_softmax, 31 | 'kd_trick': args.kd_trick, 'kd_trick_star': args.kd_trick_star, 'review_trick': args.review_trick, 32 | 'nmc_trick': args.nmc_trick} 33 | if(args.num_runs>1): 34 | multiple_RLtrainig_run(args) 35 | else: 36 | multiple_run(args) 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | # Commandline arguments 42 | parser = argparse.ArgumentParser(description="Online Continual Learning PyTorch") 43 | 44 | parser = parse_cl_basic(parser) 45 | parser = parse_scr(parser) 46 | parser = parse_der(parser) 47 | 48 | parser = parse_RL_para(parser) 49 | parser = parse_aug(parser) 50 | parser = parse_replay(parser) 51 | 52 | 53 | parser.add_argument('--aug_action_num',default=8,type=int) 54 | ######################### misc ######################## 55 | parser.add_argument("--immediate_evaluate",default = False, type = boolean_string) 56 | parser.add_argument('--dataset_random_type', dest='dataset_random_type', default= "order_random", 57 | type=str,choices=["order_random","task_random"], 58 | help="") 59 | parser.add_argument("--resnet_size",default="reduced",choices=["normal","reduced"]) 60 | parser.add_argument('--save_prefix', dest='save_prefix', default="", help='') 61 | 62 | parser.add_argument('--new_folder', dest='new_folder', default="", help='') 63 | 64 | parser.add_argument('--bpg_restart',default=False,type=boolean_string) 65 | parser.add_argument('--test', dest='test', default=" ", type=str,choices=["not_reset"], 66 | help='') 67 | parser.add_argument('--debug_mode',default=False, type=boolean_string) 68 | parser.add_argument('--acc_no_aug',default=True) 69 | 70 | parser.add_argument('--GPU_ID', dest='GPU_ID', default= 0, 71 | type=int, 72 | help="") 73 | parser.add_argument('--drift_detection',type=boolean_string,default=False) 74 | parser.add_argument("--test_add_buffer",default=False,type=boolean_string) 75 | 76 | parser.add_argument('--switch_buffer_type', dest='switch_buffer_type', default="one_buffer", 77 | type=str, choices=["one_buffer", "two_buffer", "dyna_buffer"], 78 | help="whether and how to switch replay buffer") 79 | #parser.add_argument("--adjust_aug_flag",default=False,type=boolean_string) 80 | ############ thompson sampling ########### 81 | parser.add_argument("--slide_window_size",default=10,type=int) 82 | 83 | parser.add_argument("--set_task_flag",default=-1,type=int) 84 | parser.add_argument("--bpg_lr",default=5.0,type=float) 85 | parser.add_argument("--bpg_lr_large",default=10.0,type=float) 86 | parser.add_argument("--bpg_lr_small",default=5.0,type=float) 87 | 88 | args = parser.parse_args() 89 | args.cuda = torch.cuda.is_available() 90 | torch.cuda.set_device(args.GPU_ID)#args.GPU_ID 91 | 92 | if(args.data=="cifar100"): 93 | if(args.set_task_flag>-1): 94 | args.num_tasks = args.set_task_flag 95 | else: 96 | args.num_tasks = 20 97 | elif(args.data=="cifar10"): 98 | args.num_tasks = 5 99 | elif(args.data=="mini_imagenet"): 100 | args.num_tasks=10 101 | elif(args.data=="clrs25"): 102 | if(args.cl_type == "nc"): 103 | args.num_tasks=5 104 | else: 105 | args.num_tasks=3 106 | elif(args.data=="core50"): 107 | args.num_tasks=9 108 | else: 109 | raise NotImplementedError("not seen dataset",args.data) 110 | 111 | main(args) 112 | -------------------------------------------------------------------------------- /continuum/dataset_scripts/openloris.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from PIL import Image 3 | import numpy as np 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | import time 6 | from continuum.data_utils import shuffle_data 7 | 8 | 9 | class OpenLORIS(DatasetBase): 10 | """ 11 | tasks_nums is predefined and it depends on the ns_type. 12 | """ 13 | def __init__(self, scenario, params): # scenario refers to "ni" or "nc" 14 | dataset = 'openloris' 15 | self.ns_type = params.ns_type 16 | task_nums = openloris_ntask[self.ns_type] # ns_type can be (illumination, occlusion, pixel, clutter, sequence) 17 | super(OpenLORIS, self).__init__(dataset, scenario, task_nums, params.num_runs, params) 18 | 19 | 20 | def download_load(self): 21 | s = time.time() 22 | self.train_set = [] 23 | for batch_num in range(1, self.task_nums+1): 24 | train_x = [] 25 | train_y = [] 26 | test_x = [] 27 | test_y = [] 28 | for i in range(len(datapath)): 29 | train_temp = glob.glob('datasets/openloris/' + self.ns_type + '/train/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 30 | 31 | train_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in train_temp]) 32 | train_y.extend([i] * len(train_temp)) 33 | 34 | test_temp = glob.glob( 35 | 'datasets/openloris/' + self.ns_type + '/test/task{}/{}/*.jpg'.format(batch_num, datapath[i])) 36 | 37 | test_x.extend([np.array(Image.open(x).convert('RGB').resize((50, 50))) for x in test_temp]) 38 | test_y.extend([i] * len(test_temp)) 39 | 40 | print(" --> batch{}'-dataset consisting of {} samples".format(batch_num, len(train_x))) 41 | print(" --> test'-dataset consisting of {} samples".format(len(test_x))) 42 | self.train_set.append((np.array(train_x), np.array(train_y))) 43 | self.test_set.append((np.array(test_x), np.array(test_y))) 44 | e = time.time() 45 | print('loading time: {}'.format(str(e - s))) 46 | 47 | def new_run(self, **kwargs): 48 | pass 49 | 50 | def new_task(self, cur_task, **kwargs): 51 | train_x, train_y = self.train_set[cur_task] 52 | # get val set 53 | train_x_rdm, train_y_rdm = shuffle_data(train_x, train_y) 54 | val_size = int(len(train_x_rdm) * self.params.val_size) 55 | val_data_rdm, val_label_rdm = train_x_rdm[:val_size], train_y_rdm[:val_size] 56 | train_data_rdm, train_label_rdm = train_x_rdm[val_size:], train_y_rdm[val_size:] 57 | self.val_set.append((val_data_rdm, val_label_rdm)) 58 | labels = set(train_label_rdm) 59 | return train_data_rdm, train_label_rdm, labels 60 | 61 | def setup(self, **kwargs): 62 | pass 63 | 64 | 65 | 66 | openloris_ntask = { 67 | 'illumination': 9, 68 | 'occlusion': 9, 69 | 'pixel': 9, 70 | 'clutter': 9, 71 | 'sequence': 12 72 | } 73 | 74 | datapath = ['bottle_01', 'bottle_02', 'bottle_03', 'bottle_04', 'bowl_01', 'bowl_02', 'bowl_03', 'bowl_04', 'bowl_05', 75 | 'corkscrew_01', 'cottonswab_01', 'cottonswab_02', 'cup_01', 'cup_02', 'cup_03', 'cup_04', 'cup_05', 76 | 'cup_06', 'cup_07', 'cup_08', 'cup_10', 'cushion_01', 'cushion_02', 'cushion_03', 'glasses_01', 77 | 'glasses_02', 'glasses_03', 'glasses_04', 'knife_01', 'ladle_01', 'ladle_02', 'ladle_03', 'ladle_04', 78 | 'mask_01', 'mask_02', 'mask_03', 'mask_04', 'mask_05', 'paper_cutter_01', 'paper_cutter_02', 79 | 'paper_cutter_03', 'paper_cutter_04', 'pencil_01', 'pencil_02', 'pencil_03', 'pencil_04', 'pencil_05', 80 | 'plasticbag_01', 'plasticbag_02', 'plasticbag_03', 'plug_01', 'plug_02', 'plug_03', 'plug_04', 'pot_01', 81 | 'scissors_01', 'scissors_02', 'scissors_03', 'stapler_01', 'stapler_02', 'stapler_03', 'thermometer_01', 82 | 'thermometer_02', 'thermometer_03', 'toy_01', 'toy_02', 'toy_03', 'toy_04', 'toy_05','nail_clippers_01','nail_clippers_02', 83 | 'nail_clippers_03', 'bracelet_01', 'bracelet_02','bracelet_03', 'comb_01','comb_02', 84 | 'comb_03', 'umbrella_01','umbrella_02','umbrella_03','socks_01','socks_02','socks_03', 85 | 'toothpaste_01','toothpaste_02','toothpaste_03','wallet_01','wallet_02','wallet_03', 86 | 'headphone_01','headphone_02','headphone_03', 'key_01','key_02','key_03', 87 | 'battery_01', 'battery_02', 'mouse_01', 'pencilcase_01', 'pencilcase_02', 'tape_01', 88 | 'chopsticks_01', 'chopsticks_02', 'chopsticks_03', 89 | 'notebook_01', 'notebook_02', 'notebook_03', 90 | 'spoon_01', 'spoon_02', 'spoon_03', 91 | 'tissue_01', 'tissue_02', 'tissue_03', 92 | 'clamp_01', 'clamp_02', 'hat_01', 'hat_02', 'u_disk_01', 'u_disk_02', 'swimming_glasses_01' 93 | ] 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /utils/setup_elements.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.resnet import ResNet18,Reduced_ResNet18,SupConResNet,SupConResNet_normal 3 | from models.pretrained import ResNet18_pretrained 4 | from torchvision import transforms 5 | import torch.nn as nn 6 | 7 | 8 | default_trick = {'labels_trick': False, 'kd_trick': False, 'separated_softmax': False, 9 | 'review_trick': False, 'nmc_trick': False} 10 | 11 | 12 | input_size_match = { 13 | 'cifar100': [3, 32, 32], 14 | 'cifar10': [3, 32, 32], 15 | 'core50': [3, 128, 128], 16 | 'clrs25': [3, 128,128],#[3, 256, 256], 17 | 'mini_imagenet': [3, 84, 84], 18 | 'openloris': [3, 50, 50] 19 | } 20 | 21 | 22 | n_classes = { 23 | 'cifar100': 100, 24 | 'cifar10': 10, 25 | 'core50': 50, 26 | 'clrs25':25, 27 | 'mini_imagenet': 100, 28 | 'openloris': 69 29 | } 30 | 31 | 32 | transforms_match = { 33 | 'core50': transforms.Compose([ 34 | transforms.ToTensor(), 35 | ]), 36 | 'clrs25': transforms.Compose([ 37 | transforms.ToTensor(), 38 | ]), 39 | 'cifar100': transforms.Compose([ 40 | transforms.ToTensor(), 41 | ]), 42 | 'cifar10': transforms.Compose([ 43 | transforms.ToTensor(), 44 | ]), 45 | 'mini_imagenet': transforms.Compose([ 46 | transforms.ToTensor()]), 47 | 'openloris': transforms.Compose([ 48 | transforms.ToTensor()]) 49 | } 50 | 51 | 52 | def setup_architecture(params): 53 | nclass = n_classes[params.data] 54 | 55 | if params.agent in ['SCR','SCR_META', 'SCP',"SCR_RL_ratio","SCR_RL_iter"]: 56 | if params.data == 'mini_imagenet': 57 | if(params.resnet_size == "normal"): 58 | return SupConResNet_normal(2048, head=params.head) 59 | else: 60 | return SupConResNet(640, head=params.head) 61 | if params.data == 'clrs25': 62 | 63 | if(params.resnet_size == "normal"): 64 | return SupConResNet_normal(8192, head=params.head) 65 | else: 66 | return SupConResNet(2560, head=params.head) 67 | 68 | if params.data == 'core50': 69 | if(params.resnet_size == "normal"): 70 | return SupConResNet_normal(8192, head=params.head) 71 | else: 72 | return SupConResNet(2560, head=params.head) 73 | 74 | 75 | if(params.resnet_size == "normal"): 76 | return SupConResNet_normal(512, head=params.head) 77 | else: 78 | return SupConResNet( head=params.head) 79 | 80 | #return SupConResNet(head=params.head) 81 | if params.agent == 'CNDPM': 82 | from models.ndpm.ndpm import Ndpm 83 | return Ndpm(params) 84 | if params.data == 'cifar100': 85 | if(params.resnet_size == "normal"): 86 | return ResNet18(nclass) 87 | else: 88 | return Reduced_ResNet18(nclass) 89 | elif params.data == 'clrs25': 90 | if(params.resnet_size == "normal"): 91 | model= ResNet18(nclass) 92 | model.linear = nn.Linear(8192, nclass, bias=True) 93 | else: 94 | model= Reduced_ResNet18(nclass) 95 | model.linear = nn.Linear(2560, nclass, bias=True) 96 | return model 97 | elif params.data == 'cifar10': 98 | 99 | if(params.resnet_size == "normal"): 100 | return ResNet18(nclass) 101 | else: 102 | return Reduced_ResNet18(nclass) 103 | 104 | elif params.data == 'core50': 105 | if(params.resnet_size == "normal"): 106 | model= ResNet18(nclass) 107 | model.linear = nn.Linear(8192, nclass, bias=True) 108 | else: 109 | model= Reduced_ResNet18(nclass) 110 | model.linear = nn.Linear(2560, nclass, bias=True) 111 | 112 | return model 113 | elif params.data == 'mini_imagenet': 114 | if(params.resnet_size == "normal"): 115 | model= ResNet18(nclass) 116 | model.linear = nn.Linear(2048, nclass, bias=True) 117 | else: 118 | model= Reduced_ResNet18(nclass) 119 | 120 | model.linear = nn.Linear(640, nclass, bias=True) 121 | return model 122 | elif params.data == 'openloris': 123 | return Reduced_ResNet18(nclass) 124 | else: 125 | raise NotImplementedError("undefined dataset",params.data) 126 | 127 | 128 | def setup_opt(optimizer, model, lr, wd): 129 | if optimizer == 'SGD': 130 | optim = torch.optim.SGD(model.parameters(), 131 | lr=lr, 132 | weight_decay=wd) 133 | elif optimizer == 'Adam': 134 | optim = torch.optim.Adam(model.parameters(), 135 | lr=lr, 136 | weight_decay=wd) 137 | else: 138 | raise Exception('wrong optimizer name') 139 | return optim 140 | -------------------------------------------------------------------------------- /utils/buffer/reservoir_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Reservoir_update(object): 5 | def __init__(self, params): 6 | super().__init__() 7 | 8 | 9 | 10 | def choose_replace_indice(self,buffer,valid_indices): 11 | 12 | ## replace samples with maximum replay times 13 | store_sample_num = torch.sum(valid_indices) 14 | idx_buffer = torch.argsort(buffer.buffer_replay_times,descending=True)[:store_sample_num] 15 | return idx_buffer 16 | 17 | def choose_replace_indice_timestamp(self, buffer, valid_indices): 18 | 19 | ## replace samples with longest last replay 20 | store_sample_num = torch.sum(valid_indices) 21 | idx_buffer = torch.argsort(buffer.buffer_last_replay,descending=True)[:store_sample_num] 22 | return idx_buffer 23 | 24 | def update(self, buffer, x, y,logits=None,tmp_buffer=None, **kwargs): 25 | batch_size = x.size(0) 26 | x = x.cpu() 27 | y = y.cpu() 28 | 29 | if hasattr(buffer, 'buffer_logits'): 30 | logits = logits.cpu() 31 | 32 | # add whatever still fits in the buffer 33 | place_left = max(0, buffer.buffer_img.size(0) - buffer.current_index) 34 | if place_left: 35 | offset = min(place_left, batch_size) 36 | buffer.buffer_img[buffer.current_index: buffer.current_index + offset].data.copy_(x[:offset]) 37 | buffer.buffer_label[buffer.current_index: buffer.current_index + offset].data.copy_(y[:offset]) 38 | buffer.buffer_new_old[buffer.current_index: buffer.current_index + offset]=1 39 | if hasattr(buffer,'buffer_logits'): 40 | buffer.buffer_logits[buffer.current_index: buffer.current_index + offset].data.copy_(logits[:offset]) 41 | 42 | buffer.current_index += offset 43 | buffer.n_seen_so_far += offset 44 | 45 | # everything was added 46 | if offset == x.size(0): 47 | filled_idx = list(range(buffer.current_index - offset, buffer.current_index, )) 48 | if buffer.params.buffer_tracker: 49 | buffer.buffer_tracker.update_cache(buffer.buffer_label, y[:offset], filled_idx) 50 | return filled_idx 51 | 52 | 53 | #TODO: the buffer tracker will have bug when the mem size can't be divided by batch size 54 | 55 | # remove what is already in the buffer 56 | x, y = x[place_left:], y[place_left:] 57 | if hasattr(buffer, 'buffer_logits'): 58 | logits = logits[place_left:] 59 | 60 | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, buffer.n_seen_so_far).long() 61 | valid_indices = (indices < buffer.buffer_img.size(0)).long() 62 | 63 | idx_new_data = valid_indices.nonzero().squeeze(-1) 64 | idx_buffer = indices[idx_new_data] 65 | 66 | ## zyq: choose the samples with least replay times to be replaced 67 | if(buffer.params.update[:2] =="rt"): 68 | idx_buffer = self.choose_replace_indice(buffer,valid_indices) 69 | elif(buffer.params.update=="timestamp"): 70 | idx_buffer = self.choose_replace_indice_timestamp(buffer,valid_indices) 71 | else: 72 | pass 73 | 74 | buffer.n_seen_so_far += x.size(0) 75 | 76 | if idx_buffer.numel() == 0: 77 | return [] 78 | 79 | assert idx_buffer.max() < buffer.buffer_img.size(0) 80 | assert idx_buffer.max() < buffer.buffer_label.size(0) 81 | if hasattr(buffer, 'buffer_logits'): 82 | assert idx_buffer.max() < buffer.buffer_logits.size(0) 83 | # assert idx_buffer.max() < self.buffer_task.size(0) 84 | 85 | assert idx_new_data.max() < x.size(0) 86 | assert idx_new_data.max() < y.size(0) 87 | 88 | 89 | idx_map = {idx_buffer[i].item(): idx_new_data[i].item() for i in range(idx_buffer.size(0))} 90 | #buffer.overwrite(idx_map,x,y) 91 | # ## zyq: save replay_times 92 | # for i in list(idx_map.keys()): 93 | # replay_times = buffer.buffer_replay_times[i].detach().cpu().numpy() 94 | # buffer.unique_replay_list.append(int(replay_times)) 95 | # buffer.buffer_replay_times[i]=0 96 | # buffer.buffer_last_replay[i]=0 97 | # sample_label = int(buffer.buffer_label[i].detach().cpu().numpy()) 98 | # buffer.replay_sample_label.append(sample_label) 99 | # # perform overwrite op 100 | # buffer.buffer_img[list(idx_map.keys())] = x[list(idx_map.values())] 101 | # buffer.buffer_label[list(idx_map.keys())] = y[list(idx_map.values())] 102 | 103 | # if (buffer.params.use_tmp_buffer): 104 | # 105 | # tmp_buffer.tmp_store(x[idx_new_data], y[idx_new_data]) 106 | # 107 | # else: 108 | if hasattr(buffer, 'buffer_logits'): 109 | buffer.overwrite(idx_map,x,y,logits) 110 | else: 111 | buffer.overwrite(idx_map, x, y) 112 | 113 | return list(idx_map.keys()) 114 | 115 | -------------------------------------------------------------------------------- /agents/ewc_pp.py: -------------------------------------------------------------------------------- 1 | from agents.base import ContinualLearner 2 | from continuum.data_utils import dataset_transform 3 | from utils.setup_elements import transforms_match 4 | from torch.utils import data 5 | from utils.utils import maybe_cuda, AverageMeter 6 | import torch 7 | 8 | class EWC_pp(ContinualLearner): 9 | def __init__(self, model, opt, params): 10 | super(EWC_pp, self).__init__(model, opt, params) 11 | self.weights = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 12 | self.lambda_ = params.lambda_ 13 | self.alpha = params.alpha 14 | self.fisher_update_after = params.fisher_update_after 15 | self.prev_params = {} 16 | self.running_fisher = self.init_fisher() 17 | self.tmp_fisher = self.init_fisher() 18 | self.normalized_fisher = self.init_fisher() 19 | 20 | def train_learner(self, x_train, y_train): 21 | self.before_train(x_train, y_train) 22 | # set up loader 23 | train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data]) 24 | train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0, 25 | drop_last=True) 26 | # setup tracker 27 | losses_batch = AverageMeter() 28 | acc_batch = AverageMeter() 29 | 30 | # set up model 31 | self.model.train() 32 | 33 | for ep in range(self.epoch): 34 | for i, batch_data in enumerate(train_loader): 35 | # batch update 36 | batch_x, batch_y = batch_data 37 | batch_x = maybe_cuda(batch_x, self.cuda) 38 | batch_y = maybe_cuda(batch_y, self.cuda) 39 | 40 | # update the running fisher 41 | if (ep * len(train_loader) + i + 1) % self.fisher_update_after == 0: 42 | self.update_running_fisher() 43 | 44 | out = self.forward(batch_x) 45 | loss = self.total_loss(out, batch_y) 46 | if self.params.trick['kd_trick']: 47 | loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \ 48 | self.kd_manager.get_kd_loss(out, batch_x) 49 | if self.params.trick['kd_trick_star']: 50 | loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \ 51 | (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x) 52 | # update tracker 53 | losses_batch.update(loss.item(), batch_y.size(0)) 54 | _, pred_label = torch.max(out, 1) 55 | acc = (pred_label == batch_y).sum().item() / batch_y.size(0) 56 | acc_batch.update(acc, batch_y.size(0)) 57 | # backward 58 | self.opt.zero_grad() 59 | loss.backward() 60 | 61 | # accumulate the fisher of current batch 62 | self.accum_fisher() 63 | self.opt.step() 64 | 65 | if i % 100 == 1 and self.verbose: 66 | print( 67 | '==>>> it: {}, avg. loss: {:.6f}, ' 68 | 'running train acc: {:.3f}' 69 | .format(i, losses_batch.avg(), acc_batch.avg()) 70 | ) 71 | 72 | # save params for current task 73 | for n, p in self.weights.items(): 74 | self.prev_params[n] = p.clone().detach() 75 | 76 | # update normalized fisher of current task 77 | max_fisher = max([torch.max(m) for m in self.running_fisher.values()]) 78 | min_fisher = min([torch.min(m) for m in self.running_fisher.values()]) 79 | for n, p in self.running_fisher.items(): 80 | self.normalized_fisher[n] = (p - min_fisher) / (max_fisher - min_fisher + 1e-32) 81 | self.after_train() 82 | 83 | def total_loss(self, inputs, targets): 84 | # cross entropy loss 85 | loss = self.criterion(inputs, targets) 86 | if len(self.prev_params) > 0: 87 | # add regularization loss 88 | reg_loss = 0 89 | for n, p in self.weights.items(): 90 | reg_loss += (self.normalized_fisher[n] * (p - self.prev_params[n]) ** 2).sum() 91 | loss += self.lambda_ * reg_loss 92 | return loss 93 | 94 | def init_fisher(self): 95 | return {n: p.clone().detach().fill_(0) for n, p in self.model.named_parameters() if p.requires_grad} 96 | 97 | def update_running_fisher(self): 98 | for n, p in self.running_fisher.items(): 99 | self.running_fisher[n] = (1. - self.alpha) * p \ 100 | + 1. / self.fisher_update_after * self.alpha * self.tmp_fisher[n] 101 | # reset the accumulated fisher 102 | self.tmp_fisher = self.init_fisher() 103 | 104 | def accum_fisher(self): 105 | for n, p in self.tmp_fisher.items(): 106 | p += self.weights[n].grad ** 2 -------------------------------------------------------------------------------- /continuum/dataset_scripts/cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from continuum.data_utils import create_task_composition, load_task_with_labels,create_task_composition_order,load_task_with_labels_correct 4 | from continuum.dataset_scripts.dataset_base import DatasetBase 5 | from continuum.non_stationary import construct_ns_multiple_wrapper, test_ns 6 | 7 | 8 | class CIFAR100(DatasetBase): 9 | def __init__(self, scenario, params): 10 | dataset = 'cifar100' 11 | if scenario == 'ni': 12 | num_tasks = len(params.ns_factor) 13 | else: 14 | num_tasks = params.num_tasks 15 | super(CIFAR100, self).__init__(dataset, scenario, num_tasks, params.num_runs, params) 16 | self.task_labels=[] 17 | 18 | 19 | 20 | def download_load(self): 21 | dataset_train = datasets.CIFAR100(root=self.root, train=True, download=True) 22 | self.train_data = dataset_train.data 23 | self.train_label = np.array(dataset_train.targets) 24 | dataset_test = datasets.CIFAR100(root=self.root, train=False, download=True) 25 | self.test_data = dataset_test.data 26 | self.test_label = np.array(dataset_test.targets) 27 | 28 | def setup(self): 29 | if self.scenario == 'ni': 30 | self.train_set, self.val_set, self.test_set = construct_ns_multiple_wrapper(self.train_data, 31 | self.train_label, 32 | self.test_data, self.test_label, 33 | self.task_nums, 32, 34 | self.params.val_size, 35 | self.params.ns_type, self.params.ns_factor, 36 | plot=self.params.plot_sample) 37 | elif self.scenario == 'nc': 38 | 39 | if(self.params.dataset_random_type == "task_random"): 40 | 41 | self.task_labels = create_task_composition(class_nums=100, num_tasks=self.task_nums, fixed_order=self.params.fix_order) 42 | print(self.task_labels) 43 | elif(self.params.dataset_random_type == "order_random"): 44 | self.task_labels = create_task_composition_order(class_nums=100, num_tasks=self.task_nums,) 45 | else: 46 | raise NotImplementedError("undefined dataset_random_type",self.params.dataset_random_type) 47 | 48 | self.test_set = [] 49 | for labels in self.task_labels: 50 | 51 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 52 | #print(labels,np.unique(y_test)) 53 | self.test_set.append((x_test, y_test)) 54 | 55 | elif self.scenario == 'nc_first_half': 56 | if(self.params.dataset_random_type == "task_random"): 57 | 58 | self.task_labels = create_task_composition(class_nums=50, num_tasks=self.task_nums, fixed_order=self.params.fix_order,start_class=0) 59 | elif(self.params.dataset_random_type == "order_random"): 60 | self.task_labels = create_task_composition_order(class_nums=50, num_tasks=self.task_nums,start_class=0) 61 | else: 62 | raise NotImplementedError("undefined dataset_random_type",self.params.dataset_random_type) 63 | 64 | self.test_set = [] 65 | for labels in self.task_labels: 66 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 67 | self.test_set.append((x_test, y_test)) 68 | elif self.scenario == 'nc_second_half': 69 | if(self.params.dataset_random_type == "task_random"): 70 | 71 | self.task_labels = create_task_composition(class_nums=50, num_tasks=self.task_nums, fixed_order=self.params.fix_order,start_class=50) 72 | elif(self.params.dataset_random_type == "order_random"): 73 | self.task_labels = create_task_composition_order(class_nums=50, num_tasks=self.task_nums,start_class=50) 74 | else: 75 | raise NotImplementedError("undefined dataset_random_type",self.params.dataset_random_type) 76 | 77 | self.test_set = [] 78 | for labels in self.task_labels: 79 | x_test, y_test = load_task_with_labels(self.test_data, self.test_label, labels) 80 | self.test_set.append((x_test, y_test)) 81 | else: 82 | raise Exception('wrong scenario') 83 | 84 | def new_task(self, cur_task, **kwargs): 85 | if self.scenario == 'ni': 86 | x_train, y_train = self.train_set[cur_task] 87 | labels = set(y_train) 88 | elif self.scenario[:2] == 'nc' or self.scenario : 89 | labels = self.task_labels[cur_task] 90 | x_train, y_train = load_task_with_labels(self.train_data, self.train_label, labels) 91 | return x_train, y_train, labels 92 | 93 | def new_run(self, **kwargs): 94 | self.setup() 95 | return self.test_set 96 | 97 | def test_plot(self): 98 | test_ns(self.train_data[:10], self.train_label[:10], self.params.ns_type, 99 | self.params.ns_factor) 100 | -------------------------------------------------------------------------------- /utils/buffer/aser_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer.reservoir_update import Reservoir_update 3 | from utils.buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve 4 | from utils.buffer.aser_utils import compute_knn_sv, add_minority_class_input 5 | from utils.setup_elements import n_classes 6 | from utils.utils import nonzero_indices, maybe_cuda 7 | 8 | 9 | class ASER_update(object): 10 | def __init__(self, params, **kwargs): 11 | super().__init__() 12 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 13 | self.k = params.k 14 | self.mem_size = params.mem_size 15 | self.num_tasks = params.num_tasks 16 | self.out_dim = n_classes[params.data] 17 | self.n_smp_cls = int(params.n_smp_cls) 18 | self.n_total_smp = int(params.n_smp_cls * self.out_dim) 19 | self.reservoir_update = Reservoir_update(params) 20 | ClassBalancedRandomSampling.class_index_cache = None 21 | 22 | def update(self, buffer, x, y, **kwargs): 23 | model = buffer.model 24 | 25 | place_left = self.mem_size - buffer.current_index 26 | 27 | # If buffer is not filled, use available space to store whole or part of batch 28 | if place_left: 29 | x_fit = x[:place_left] 30 | y_fit = y[:place_left] 31 | 32 | ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device) 33 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 34 | new_y=y_fit, ind=ind, device=self.device) 35 | self.reservoir_update.update(buffer, x_fit, y_fit) 36 | 37 | # If buffer is filled, update buffer by sv 38 | if buffer.current_index == self.mem_size: 39 | # remove what is already in the buffer 40 | cur_x, cur_y = x[place_left:], y[place_left:] 41 | self._update_by_knn_sv(model, buffer, cur_x, cur_y) 42 | 43 | def _update_by_knn_sv(self, model, buffer, cur_x, cur_y): 44 | """ 45 | Returns indices for replacement. 46 | Buffered instances with smallest SV are replaced by current input with higher SV. 47 | Args: 48 | model (object): neural network. 49 | buffer (object): buffer object. 50 | cur_x (tensor): current input data tensor. 51 | cur_y (tensor): current input label tensor. 52 | Returns 53 | ind_buffer (tensor): indices of buffered instances to be replaced. 54 | ind_cur (tensor): indices of current data to do replacement. 55 | """ 56 | cur_x = maybe_cuda(cur_x) 57 | cur_y = maybe_cuda(cur_y) 58 | 59 | # Find minority class samples from current input batch 60 | minority_batch_x, minority_batch_y = add_minority_class_input(cur_x, cur_y, self.mem_size, self.out_dim) 61 | 62 | # Evaluation set 63 | eval_x, eval_y, eval_indices = \ 64 | ClassBalancedRandomSampling.sample(buffer.buffer_img, buffer.buffer_label, self.n_smp_cls, 65 | device=self.device) 66 | 67 | # Concatenate minority class samples from current input batch to evaluation set 68 | eval_x = torch.cat((eval_x, minority_batch_x)) 69 | eval_y = torch.cat((eval_y, minority_batch_y)) 70 | 71 | # Candidate set 72 | cand_excl_indices = set(eval_indices.tolist()) 73 | cand_x, cand_y, cand_ind = random_retrieve(buffer, self.n_total_smp, cand_excl_indices, return_indices=True) 74 | 75 | # Concatenate current input batch to candidate set 76 | cand_x = torch.cat((cand_x, cur_x)) 77 | cand_y = torch.cat((cand_y, cur_y)) 78 | 79 | sv_matrix = compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, self.k, device=self.device) 80 | sv = sv_matrix.sum(0) 81 | 82 | n_cur = cur_x.size(0) 83 | n_cand = cand_x.size(0) 84 | 85 | # Number of previously buffered instances in candidate set 86 | n_cand_buf = n_cand - n_cur 87 | 88 | sv_arg_sort = sv.argsort(descending=True) 89 | 90 | # Divide SV array into two segments 91 | # - large: candidate args to be retained; small: candidate args to be discarded 92 | sv_arg_large = sv_arg_sort[:n_cand_buf] 93 | sv_arg_small = sv_arg_sort[n_cand_buf:] 94 | 95 | # Extract args relevant to replacement operation 96 | # If current data instances are in 'large' segment, they are added to buffer 97 | # If buffered instances are in 'small' segment, they are discarded from buffer 98 | # Replacement happens between these two sets 99 | # Retrieve original indices from candidate args 100 | ind_cur = sv_arg_large[nonzero_indices(sv_arg_large >= n_cand_buf)] - n_cand_buf 101 | arg_buffer = sv_arg_small[nonzero_indices(sv_arg_small < n_cand_buf)] 102 | ind_buffer = cand_ind[arg_buffer] 103 | 104 | buffer.n_seen_so_far += n_cur 105 | 106 | # perform overwrite op 107 | y_upt = cur_y[ind_cur] 108 | x_upt = cur_x[ind_cur] 109 | ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim, 110 | new_y=y_upt, ind=ind_buffer, device=self.device) 111 | buffer.buffer_img[ind_buffer] = x_upt.detach().cpu() 112 | buffer.buffer_label[ind_buffer] = y_upt.detach().cpu() 113 | -------------------------------------------------------------------------------- /utils/buffer/mir_retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import maybe_cuda 3 | import torch.nn.functional as F 4 | from utils.buffer.buffer_utils import random_retrieve, get_grad_vector 5 | import copy 6 | 7 | 8 | class MIR_retrieve(object): 9 | def __init__(self, params, **kwargs): 10 | super().__init__() 11 | self.params = params 12 | self.subsample = params.subsample 13 | self.num_retrieve = params.eps_mem_batch 14 | self.incoming_influence = None 15 | 16 | def compute_incoming_influence(self, buffer, **kwargs): 17 | 18 | sub_x, sub_y, mem_indices = random_retrieve(buffer, self.subsample, return_indices=True) 19 | #sub_x, sub_y = random_retrieve(buffer, self.subsample) 20 | grad_dims = [] 21 | for param in buffer.model.parameters(): 22 | grad_dims.append(param.data.numel()) 23 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 24 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 25 | if sub_x.size(0) > 0: 26 | with torch.no_grad(): 27 | logits_pre = buffer.model.forward(sub_x) 28 | logits_post = model_temp.forward(sub_x) 29 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none') 30 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none') 31 | scores = post_loss - pre_loss 32 | self.incoming_influence = torch.mean(scores) 33 | # big_ind = scores.sort(descending=True)[1][:self.num_retrieve] 34 | #buffer.update_replay_times(mem_indices[big_ind]) 35 | return self.incoming_influence 36 | else: 37 | raise NotImplementedError("MIR random subsample number is 0",sub_x.size(0)) 38 | 39 | def retrieve(self, buffer, **kwargs): 40 | sub_x, sub_y, mem_indices = random_retrieve(buffer, self.subsample, return_indices=True) 41 | #sub_x, sub_y = random_retrieve(buffer, self.subsample) 42 | grad_dims = [] 43 | for param in buffer.model.parameters(): 44 | grad_dims.append(param.data.numel()) 45 | grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 46 | model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 47 | if sub_x.size(0) > 0: 48 | with torch.no_grad(): 49 | logits_pre = buffer.model.forward(sub_x) 50 | logits_post = model_temp.forward(sub_x) 51 | pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none') 52 | post_loss = F.cross_entropy(logits_post, sub_y, reduction='none') 53 | scores = post_loss - pre_loss 54 | self.incoming_influence = torch.mean(scores) 55 | big_ind = scores.sort(descending=True)[1][:self.num_retrieve] 56 | buffer.update_replay_times(mem_indices[big_ind]) 57 | return sub_x[big_ind], sub_y[big_ind] 58 | else: 59 | return sub_x, sub_y 60 | # def retrieve(self, buffer, **kwargs): 61 | # sub_x, sub_y ,mem_indices = random_retrieve(buffer, self.subsample,return_indices=True) 62 | # grad_dims = [] 63 | # for param in buffer.model.parameters(): 64 | # grad_dims.append(param.data.numel()) 65 | # grad_vector = get_grad_vector(buffer.model.parameters, grad_dims) 66 | # model_temp = self.get_future_step_parameters(buffer.model, grad_vector, grad_dims) 67 | # if sub_x.size(0) > 0: 68 | # with torch.no_grad(): 69 | # logits_pre = buffer.model.forward(sub_x) 70 | # logits_post = model_temp.forward(sub_x) 71 | # pre_loss = F.cross_entropy(logits_pre, sub_y, reduction='none') 72 | # post_loss = F.cross_entropy(logits_post, sub_y, reduction='none') 73 | # scores = post_loss - pre_loss 74 | # big_ind = scores.sort(descending=True)[1][:self.num_retrieve] 75 | # buffer.update_replay_times(mem_indices[big_ind]) 76 | # return sub_x[big_ind], sub_y[big_ind] 77 | # else: 78 | # return sub_x, sub_y 79 | 80 | def get_future_step_parameters(self, model, grad_vector, grad_dims): 81 | """ 82 | computes \theta-\delta\theta 83 | :param this_net: 84 | :param grad_vector: 85 | :return: 86 | """ 87 | new_model = copy.deepcopy(model) 88 | self.overwrite_grad(new_model.parameters, grad_vector, grad_dims) 89 | with torch.no_grad(): 90 | for param in new_model.parameters(): 91 | if param.grad is not None: 92 | param.data = param.data - self.params.learning_rate * param.grad.data 93 | return new_model 94 | 95 | def overwrite_grad(self, pp, new_grad, grad_dims): 96 | """ 97 | This is used to overwrite the gradients with a new gradient 98 | vector, whenever violations occur. 99 | pp: parameters 100 | newgrad: corrected gradient 101 | grad_dims: list storing number of parameters at each layer 102 | """ 103 | cnt = 0 104 | for param in pp(): 105 | param.grad = torch.zeros_like(param.data) 106 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 107 | en = sum(grad_dims[:cnt + 1]) 108 | this_grad = new_grad[beg: en].contiguous().view( 109 | param.data.size()) 110 | param.grad.data.copy_(this_grad) 111 | cnt += 1 --------------------------------------------------------------------------------