├── approaches ├── __init__.py ├── .DS_Store ├── __pycache__ │ ├── .DS_Store │ ├── ewc.cpython-35.pyc │ ├── ewc.cpython-36.pyc │ ├── ft.cpython-35.pyc │ ├── gs.cpython-35.pyc │ ├── gs.cpython-36.pyc │ ├── hat.cpython-35.pyc │ ├── hat.cpython-36.pyc │ ├── lfl.cpython-36.pyc │ ├── lrp.cpython-36.pyc │ ├── lwf.cpython-36.pyc │ ├── mas.cpython-35.pyc │ ├── mas.cpython-36.pyc │ ├── sgd.cpython-36.pyc │ ├── si.cpython-35.pyc │ ├── si.cpython-36.pyc │ ├── ucl.cpython-36.pyc │ ├── joint.cpython-36.pyc │ ├── lrp_R.cpython-36.pyc │ ├── proxy.cpython-36.pyc │ ├── relu.cpython-36.pyc │ ├── rwalk.cpython-35.pyc │ ├── rwalk.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── ewc_LRP.cpython-36.pyc │ ├── imm_mean.cpython-36.pyc │ ├── imm_mode.cpython-36.pyc │ ├── linear.cpython-36.pyc │ ├── lrp_ewc.cpython-36.pyc │ ├── lrp_hong.cpython-36.pyc │ ├── module.cpython-36.pyc │ ├── pathnet.cpython-36.pyc │ ├── random.cpython-36.pyc │ ├── container.cpython-36.pyc │ ├── gs_secure.cpython-36.pyc │ ├── lrp_before.cpython-36.pyc │ ├── sgd_frozen.cpython-36.pyc │ ├── conv_ewc_LRP.cpython-36.pyc │ ├── ewc_with_log.cpython-36.pyc │ ├── lrp_hong_ewc.cpython-36.pyc │ ├── lwf_with_log.cpython-36.pyc │ ├── mlp_ewc_LRP.cpython-36.pyc │ ├── progressive.cpython-36.pyc │ ├── sgd_restart.cpython-36.pyc │ ├── sgd_with_log.cpython-36.pyc │ ├── si_with_log.cpython-36.pyc │ ├── ucl_ablation.cpython-36.pyc │ ├── sgd_L2_with_log.cpython-36.pyc │ ├── ewc_LRP_with_log.cpython-36.pyc │ ├── pattern_with_log.cpython-36.pyc │ └── ewc_LRPabs_with_log.cpython-36.pyc ├── random_init.py ├── ewc.py ├── ft.py ├── mas.py ├── rwalk.py ├── si.py ├── afec_ewc.py └── afec_mas.py ├── LargeScale ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── alexnet.cpython-36.pyc │ │ ├── resnet18.cpython-36.pyc │ │ ├── alexnet_hat.cpython-36.pyc │ │ ├── model_factory.cpython-36.pyc │ │ └── resnet18_small.cpython-36.pyc │ ├── model_factory.py │ ├── alexnet.py │ └── alexnet_hat.py ├── .DS_Store ├── trainer │ ├── __init__.py │ ├── .DS_Store │ ├── __pycache__ │ │ ├── gs.cpython-36.pyc │ │ ├── si.cpython-36.pyc │ │ ├── ewc.cpython-36.pyc │ │ ├── hat.cpython-36.pyc │ │ ├── mas.cpython-36.pyc │ │ ├── rwalk.cpython-36.pyc │ │ ├── gs_alex.cpython-36.pyc │ │ ├── gs_res.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── evaluator.cpython-36.pyc │ │ └── trainer_factory.cpython-36.pyc │ ├── evaluator.py │ ├── trainer_factory.py │ ├── mas.py │ ├── ewc.py │ ├── rwalk.py │ ├── si.py │ └── afec_mas.py ├── data_handler │ ├── __init__.py │ ├── __pycache__ │ │ ├── dataset.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── dataset_factory.cpython-36.pyc │ │ └── incremental_loader.cpython-36.pyc │ ├── dataset_factory.py │ ├── incremental_loader.py │ └── dataset.py ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── arguments.cpython-36.pyc ├── utils.py ├── arguments.py └── main.py ├── .DS_Store ├── networks ├── .DS_Store ├── __pycache__ │ ├── conv_net.cpython-35.pyc │ └── conv_net_hat.cpython-35.pyc ├── conv_net_omniglot.py ├── alexnet.py ├── conv_net.py ├── conv_net_omniglot_hat.py ├── resnet │ └── conv_net.py ├── alexnet_hat.py └── conv_net_hat.py ├── weishts ├── .DS_Store └── logs │ └── .DS_Store ├── dataloaders ├── .DS_Store ├── __pycache__ │ ├── split_cifar100.cpython-35.pyc │ └── split_cifar10_100.cpython-35.pyc ├── split_cifar100.py ├── split_cifar100_SC.py └── split_cifar10_100.py ├── result_data └── .DS_Store ├── trained_model └── .DS_Store ├── __pycache__ ├── utils.cpython-35.pyc ├── arguments.cpython-35.pyc └── arguments_rl.cpython-36.pyc ├── LICENSE ├── README.md ├── arguments.py └── main.py /approaches/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LargeScale/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.model_factory import * -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/.DS_Store -------------------------------------------------------------------------------- /networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/networks/.DS_Store -------------------------------------------------------------------------------- /weishts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/weishts/.DS_Store -------------------------------------------------------------------------------- /LargeScale/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/.DS_Store -------------------------------------------------------------------------------- /approaches/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/.DS_Store -------------------------------------------------------------------------------- /dataloaders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/dataloaders/.DS_Store -------------------------------------------------------------------------------- /result_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/result_data/.DS_Store -------------------------------------------------------------------------------- /trained_model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/trained_model/.DS_Store -------------------------------------------------------------------------------- /weishts/logs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/weishts/logs/.DS_Store -------------------------------------------------------------------------------- /LargeScale/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from trainer.evaluator import * 2 | from trainer.trainer_factory import * -------------------------------------------------------------------------------- /LargeScale/trainer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/.DS_Store -------------------------------------------------------------------------------- /__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/.DS_Store -------------------------------------------------------------------------------- /__pycache__/arguments.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/__pycache__/arguments.cpython-35.pyc -------------------------------------------------------------------------------- /LargeScale/data_handler/__init__.py: -------------------------------------------------------------------------------- 1 | from data_handler.dataset_factory import * 2 | from data_handler.incremental_loader import * -------------------------------------------------------------------------------- /__pycache__/arguments_rl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/__pycache__/arguments_rl.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ewc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ewc.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ewc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ewc.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ft.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ft.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/gs.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/gs.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/gs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/gs.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/hat.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/hat.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/hat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/hat.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lfl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lfl.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lrp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lrp.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lwf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lwf.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/mas.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/mas.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/mas.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/mas.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/sgd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/sgd.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/si.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/si.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/si.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/si.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ucl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ucl.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/joint.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lrp_R.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lrp_R.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/proxy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/proxy.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/relu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/relu.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/rwalk.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/rwalk.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/rwalk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/rwalk.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ewc_LRP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ewc_LRP.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/imm_mean.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/imm_mean.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/imm_mode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/imm_mode.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/linear.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lrp_ewc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lrp_ewc.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lrp_hong.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lrp_hong.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/module.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/pathnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/pathnet.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/random.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/random.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/conv_net.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/networks/__pycache__/conv_net.cpython-35.pyc -------------------------------------------------------------------------------- /LargeScale/__pycache__/arguments.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/__pycache__/arguments.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/gs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/gs.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/si.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/si.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/container.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/container.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/gs_secure.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/gs_secure.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lrp_before.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lrp_before.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/sgd_frozen.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/sgd_frozen.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/conv_net_hat.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/networks/__pycache__/conv_net_hat.cpython-35.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/ewc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/ewc.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/hat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/hat.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/mas.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/mas.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/rwalk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/rwalk.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/conv_ewc_LRP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/conv_ewc_LRP.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ewc_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ewc_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lrp_hong_ewc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lrp_hong_ewc.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/lwf_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/lwf_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/mlp_ewc_LRP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/mlp_ewc_LRP.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/progressive.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/progressive.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/sgd_restart.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/sgd_restart.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/sgd_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/sgd_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/si_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/si_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ucl_ablation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ucl_ablation.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/gs_alex.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/gs_alex.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/gs_res.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/gs_res.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/sgd_L2_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/sgd_L2_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/split_cifar100.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/dataloaders/__pycache__/split_cifar100.cpython-35.pyc -------------------------------------------------------------------------------- /LargeScale/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/networks/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/networks/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/networks/__pycache__/resnet18.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/networks/__pycache__/resnet18.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/evaluator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/evaluator.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ewc_LRP_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ewc_LRP_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/pattern_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/pattern_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/split_cifar10_100.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/dataloaders/__pycache__/split_cifar10_100.cpython-35.pyc -------------------------------------------------------------------------------- /LargeScale/data_handler/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/data_handler/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/networks/__pycache__/alexnet_hat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/networks/__pycache__/alexnet_hat.cpython-36.pyc -------------------------------------------------------------------------------- /approaches/__pycache__/ewc_LRPabs_with_log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/approaches/__pycache__/ewc_LRPabs_with_log.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/data_handler/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/data_handler/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/networks/__pycache__/model_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/networks/__pycache__/model_factory.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/networks/__pycache__/resnet18_small.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/networks/__pycache__/resnet18_small.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/trainer/__pycache__/trainer_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/trainer/__pycache__/trainer_factory.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/data_handler/__pycache__/dataset_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/data_handler/__pycache__/dataset_factory.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/data_handler/__pycache__/incremental_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lywang3081/AFEC/HEAD/LargeScale/data_handler/__pycache__/incremental_loader.cpython-36.pyc -------------------------------------------------------------------------------- /LargeScale/data_handler/dataset_factory.py: -------------------------------------------------------------------------------- 1 | import data_handler.dataset as data 2 | 3 | 4 | class DatasetFactory: 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def get_dataset(name): 10 | if name == "CUB200": 11 | return data.CUB200() 12 | 13 | elif name == "ImageNet": 14 | return data.ImageNet() 15 | -------------------------------------------------------------------------------- /LargeScale/networks/model_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class ModelFactory(): 3 | def __init__(self): 4 | pass 5 | 6 | @staticmethod 7 | def get_model(dataset, trainer, taskcla): 8 | 9 | 10 | if dataset == 'CUB200' or 'ImageNet': 11 | if trainer == 'hat': 12 | import networks.alexnet_hat as alex 13 | return alex.alexnet(taskcla, pretrained=False) 14 | else: 15 | import networks.alexnet as alex 16 | return alex.alexnet(taskcla, pretrained=False) 17 | -------------------------------------------------------------------------------- /LargeScale/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import pandas as pd 5 | 6 | def human_format(num): 7 | magnitude=0 8 | while abs(num)>=1000: 9 | magnitude+=1 10 | num/=1000.0 11 | return '%.1f%s'%(num,['','K','M','G','T','P'][magnitude]) 12 | 13 | def print_model_report(model): 14 | print('-'*100) 15 | print(model) 16 | print('Dimensions =',end=' ') 17 | count=0 18 | for p in model.parameters(): 19 | print(p.size(),end=' ') 20 | count+=np.prod(p.size()) 21 | print() 22 | print('Num parameters = %s'%(human_format(count))) 23 | print('-'*100) 24 | return count 25 | 26 | 27 | def print_optimizer_config(optim): 28 | if optim is None: 29 | print(optim) 30 | else: 31 | print(optim,'=',end=' ') 32 | opt=optim.param_groups[0] 33 | for n in opt.keys(): 34 | if not n.startswith('param'): 35 | print(n+':',opt[n],end=', ') 36 | print() 37 | return 38 | 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 lywang3081 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AFEC: Active Forgetting of Negative Transfer in Continual Learning (NeurIPS 2021)]() 2 | 3 | ------ 4 | This code is the official implementation of our paper. 5 | 6 | ## **Execution Details** 7 | 8 | ### Requirements 9 | 10 | - Python 3 11 | - GPU 1080Ti / Pytorch 1.3.1+cu9.2 / CUDA 9.2 12 | 13 | ### Execution command 14 | We provide a demo command to run AFEC on visual classification tasks. 15 | To reproduce other baselines and the adaptation of AFEC to representative weight regularization approaches, 16 | please check arguments.py for the command, and Appendix C.1 (Table.4) for the hyperparameters. 17 | 18 | For small-scale images: 19 | 20 | ``` 21 | # CIFAR-100-SC 22 | $ python3 ./main.py --experiment split_cifar100_SC --approach afec_ewc --lamb 40000 --lamb_emp 1 23 | 24 | # CIFAR-100 25 | $ python3 ./main.py --experiment split_cifar100 --approach afec_ewc --lamb 10000 --lamb_emp 1 26 | 27 | # CIFAR-10/100 28 | $ python3 ./main.py --experiment split_cifar10_100 --approach afec_ewc --lamb 25000 --lamb_emp 1 29 | 30 | ``` 31 | 32 | For large-scale images: 33 | 34 | ``` 35 | $ cd LargeScale 36 | 37 | # CUB-200 38 | $ python3 ./main.py --dataset CUB200 --trainer afec_ewc --lamb 40 --lamb_emp 0.001 39 | 40 | # ImageNet-100 41 | $ python3 ./main.py --dataset ImageNet --trainer afec_ewc --lamb 80 --lamb_emp 0.001 42 | 43 | ``` 44 | 45 | ## Citation 46 | 47 | Please cite our paper if it is helpful to your work: 48 | 49 | ```bibtex 50 | @article{wang2021afec, 51 | title={AFEC: Active Forgetting of Negative Transfer in Continual Learning}, 52 | author={Wang, Liyuan and Zhang, Mingtian and Jia, Zhongfan and Li, Qian and Bao, Chenglong and Ma, Kaisheng and Zhu, Jun and Zhong, Yi}, 53 | journal={arXiv preprint arXiv:2110.12187}, 54 | year={2021} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /LargeScale/trainer/evaluator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from arguments import get_args 8 | args = get_args() 9 | 10 | 11 | class EvaluatorFactory(): 12 | ''' 13 | This class is used to get different versions of evaluators 14 | ''' 15 | 16 | def __init__(self): 17 | pass 18 | 19 | @staticmethod 20 | def get_evaluator(testType="trainedClassifier"): 21 | if testType == "trainedClassifier": 22 | return softmax_evaluator() 23 | 24 | 25 | class softmax_evaluator(): 26 | ''' 27 | Evaluator class for softmax classification 28 | ''' 29 | 30 | def __init__(self): 31 | self.ce=torch.nn.CrossEntropyLoss() 32 | 33 | def evaluate(self, model, iterator, t): 34 | with torch.no_grad(): 35 | total_loss=0 36 | total_acc=0 37 | total_num=0 38 | model.eval() 39 | 40 | # Loop batches 41 | for data, target in iterator: 42 | data, target = data.cuda(), target.cuda() 43 | 44 | if args.trainer == 'hat': 45 | task=torch.autograd.Variable(torch.LongTensor([t]).cuda(),volatile=True) 46 | output = model(data,task,args.smax)[t] 47 | else: 48 | output = model(data)[t] 49 | loss=self.ce(output,target) 50 | _,pred=output.max(1) 51 | hits=(pred==target).float() 52 | 53 | # Log 54 | 55 | total_loss+=loss.data.cpu().numpy()*data.shape[0] 56 | total_acc+=hits.sum().data.cpu().numpy() 57 | total_num+=data.shape[0] 58 | 59 | return total_loss/total_num,total_acc/total_num 60 | -------------------------------------------------------------------------------- /LargeScale/trainer/trainer_factory.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from arguments import get_args 3 | args = get_args() 4 | import torch 5 | 6 | class TrainerFactory(): 7 | def __init__(self): 8 | pass 9 | 10 | @staticmethod 11 | def get_trainer(myModel, args, optimizer, evaluator, taskcla): 12 | 13 | if args.trainer == 'ewc': 14 | import trainer.ewc as trainer 15 | elif args.trainer == 'afec_ewc': 16 | import trainer.afec_ewc as trainer 17 | elif args.trainer == 'afec_mas': 18 | import trainer.afec_mas as trainer 19 | elif args.trainer == 'mas': 20 | import trainer.mas as trainer 21 | elif args.trainer == 'afec_rwalk': 22 | import trainer.afec_rwalk as trainer 23 | elif args.trainer == 'rwalk': 24 | import trainer.rwalk as trainer 25 | elif args.trainer == 'afec_si': 26 | import trainer.afec_si as trainer 27 | elif args.trainer == 'si': 28 | import trainer.si as trainer 29 | elif args.trainer == 'gs': 30 | import trainer.gs as trainer 31 | return trainer.Trainer(myModel, args, optimizer, evaluator, taskcla) 32 | 33 | class GenericTrainer: 34 | ''' 35 | Base class for trainer; to implement a new training routine, inherit from this. 36 | ''' 37 | 38 | def __init__(self, model, args, optimizer, evaluator, taskcla): 39 | 40 | self.model = model 41 | self.args = args 42 | self.optimizer = optimizer 43 | self.evaluator=evaluator 44 | self.taskcla=taskcla 45 | self.model_fixed = copy.deepcopy(self.model) 46 | for param in self.model_fixed.parameters(): 47 | param.requires_grad = False 48 | self.current_lr = args.lr 49 | self.ce=torch.nn.CrossEntropyLoss() 50 | self.model_single = copy.deepcopy(self.model) 51 | self.optimizer_single = None -------------------------------------------------------------------------------- /networks/conv_net_omniglot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from utils import * 5 | 6 | class Net(nn.Module): 7 | def __init__(self, inputsize, taskcla): 8 | super().__init__() 9 | 10 | ncha,size,_=inputsize #28 11 | self.taskcla = taskcla 12 | 13 | self.conv1 = nn.Conv2d(ncha,64,kernel_size=3) 14 | s = compute_conv_output_size(size,3) #26 15 | self.conv2 = nn.Conv2d(64,64,kernel_size=3) 16 | s = compute_conv_output_size(s,3) #24 17 | s = s//2 #12 18 | self.conv3 = nn.Conv2d(64,64,kernel_size=3) 19 | s = compute_conv_output_size(s,3) #10 20 | self.conv4 = nn.Conv2d(64,64,kernel_size=3) 21 | s = compute_conv_output_size(s,3) #8 22 | s = s//2 #4 23 | 24 | self.MaxPool = torch.nn.MaxPool2d(2) 25 | 26 | self.last=torch.nn.ModuleList() 27 | 28 | for t,n in self.taskcla: 29 | self.last.append(torch.nn.Linear(s*s*64,n)) #4*4*64 = 1024 30 | self.relu = torch.nn.ReLU() 31 | 32 | def forward(self, x, avg_act = False): 33 | act1=self.relu(self.conv1(x)) 34 | act2=self.relu(self.conv2(act1)) 35 | h=self.MaxPool(act2) 36 | act3=self.relu(self.conv3(h)) 37 | act4=self.relu(self.conv4(act3)) 38 | h=self.MaxPool(act4) 39 | h=h.view(x.shape[0],-1) 40 | y = [] 41 | for t,i in self.taskcla: 42 | y.append(self.last[t](h)) 43 | 44 | self.grads={} 45 | def save_grad(name): 46 | def hook(grad): 47 | self.grads[name] = grad 48 | return hook 49 | 50 | if avg_act == True: 51 | names = [0, 1, 2, 3] 52 | act = [act1, act2, act3, act4] 53 | 54 | self.act = [] 55 | for i in act: 56 | self.act.append(i.detach()) 57 | for idx, name in enumerate(names): 58 | act[idx].register_hook(save_grad(name)) 59 | 60 | return y -------------------------------------------------------------------------------- /LargeScale/data_handler/incremental_loader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as td 7 | from sklearn.utils import shuffle 8 | from PIL import Image 9 | from torch.autograd import Variable 10 | import torchvision.transforms.functional as trnF 11 | 12 | 13 | class ResultLoader(td.Dataset): 14 | def __init__(self, data, labels, transform=None, loader=None, data_dict=None): 15 | 16 | self.data = data 17 | self.labels = labels 18 | self.transform=transform 19 | self.loader = loader 20 | self.data_dict = data_dict 21 | 22 | def __len__(self): 23 | return self.labels.shape[0] 24 | 25 | def __getitem__(self, index): 26 | 27 | img = self.data[index] 28 | try: 29 | img = Image.fromarray(img) 30 | except: 31 | try: 32 | img = self.data_dict[img] 33 | except: 34 | img = self.loader(img) 35 | 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img, self.labels[index] 40 | 41 | def make_ResultLoaders(data, labels, taskcla, transform=None, shuffle_idx=None, loader=None, data_dict=None): 42 | if shuffle_idx is not None: 43 | labels = shuffle_idx[labels] 44 | sort_index = np.argsort(labels) 45 | data = data[sort_index] 46 | labels = np.array(labels) 47 | labels = labels[sort_index] 48 | 49 | loaders = [] 50 | start = 0 51 | for t, ncla in taskcla: 52 | start_idx = np.argmin(labels(start+ncla-1)) # end data index 54 | if end_idx == 0: 55 | end_idx = data.shape[0] 56 | 57 | loaders.append(ResultLoader(data[start_idx:end_idx], 58 | labels[start_idx:end_idx]%ncla, 59 | transform=transform, 60 | loader=loader, 61 | data_dict=data_dict)) 62 | 63 | start += ncla 64 | 65 | return loaders -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description='Continual') 6 | # Arguments 7 | parser.add_argument('--seed', type=int, default=0, help='(default=%(default)d)') 8 | parser.add_argument('--experiment', default='pmnist', type=str, required=False, 9 | choices=[ 'split_cifar10_100', 10 | 'split_cifar100', 11 | 'split_cifar100_SC'], 12 | help='(default=%(default)s)') 13 | parser.add_argument('--approach', default='lrp', type=str, required=False, 14 | choices=['afec_ewc', 15 | 'ewc', 16 | 'si','afec_si', 17 | 'gs', 18 | 'rwalk', 'afec_rwalk', 19 | 'mas', 'afec_mas',], 20 | help='(default=%(default)s)') 21 | 22 | 23 | parser.add_argument('--output', default='', type=str, required=False, help='(default=%(default)s)') 24 | parser.add_argument('--nepochs', default=100, type=int, required=False, help='(default=%(default)d)') 25 | parser.add_argument('--batch-size', default=256, type=int, required=False, help='(default=%(default)d)') 26 | parser.add_argument('--lr', default=0.001, type=float, required=False, help='(default=%(default)f)') 27 | parser.add_argument('--rho', default=0.3, type=float, help='(default=%(default)f)') 28 | parser.add_argument('--gamma', default=0.75, type=float, help='(default=%(default)f)') 29 | parser.add_argument('--eta', default=0.8, type=float, help='(default=%(default)f)') 30 | parser.add_argument('--smax', default=400, type=float, help='(default=%(default)f)') 31 | parser.add_argument('--lamb', default='1', type=float, help='(default=%(default)f)') 32 | parser.add_argument('--lamb_emp', default='0', type=float, help='(default=%(default)f)') 33 | parser.add_argument('--nu', default='0.1', type=float, help='(default=%(default)f)') 34 | parser.add_argument('--mu', default=0, type=float, help='groupsparse parameter') 35 | 36 | parser.add_argument('--img', default=0, type=float, help='image id to visualize') 37 | 38 | parser.add_argument('--date', type=str, default='', help='(default=%(default)s)') 39 | parser.add_argument('--tasknum', default=10, type=int, help='(default=%(default)s)') 40 | parser.add_argument('--parameter',type=str,default='',help='(default=%(default)s)') 41 | parser.add_argument('--sample', type = int, default=1, help='Using sigma max to support coefficient') 42 | 43 | args=parser.parse_args() 44 | return args 45 | 46 | -------------------------------------------------------------------------------- /LargeScale/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description='Continual Learning') 6 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 7 | help='input batch size for training (default: 64)') 8 | # CUB: 0.005 9 | parser.add_argument('--lr', type=float, default=0.005, metavar='LR', 10 | help='learning rate (default: 0.1. Note that lr is decayed by args.gamma parameter args.schedule ') 11 | parser.add_argument('--decay', type=float, default=0, help='Weight decay (L2 penalty).') 12 | parser.add_argument('--lamb', type=float, default=0, help='Lambda for gs, mas, rwalk, ewc, si, hat') 13 | parser.add_argument('--lamb_emp', type=float, default=0, help='Lambda for afec') 14 | parser.add_argument('--lamb_emp_fo', type=float, default=0, help='Lambda for afec') 15 | parser.add_argument('--mu', type=float, default=1, help='Mu for gs') 16 | parser.add_argument('--eta', type=float, default=1, help='Gracefully forgetting') 17 | parser.add_argument('--gamma', type=float, default=0.75, help='HAT reg strength or AGS rand-init') 18 | parser.add_argument('--smax', type=int, default=400, help='HAT reg strength') 19 | parser.add_argument('--rho', type=float, default=0.1, help='Rho for GS') 20 | parser.add_argument('--schedule', type=int, nargs='+', default=[30], 21 | help='Decrease learning rate at these epochs.') 22 | parser.add_argument('--gammas', type=float, nargs='+', default=[0.1], 23 | help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule') 24 | parser.add_argument('--seed', type=int, default=0, 25 | help='Seeds values to be used; seed introduces randomness by changing order of classes') 26 | parser.add_argument('--nepochs', type=int, default=40, help='Number of epochs for each increment') 27 | parser.add_argument('--tasknum', default=10, type=int, help='(default=%(default)s)') 28 | parser.add_argument('--date', type=str, default='', help='(default=%(default)s)') 29 | parser.add_argument('--output', default='', type=str, required=False, help='(default=%(default)s)') 30 | parser.add_argument('--dataset', default='CUB200', type=str, 31 | choices=['CUB200', 'ImageNet'], 32 | help='(default=%(default)s)') 33 | 34 | parser.add_argument('--trainer', default='gs', type=str, 35 | choices=['mas', 'afec_mas', 36 | 'ewc', 'afec_ewc', 37 | 'gs', 38 | 'si', 'afec_si', 39 | 'rwalk', 'afec_rwalk' ], 40 | help='(default=%(default)s)') 41 | 42 | args = parser.parse_args() 43 | return args 44 | -------------------------------------------------------------------------------- /networks/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class Net(nn.Module): 15 | 16 | def __init__(self, inputsize, taskcla): 17 | super(Net, self).__init__() 18 | self.taskcla = taskcla 19 | self.relu = nn.ReLU(inplace=True) 20 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 21 | self.dropout = nn.Dropout() 22 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1) 23 | self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1) 24 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 25 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 26 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 27 | self.fc1 = nn.Linear(256 * 1 * 1, 4096) 28 | self.fc2 = nn.Linear(4096, 4096) 29 | 30 | self.last=torch.nn.ModuleList() 31 | for t,n in self.taskcla: 32 | self.last.append(torch.nn.Linear(4096,n)) 33 | 34 | def forward(self, x, avg_act = False): 35 | act1 = self.relu(self.conv1(x)) 36 | x = self.maxpool(act1) 37 | act2 = self.relu(self.conv2(x)) 38 | x = self.maxpool(act2) 39 | act3 = self.relu(self.conv3(x)) 40 | act4 = self.relu(self.conv4(act3)) 41 | act5 = self.relu(self.conv5(act4)) 42 | x = self.maxpool(act5) 43 | 44 | x = torch.flatten(x, 1) 45 | act6 = self.relu(self.fc1(self.dropout(x))) 46 | act7 = self.relu(self.fc2(self.dropout(act6))) 47 | 48 | y = [] 49 | for t,i in self.taskcla: 50 | y.append(self.last[t](act7)) 51 | 52 | self.grads={} 53 | self.act = [] 54 | def save_grad(name): 55 | def hook(grad): 56 | self.grads[name] = grad 57 | return hook 58 | 59 | if avg_act == True: 60 | names = [0, 1, 2, 3, 4, 5, 6] 61 | act = [act1, act2, act3, act4, act5, act6, act7] 62 | 63 | self.act = [] 64 | for i in act: 65 | self.act.append(i.detach()) 66 | for idx, name in enumerate(names): 67 | act[idx].register_hook(save_grad(name)) 68 | 69 | return y 70 | 71 | 72 | def alexnet(taskcla, pretrained=False): 73 | r"""AlexNet model architecture from the 74 | `"One weird trick..." `_ paper. 75 | Args: 76 | pretrained (bool): If True, returns a model pre-trained on ImageNet 77 | progress (bool): If True, displays a progress bar of the download to stderr 78 | """ 79 | model = AlexNet(taskcla) 80 | 81 | if pretrained: 82 | pre_model = torchvision.models.alexnet(pretrained=True) 83 | for key1, key2 in zip(model.state_dict().keys(), pre_model.state_dict().keys()): 84 | if 'last' in key1: 85 | break 86 | if model.state_dict()[key1].shape == torch.tensor(1).shape: 87 | model.state_dict()[key1] = pre_model.state_dict()[key2] 88 | else: 89 | model.state_dict()[key1][:] = pre_model.state_dict()[key2][:] 90 | 91 | return model 92 | 93 | -------------------------------------------------------------------------------- /LargeScale/networks/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class AlexNet(nn.Module): 15 | 16 | def __init__(self, taskcla): 17 | super(AlexNet, self).__init__() 18 | self.taskcla = taskcla 19 | self.relu = nn.ReLU(inplace=True) 20 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 21 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 22 | self.dropout = nn.Dropout() 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2) 24 | self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) 25 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 26 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 27 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 28 | self.fc1 = nn.Linear(256 * 6 * 6, 4096) 29 | self.fc2 = nn.Linear(4096, 4096) 30 | 31 | self.last=torch.nn.ModuleList() 32 | for t,n in self.taskcla: 33 | self.last.append(torch.nn.Linear(4096,n)) 34 | 35 | def forward(self, x, avg_act = False): 36 | act1 = self.relu(self.conv1(x)) 37 | x = self.maxpool(act1) 38 | act2 = self.relu(self.conv2(x)) 39 | x = self.maxpool(act2) 40 | act3 = self.relu(self.conv3(x)) 41 | act4 = self.relu(self.conv4(act3)) 42 | act5 = self.relu(self.conv5(act4)) 43 | x = self.maxpool(act5) 44 | x = self.avgpool(x) 45 | 46 | x = torch.flatten(x, 1) 47 | act6 = self.relu(self.fc1(self.dropout(x))) 48 | act7 = self.relu(self.fc2(self.dropout(act6))) 49 | 50 | y = [] 51 | for t,i in self.taskcla: 52 | y.append(self.last[t](act7)) 53 | 54 | self.grads={} 55 | self.act = [] 56 | def save_grad(name): 57 | def hook(grad): 58 | self.grads[name] = grad 59 | return hook 60 | 61 | if avg_act == True: 62 | names = [0, 1, 2, 3, 4, 5, 6] 63 | act = [act1, act2, act3, act4, act5, act6, act7] 64 | 65 | self.act = [] 66 | for i in act: 67 | self.act.append(i.detach()) 68 | for idx, name in enumerate(names): 69 | act[idx].register_hook(save_grad(name)) 70 | 71 | return y 72 | 73 | 74 | def alexnet(taskcla, pretrained=False): 75 | r"""AlexNet model architecture from the 76 | `"One weird trick..." `_ paper. 77 | Args: 78 | pretrained (bool): If True, returns a model pre-trained on ImageNet 79 | progress (bool): If True, displays a progress bar of the download to stderr 80 | """ 81 | model = AlexNet(taskcla) 82 | 83 | if pretrained: 84 | pre_model = torchvision.models.alexnet(pretrained=True) 85 | for key1, key2 in zip(model.state_dict().keys(), pre_model.state_dict().keys()): 86 | if 'last' in key1: 87 | break 88 | if model.state_dict()[key1].shape == torch.tensor(1).shape: 89 | model.state_dict()[key1] = pre_model.state_dict()[key2] 90 | else: 91 | model.state_dict()[key1][:] = pre_model.state_dict()[key2][:] 92 | 93 | return model 94 | -------------------------------------------------------------------------------- /networks/conv_net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from utils import * 5 | 6 | class Net(nn.Module): 7 | def __init__(self, inputsize, taskcla): 8 | super().__init__() 9 | 10 | ncha,size,_=inputsize 11 | self.taskcla = taskcla 12 | 13 | self.conv1 = nn.Conv2d(ncha,32,kernel_size=3,padding=1) 14 | s = compute_conv_output_size(size,3, padding=1) # 32 15 | self.conv2 = nn.Conv2d(32,32,kernel_size=3,padding=1) 16 | s = compute_conv_output_size(s,3, padding=1) # 32 17 | s = s//2 # 16 18 | self.conv3 = nn.Conv2d(32,64,kernel_size=3,padding=1) 19 | s = compute_conv_output_size(s,3, padding=1) # 16 20 | self.conv4 = nn.Conv2d(64,64,kernel_size=3,padding=1) 21 | s = compute_conv_output_size(s,3, padding=1) # 16 22 | s = s//2 # 8 23 | self.conv5 = nn.Conv2d(64,128,kernel_size=3,padding=1) 24 | s = compute_conv_output_size(s,3, padding=1) # 8 25 | self.conv6 = nn.Conv2d(128,128,kernel_size=3,padding=1) 26 | s = compute_conv_output_size(s,3, padding=1) # 8 27 | # self.conv7 = nn.Conv2d(128,128,kernel_size=3,padding=1) 28 | # s = compute_conv_output_size(s,3, padding=1) # 8 29 | s = s//2 # 4 30 | self.fc1 = nn.Linear(s*s*128,256) # 2048 31 | self.drop1 = nn.Dropout(0.25) 32 | self.drop2 = nn.Dropout(0.5) 33 | self.MaxPool = torch.nn.MaxPool2d(2) 34 | self.avg_neg = [] 35 | self.last=torch.nn.ModuleList() 36 | 37 | for t,n in self.taskcla: 38 | self.last.append(torch.nn.Linear(256,n)) 39 | self.relu = torch.nn.ReLU() 40 | 41 | def forward(self, x, avg_act = False): 42 | act1=self.relu(self.conv1(x)) 43 | act2=self.relu(self.conv2(act1)) 44 | h=self.drop1(self.MaxPool(act2)) 45 | act3=self.relu(self.conv3(h)) 46 | act4=self.relu(self.conv4(act3)) 47 | h=self.drop1(self.MaxPool(act4)) 48 | act5=self.relu(self.conv5(h)) 49 | act6=self.relu(self.conv6(act5)) 50 | h=self.drop1(self.MaxPool(act6)) 51 | h=h.view(x.shape[0],-1) 52 | act7 = self.relu(self.fc1(h)) 53 | h = self.drop2(act7) 54 | y = [] 55 | for t,i in self.taskcla: 56 | y.append(self.last[t](h)) 57 | 58 | self.grads={} 59 | def save_grad(name): 60 | def hook(grad): 61 | self.grads[name] = grad 62 | return hook 63 | 64 | """ 65 | act1=self.conv1(x) 66 | act2=self.conv2(self.relu(act1)) 67 | h=self.drop1(self.MaxPool(self.relu(act2))) 68 | act3=self.conv3(h) 69 | act4=self.conv4(self.relu(act3)) 70 | h=self.drop1(self.MaxPool(self.relu(act4))) 71 | act5=self.conv5(h) 72 | act6=self.conv6(self.relu(act5)) 73 | # h=self.relu(self.conv7(h)) 74 | h=self.drop1(self.MaxPool(self.relu(act6))) 75 | h=h.view(x.shape[0],-1) 76 | act7 = self.fc1(h) 77 | h = self.drop2(self.relu(act7)) 78 | y = [] 79 | for t,i in self.taskcla: 80 | y.append(self.last[t](h)) 81 | """ 82 | 83 | if avg_act == True: 84 | names = [0, 1, 2, 3, 4, 5, 6] 85 | act = [act1, act2, act3, act4, act5, act6, act7] 86 | 87 | self.act = [] 88 | for i in act: 89 | self.act.append(i.detach()) 90 | for idx, name in enumerate(names): 91 | act[idx].register_hook(save_grad(name)) 92 | return y 93 | -------------------------------------------------------------------------------- /networks/conv_net_omniglot_hat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from utils import * 5 | 6 | class Net(nn.Module): 7 | def __init__(self, inputsize, taskcla): 8 | super().__init__() 9 | 10 | ncha,size,_=inputsize #28 11 | self.taskcla = taskcla 12 | 13 | self.c1 = nn.Conv2d(ncha,64,kernel_size=3) 14 | s = compute_conv_output_size(size,3) #26 15 | self.c2 = nn.Conv2d(64,64,kernel_size=3) 16 | s = compute_conv_output_size(s,3) #24 17 | s = s//2 #12 18 | self.c3 = nn.Conv2d(64,64,kernel_size=3) 19 | s = compute_conv_output_size(s,3) #10 20 | self.c4 = nn.Conv2d(64,64,kernel_size=3) 21 | s = compute_conv_output_size(s,3) #8 22 | s = s//2 #4 23 | 24 | self.MaxPool = torch.nn.MaxPool2d(2) 25 | 26 | self.last=torch.nn.ModuleList() 27 | 28 | for t,n in self.taskcla: 29 | self.last.append(torch.nn.Linear(s*s*64,n)) #4*4*64 = 1024 30 | self.relu = torch.nn.ReLU() 31 | 32 | self.gate=torch.nn.Sigmoid() 33 | # All embedding stuff should start with 'e' 34 | self.ec1=torch.nn.Embedding(len(self.taskcla),64) 35 | self.ec2=torch.nn.Embedding(len(self.taskcla),64) 36 | self.ec3=torch.nn.Embedding(len(self.taskcla),64) 37 | self.ec4=torch.nn.Embedding(len(self.taskcla),64) 38 | 39 | """ (e.g., used in the compression experiments) 40 | lo,hi=0,2 41 | self.ec1.weight.data.uniform_(lo,hi) 42 | self.ec2.weight.data.uniform_(lo,hi) 43 | self.ec3.weight.data.uniform_(lo,hi) 44 | self.ec4.weight.data.uniform_(lo,hi) 45 | #""" 46 | 47 | def forward(self,t,x,s=1): 48 | # Gates 49 | masks=self.mask(t,s=s) 50 | gc1,gc2,gc3,gc4=masks 51 | 52 | #Gated 53 | h=self.relu(self.c1(x)) 54 | h=h*gc1.view(1,-1,1,1).expand_as(h) 55 | 56 | h=self.relu(self.c2(h)) 57 | h=h*gc2.view(1,-1,1,1).expand_as(h) 58 | h=self.MaxPool(h) 59 | 60 | h=self.relu(self.c3(h)) 61 | h=h*gc3.view(1,-1,1,1).expand_as(h) 62 | 63 | h=self.relu(self.c4(h)) 64 | h=h*gc4.view(1,-1,1,1).expand_as(h) 65 | h=self.MaxPool(h) 66 | 67 | h=h.view(x.shape[0],-1) 68 | y = [] 69 | for t,i in self.taskcla: 70 | y.append(self.last[t](h)) 71 | 72 | return y,masks 73 | 74 | def mask(self,t,s=1): 75 | gc1=self.gate(s*self.ec1(t)) 76 | gc2=self.gate(s*self.ec2(t)) 77 | gc3=self.gate(s*self.ec3(t)) 78 | gc4=self.gate(s*self.ec4(t)) 79 | return [gc1,gc2,gc3,gc4] 80 | 81 | def get_view_for(self,n,masks): 82 | gc1,gc2,gc3,gc4=masks 83 | 84 | if n=='c1.weight': 85 | return gc1.data.view(-1,1,1,1).expand_as(self.c1.weight) 86 | elif n=='c1.bias': 87 | return gc1.data.view(-1) 88 | elif n=='c2.weight': 89 | post=gc2.data.view(-1,1,1,1).expand_as(self.c2.weight) 90 | pre=gc1.data.view(1,-1,1,1).expand_as(self.c2.weight) 91 | return torch.min(post,pre) 92 | elif n=='c2.bias': 93 | return gc2.data.view(-1) 94 | elif n=='c3.weight': 95 | post=gc3.data.view(-1,1,1,1).expand_as(self.c3.weight) 96 | pre=gc2.data.view(1,-1,1,1).expand_as(self.c3.weight) 97 | return torch.min(post,pre) 98 | elif n=='c3.bias': 99 | return gc3.data.view(-1) 100 | elif n=='c4.weight': 101 | post=gc4.data.view(-1,1,1,1).expand_as(self.c4.weight) 102 | pre=gc3.data.view(1,-1,1,1).expand_as(self.c4.weight) 103 | return torch.min(post,pre) 104 | elif n=='c4.bias': 105 | return gc4.data.view(-1) 106 | 107 | return None 108 | -------------------------------------------------------------------------------- /dataloaders/split_cifar100.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import torch 4 | import utils 5 | from torchvision import datasets,transforms 6 | from sklearn.utils import shuffle 7 | 8 | def get(seed=0,pc_valid=0.10, tasknum = 20): 9 | data={} 10 | taskcla=[] 11 | size=[3,32,32] 12 | tasknum = 20 13 | 14 | if not os.path.isdir('../dat/binary_split_cifar100_5/'): 15 | os.makedirs('../dat/binary_split_cifar100_5') 16 | 17 | mean = [0.5071, 0.4867, 0.4408] 18 | std = [0.2675, 0.2565, 0.2761] 19 | 20 | # CIFAR100 21 | dat={} 22 | 23 | dat['train']=datasets.CIFAR100('../dat/',train=True,download=True, 24 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 25 | dat['test']=datasets.CIFAR100('../dat/',train=False,download=True, 26 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 27 | for n in range(tasknum): 28 | data[n]={} 29 | data[n]['name']='cifar100' 30 | data[n]['ncla']= 5 31 | data[n]['train']={'x': [],'y': []} 32 | data[n]['test']={'x': [],'y': []} 33 | for s in ['train','test']: 34 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 35 | for image,target in loader: 36 | 37 | task_idx = target.numpy()[0] // 5 #num_task 38 | #print("task_idx", task_idx) 39 | data[task_idx][s]['x'].append(image) 40 | data[task_idx][s]['y'].append(target.numpy()[0] % 5) 41 | 42 | # "Unify" and save 43 | for t in range(tasknum): 44 | for s in ['train','test']: 45 | data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2]) 46 | data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1) 47 | torch.save(data[t][s]['x'], os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 48 | 'data'+str(t+1)+s+'x.bin')) 49 | torch.save(data[t][s]['y'], os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 50 | 'data'+str(t+1)+s+'y.bin')) 51 | 52 | # Load binary files 53 | data={} 54 | data[0] = dict.fromkeys(['name','ncla','train','test']) 55 | ids=list(shuffle(np.arange(tasknum),random_state=seed)+1) 56 | print('Task order =',ids) 57 | for i in range(tasknum): 58 | data[i] = dict.fromkeys(['name','ncla','train','test']) 59 | for s in ['train','test']: 60 | data[i][s]={'x':[],'y':[]} 61 | data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 62 | 'data'+str(ids[i])+s+'x.bin')) 63 | data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 64 | 'data'+str(ids[i])+s+'y.bin')) 65 | data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy())) 66 | data[i]['name']='cifar100-'+str(ids[i-1]) 67 | 68 | # Validation 69 | for t in range(tasknum): 70 | r=np.arange(data[t]['train']['x'].size(0)) 71 | r=np.array(shuffle(r,random_state=seed),dtype=int) 72 | nvalid=int(pc_valid*len(r)) 73 | ivalid=torch.LongTensor(r[:nvalid]) 74 | itrain=torch.LongTensor(r[nvalid:]) 75 | data[t]['valid']={} 76 | data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone() 77 | data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone() 78 | data[t]['train']['x']=data[t]['train']['x'][itrain].clone() 79 | data[t]['train']['y']=data[t]['train']['y'][itrain].clone() 80 | 81 | # Others 82 | n=0 83 | for t in range(tasknum): 84 | taskcla.append((t,data[t]['ncla'])) 85 | n+=data[t]['ncla'] 86 | data['ncla']=n 87 | 88 | return data,taskcla,size 89 | -------------------------------------------------------------------------------- /LargeScale/trainer/mas.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data as td 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import trainer 14 | 15 | import networks 16 | 17 | 18 | class Trainer(trainer.GenericTrainer): 19 | def __init__(self, model, args, optimizer, evaluator, taskcla): 20 | super().__init__(model, args, optimizer, evaluator, taskcla) 21 | 22 | self.lamb=args.lamb 23 | self.omega = {} 24 | for n,_ in self.model.named_parameters(): 25 | self.omega[n] = 0 26 | 27 | def update_lr(self, epoch, schedule): 28 | for temp in range(0, len(schedule)): 29 | if schedule[temp] == epoch: 30 | for param_group in self.optimizer.param_groups: 31 | self.current_lr = param_group['lr'] 32 | param_group['lr'] = self.current_lr * self.args.gammas[temp] 33 | print("Changing learning rate from %0.4f to %0.4f"%(self.current_lr, 34 | self.current_lr * self.args.gammas[temp])) 35 | self.current_lr *= self.args.gammas[temp] 36 | 37 | 38 | def setup_training(self, lr): 39 | 40 | for param_group in self.optimizer.param_groups: 41 | print("Setting LR to %0.4f"%lr) 42 | param_group['lr'] = lr 43 | self.current_lr = lr 44 | 45 | def update_frozen_model(self): 46 | self.model.eval() 47 | self.model_fixed = copy.deepcopy(self.model) 48 | self.model_fixed.eval() 49 | for param in self.model_fixed.parameters(): 50 | param.requires_grad = False 51 | 52 | def train(self, train_loader, test_loader, t): 53 | 54 | lr = self.args.lr 55 | self.setup_training(lr) 56 | # Do not update self.t 57 | if t>0: 58 | self.update_frozen_model() 59 | self.omega_update() 60 | 61 | # Now, you can update self.t 62 | self.t = t 63 | #kwargs = {'num_workers': 8, 'pin_memory': True} 64 | kwargs = {'num_workers': 0, 'pin_memory': False} 65 | self.train_iterator = torch.utils.data.DataLoader(train_loader, batch_size=self.args.batch_size, shuffle=True, **kwargs) 66 | self.test_iterator = torch.utils.data.DataLoader(test_loader, 100, shuffle=False, **kwargs) 67 | self.omega_iterator = torch.utils.data.DataLoader(train_loader, batch_size=20, shuffle=True, **kwargs) 68 | for epoch in range(self.args.nepochs): 69 | self.model.train() 70 | self.update_lr(epoch, self.args.schedule) 71 | for samples in tqdm(self.train_iterator): 72 | data, target = samples 73 | data, target = data.cuda(), target.cuda() 74 | 75 | output = self.model(data)[t] 76 | loss_CE = self.criterion(output,target) 77 | 78 | self.optimizer.zero_grad() 79 | (loss_CE).backward() 80 | self.optimizer.step() 81 | 82 | 83 | train_loss,train_acc = self.evaluator.evaluate(self.model, self.train_iterator, t) 84 | num_batch = len(self.train_iterator) 85 | print('| Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% |'.format(epoch+1,train_loss,100*train_acc),end='') 86 | valid_loss,valid_acc=self.evaluator.evaluate(self.model, self.test_iterator, t) 87 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 88 | print() 89 | 90 | 91 | 92 | def criterion(self,output,targets): 93 | # Regularization for all previous tasks 94 | loss_reg=0 95 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_fixed.named_parameters()): 96 | loss_reg+=torch.sum(self.omega[name]*(param_old-param).pow(2))/2 97 | 98 | return self.ce(output,targets)+self.lamb*loss_reg 99 | 100 | 101 | def omega_update(self): 102 | sbatch = 20 103 | 104 | # Compute 105 | self.model.train() 106 | for samples in tqdm(self.omega_iterator): 107 | data, target = samples 108 | data, target = data.cuda(), target.cuda() 109 | # Forward and backward 110 | self.model.zero_grad() 111 | outputs = self.model.forward(data)[self.t] 112 | 113 | # Sum of L2 norm of output scores 114 | loss = torch.sum(outputs.norm(2, dim = -1)) 115 | loss.backward() 116 | 117 | # Get gradients 118 | for n,p in self.model.named_parameters(): 119 | if p.grad is not None: 120 | self.omega[n]+= p.grad.data.abs() / len(self.train_iterator) 121 | 122 | return 123 | 124 | -------------------------------------------------------------------------------- /LargeScale/data_handler/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | import torch 3 | import numpy as np 4 | from arguments import get_args 5 | import math 6 | import time 7 | args = get_args() 8 | 9 | 10 | class Dataset(): 11 | ''' 12 | Base class to reprenent a Dataset 13 | ''' 14 | 15 | def __init__(self, classes, name, tasknum): 16 | self.classes = classes 17 | self.name = name 18 | self.tasknum = tasknum 19 | self.train_data = None 20 | self.test_data = None 21 | self.loader = None 22 | 23 | class CUB200(Dataset): 24 | def __init__(self): 25 | super().__init__(200, "CUB200", args.tasknum) 26 | 27 | mean = [0.485, 0.500, 0.432] 28 | std = [0.232, 0.227, 0.266] 29 | 30 | 31 | self.train_transform = transforms.Compose([ 32 | transforms.RandomResizedCrop(224), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | transforms.Normalize(mean, std), 36 | ]) 37 | 38 | self.test_transform = transforms.Compose([ 39 | transforms.Resize(256), 40 | transforms.CenterCrop(224), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean, std), 43 | ]) 44 | 45 | print('Load start!') 46 | clock1 = time.time() 47 | data = datasets.ImageFolder("../dat/CUB_200_2011/images", transform=self.train_transform) 48 | self.loader = data.loader 49 | 50 | self.train_data = [] 51 | self.train_labels = [] 52 | self.test_data = [] 53 | self.test_labels = [] 54 | self.data_dict = {} 55 | class_cnt = [0]*200 56 | class_num = [ 57 | 60, 60, 58, 60, 44, 41, 53, 48, 59, 60, 60, 56, 60, 60, 58, 58, 57, 45, 59, 59, 58 | 60, 56, 59, 52, 60, 60, 60, 59, 60, 60, 60, 53, 59, 59, 60, 60, 59, 60, 59, 60, 59 | 60, 60, 59, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 58, 60, 59, 60 | 60, 60, 60, 60, 50, 60, 60, 60, 60, 60, 60, 60, 60, 60, 57, 60, 60, 59, 60, 60, 61 | 60, 60, 60, 53, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 59, 60, 60, 60, 62 | 50, 60, 60, 60, 49, 60, 59, 60, 60, 60, 60, 60, 50, 60, 59, 60, 59, 60, 59, 60, 63 | 60, 60, 60, 59, 59, 59, 60, 60, 60, 60, 60, 60, 60, 60, 59, 60, 60, 60, 60, 60, 64 | 58, 60, 60, 60, 60, 60, 60, 60, 59, 60, 51, 60, 59, 60, 60, 60, 59, 60, 60, 59, 65 | 60, 60, 60, 60, 60, 59, 60, 59, 59, 60, 60, 60, 60, 60, 60, 60, 60, 56, 59, 60, 66 | 59, 60, 60, 60, 60, 60, 50, 60, 60, 58, 60, 60, 60, 60, 60, 59, 60, 60, 60, 60 67 | ] 68 | for i in range(len(data.imgs)): 69 | path, target = data.imgs[i] 70 | # self.data_dict[path] = self.loader(path) 71 | if class_cnt[target] < math.ceil(class_num[target]*0.8): 72 | self.train_data.append(path) 73 | self.train_labels.append(target) 74 | else: 75 | self.test_data.append(path) 76 | self.test_labels.append(target) 77 | class_cnt[target] += 1 78 | 79 | self.train_data = np.stack(self.train_data, axis=0) 80 | self.test_data = np.stack(self.test_data, axis=0) 81 | 82 | self.taskcla = [] 83 | 84 | clock2 = time.time() 85 | print('Load finished!') 86 | print('Time elapse: %d'%(clock2-clock1)) 87 | 88 | for t in range(self.tasknum): 89 | self.taskcla.append((t, self.classes // self.tasknum)) 90 | 91 | 92 | class ImageNet(Dataset): 93 | def __init__(self): 94 | super().__init__(1000, "ImageNet", args.tasknum) 95 | 96 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 97 | std=[0.229, 0.224, 0.225]) 98 | 99 | traindir = '/data/LargeData/Large/ImageNet/train' 100 | valdir = '/data/LargeData/Large/ImageNet/val' 101 | 102 | self.train_transform = transforms.Compose([ 103 | transforms.RandomResizedCrop(224), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | normalize,]) 107 | 108 | self.test_transform = transforms.Compose([ 109 | transforms.Resize(256), 110 | transforms.CenterCrop(224), 111 | transforms.ToTensor(), 112 | normalize,]) 113 | 114 | trainset = datasets.ImageFolder(traindir, self.train_transform) 115 | 116 | testset = datasets.ImageFolder(valdir, self.train_transform) 117 | 118 | self.train_data = [] 119 | self.train_labels = [] 120 | self.test_data = [] 121 | self.test_labels = [] 122 | self.data_dict = {} 123 | 124 | for i in range(len(trainset.imgs)): 125 | path, target = trainset.imgs[i] 126 | self.train_data.append(path) 127 | self.train_labels.append(target) 128 | 129 | for i in range(len(testset.imgs)): 130 | path, target = testset.imgs[i] 131 | self.test_data.append(path) 132 | self.test_labels.append(target) 133 | 134 | self.train_data = np.stack(self.train_data, axis=0) 135 | self.test_data = np.stack(self.test_data, axis=0) 136 | 137 | 138 | self.loader = trainset.loader 139 | self.taskcla = [] 140 | 141 | for t in range(self.tasknum): 142 | self.taskcla.append((t, self.classes // self.tasknum)) 143 | 144 | print('Load finished!') 145 | -------------------------------------------------------------------------------- /LargeScale/trainer/ewc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data as td 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import trainer 14 | 15 | import networks 16 | 17 | class Trainer(trainer.GenericTrainer): 18 | def __init__(self, model, args, optimizer, evaluator, taskcla): 19 | super().__init__(model, args, optimizer, evaluator, taskcla) 20 | 21 | self.lamb=args.lamb 22 | 23 | def update_lr(self, epoch, schedule): 24 | for temp in range(0, len(schedule)): 25 | if schedule[temp] == epoch: 26 | for param_group in self.optimizer.param_groups: 27 | self.current_lr = param_group['lr'] 28 | param_group['lr'] = self.current_lr * self.args.gammas[temp] 29 | print("Changing learning rate from %0.4f to %0.4f"%(self.current_lr, 30 | self.current_lr * self.args.gammas[temp])) 31 | self.current_lr *= self.args.gammas[temp] 32 | 33 | 34 | def setup_training(self, lr): 35 | 36 | for param_group in self.optimizer.param_groups: 37 | print("Setting LR to %0.4f"%lr) 38 | param_group['lr'] = lr 39 | self.current_lr = lr 40 | 41 | def update_frozen_model(self): 42 | self.model.eval() 43 | self.model_fixed = copy.deepcopy(self.model) 44 | self.model_fixed.eval() 45 | for param in self.model_fixed.parameters(): 46 | param.requires_grad = False 47 | 48 | def train(self, train_loader, test_loader, t): 49 | 50 | lr = self.args.lr 51 | self.setup_training(lr) 52 | # Do not update self.t 53 | if t>0: 54 | self.update_frozen_model() 55 | self.update_fisher() 56 | 57 | # Now, you can update self.t 58 | self.t = t 59 | 60 | #kwargs = {'num_workers': 0, 'pin_memory': True} 61 | kwargs = {'num_workers': 0, 'pin_memory': False} 62 | self.train_iterator = torch.utils.data.DataLoader(train_loader, batch_size=self.args.batch_size, shuffle=True, **kwargs) 63 | self.test_iterator = torch.utils.data.DataLoader(test_loader, 100, shuffle=False, **kwargs) 64 | self.fisher_iterator = torch.utils.data.DataLoader(train_loader, batch_size=20, shuffle=True, **kwargs) 65 | for epoch in range(self.args.nepochs): 66 | self.model.train() 67 | self.update_lr(epoch, self.args.schedule) 68 | for samples in tqdm(self.train_iterator): 69 | data, target = samples 70 | data, target = data.cuda(), target.cuda() 71 | batch_size = data.shape[0] 72 | 73 | output = self.model(data)[t] 74 | loss_CE = self.criterion(output,target) 75 | 76 | self.optimizer.zero_grad() 77 | (loss_CE).backward() 78 | self.optimizer.step() 79 | 80 | train_loss,train_acc = self.evaluator.evaluate(self.model, self.train_iterator, t) 81 | num_batch = len(self.train_iterator) 82 | print('| Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% |'.format(epoch+1,train_loss,100*train_acc),end='') 83 | valid_loss,valid_acc=self.evaluator.evaluate(self.model, self.test_iterator, t) 84 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 85 | print() 86 | 87 | def criterion(self,output,targets): 88 | # Regularization for all previous tasks 89 | loss_reg=0 90 | if self.t>0: 91 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_fixed.named_parameters()): 92 | loss_reg+=torch.sum(self.fisher[name]*(param_old-param).pow(2))/2 93 | return self.ce(output,targets)+self.lamb*loss_reg 94 | 95 | def fisher_matrix_diag(self): 96 | # Init 97 | fisher={} 98 | for n,p in self.model.named_parameters(): 99 | fisher[n]=0*p.data 100 | # Compute 101 | self.model.train() 102 | criterion = torch.nn.CrossEntropyLoss() 103 | for samples in tqdm(self.fisher_iterator): 104 | data, target = samples 105 | data, target = data.cuda(), target.cuda() 106 | 107 | # Forward and backward 108 | self.model.zero_grad() 109 | outputs = self.model.forward(data)[self.t] 110 | loss=self.criterion(outputs, target) 111 | loss.backward() 112 | 113 | # Get gradients 114 | for n,p in self.model.named_parameters(): 115 | if p.grad is not None: 116 | fisher[n]+=self.args.batch_size*p.grad.data.pow(2) 117 | # Mean 118 | with torch.no_grad(): 119 | for n,_ in self.model.named_parameters(): 120 | fisher[n]=fisher[n]/len(self.train_iterator) 121 | return fisher 122 | 123 | 124 | def update_fisher(self): 125 | if self.t>0: 126 | fisher_old={} 127 | for n,_ in self.model.named_parameters(): 128 | fisher_old[n]=self.fisher[n].clone() 129 | self.fisher=self.fisher_matrix_diag() 130 | if self.t>0: 131 | for n,_ in self.model.named_parameters(): 132 | self.fisher[n]=(self.fisher[n]+fisher_old[n]*self.t)/(self.t+1) 133 | -------------------------------------------------------------------------------- /dataloaders/split_cifar100_SC.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import torch 4 | import utils 5 | from torchvision import datasets,transforms 6 | from sklearn.utils import shuffle 7 | 8 | def get(seed=0,pc_valid=0.10, tasknum = 20): 9 | data={} 10 | taskcla=[] 11 | size=[3,32,32] 12 | tasknum = 20 13 | 14 | if not os.path.isdir('../dat/binary_split_cifar100_5_spcls/'): 15 | os.makedirs('../dat/binary_split_cifar100_5_spcls') 16 | 17 | mean = [0.5071, 0.4867, 0.4408] 18 | std = [0.2675, 0.2565, 0.2761] 19 | 20 | superclass = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 21 | 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 22 | 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 23 | 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 24 | 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 25 | 16, 4, 17, 4, 2, 0, 17, 4, 18, 17, 26 | 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 27 | 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 28 | 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 29 | 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]) 30 | 31 | # CIFAR100 32 | dat={} 33 | 34 | dat['train']=datasets.CIFAR100('../dat/',train=True,download=True, 35 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 36 | dat['test']=datasets.CIFAR100('../dat/',train=False,download=True, 37 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 38 | for n in range(tasknum): 39 | data[n]={} 40 | data[n]['name']='cifar100' 41 | data[n]['ncla']= 5 42 | data[n]['train']={'x': [],'y': []} 43 | data[n]['test']={'x': [],'y': []} 44 | 45 | for s in ['train','test']: 46 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 47 | for image,target in loader: 48 | task_idx = superclass[target] 49 | #task_idx = target.numpy()[0] // 5 #num_task 50 | #print("task_idx", task_idx) 51 | data[task_idx][s]['x'].append(image) 52 | data[task_idx][s]['y'].append(target.numpy()[0]) 53 | 54 | for s in ['train','test']: 55 | for task_idx in range(tasknum): #20 tasks 56 | unique_label_t = np.unique(data[task_idx][s]['y']) 57 | #print("unique_label_t", unique_label_t) 58 | for i in range(len(data[task_idx][s]['y'])): 59 | for j in range(len(unique_label_t)): 60 | if data[task_idx][s]['y'][i] == unique_label_t[j]: 61 | #print('data[task_idx][s][y][i]', data[task_idx][s]['y'][i]) 62 | data[task_idx][s]['y'][i] = j 63 | #print('data[task_idx][s][y][i]', data[task_idx][s]['y'][i]) 64 | 65 | # "Unify" and save 66 | for t in range(tasknum): 67 | for s in ['train','test']: 68 | data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2]) 69 | data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1) 70 | torch.save(data[t][s]['x'], os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5_spcls'), 71 | 'data'+str(t+1)+s+'x.bin')) 72 | torch.save(data[t][s]['y'], os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5_spcls'), 73 | 'data'+str(t+1)+s+'y.bin')) 74 | 75 | # Load binary files 76 | data={} 77 | data[0] = dict.fromkeys(['name','ncla','train','test']) 78 | ids=list(shuffle(np.arange(tasknum),random_state=seed)+1) 79 | print('Task order =',ids) 80 | for i in range(tasknum): 81 | data[i] = dict.fromkeys(['name','ncla','train','test']) 82 | for s in ['train','test']: 83 | data[i][s]={'x':[],'y':[]} 84 | data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5_spcls'), 85 | 'data'+str(ids[i])+s+'x.bin')) 86 | data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5_spcls'), 87 | 'data'+str(ids[i])+s+'y.bin')) 88 | data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy())) 89 | data[i]['name']='cifar100-'+str(ids[i-1]) 90 | 91 | # Validation 92 | for t in range(tasknum): 93 | r=np.arange(data[t]['train']['x'].size(0)) 94 | r=np.array(shuffle(r,random_state=seed),dtype=int) 95 | nvalid=int(pc_valid*len(r)) 96 | ivalid=torch.LongTensor(r[:nvalid]) 97 | itrain=torch.LongTensor(r[nvalid:]) 98 | data[t]['valid']={} 99 | data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone() 100 | data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone() 101 | data[t]['train']['x']=data[t]['train']['x'][itrain].clone() 102 | data[t]['train']['y']=data[t]['train']['y'][itrain].clone() 103 | 104 | # Others 105 | n=0 106 | for t in range(tasknum): 107 | taskcla.append((t,data[t]['ncla'])) 108 | n+=data[t]['ncla'] 109 | data['ncla']=n 110 | 111 | return data,taskcla,size 112 | -------------------------------------------------------------------------------- /approaches/random_init.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | args = get_args() 12 | 13 | 14 | class Appr(object): 15 | """ Class implementing the fine tuning """ 16 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None): 17 | self.model=model 18 | self.model_old=model 19 | self.fisher=None 20 | self.model_emp = model 21 | 22 | self.nepochs = nepochs 23 | self.sbatch = sbatch 24 | self.lr = lr 25 | self.lr_min = lr_min * 1/3 26 | self.lr_factor = lr_factor 27 | self.lr_patience = lr_patience 28 | self.clipgrad = clipgrad 29 | 30 | self.ce=torch.nn.CrossEntropyLoss() 31 | self.optimizer=self._get_optimizer() 32 | self.lamb=args.lamb 33 | if len(args.parameter)>=1: 34 | params=args.parameter.split(',') 35 | print('Setting parameters to',params) 36 | self.lamb=float(params[0]) 37 | 38 | return 39 | 40 | def _get_optimizer(self,lr=None): 41 | if lr is None: lr=self.lr 42 | 43 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 44 | return optimizer 45 | 46 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 47 | best_loss = np.inf 48 | best_model = utils.get_model(self.model) 49 | lr = self.lr 50 | self.optimizer = self._get_optimizer(lr) 51 | 52 | if t == 0: 53 | self.model_emp = deepcopy(self.model) 54 | else: 55 | self.model = deepcopy(self.model_emp) #random initialize model 56 | 57 | # Loop epochs 58 | for e in range(self.nepochs): 59 | # Train 60 | clock0=time.time() 61 | 62 | num_batch = xtrain.size(0) 63 | 64 | self.train_epoch(t,xtrain,ytrain, e) 65 | 66 | clock1=time.time() 67 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 68 | clock2=time.time() 69 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 70 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 71 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 72 | # Valid 73 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 74 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 75 | print(' lr : {:.6f}'.format(self.optimizer.param_groups[0]['lr'])) 76 | #save log for current task & old tasks at every epoch 77 | 78 | # Adapt lr 79 | if valid_loss < best_loss: 80 | best_loss = valid_loss 81 | best_model = utils.get_model(self.model) 82 | patience = self.lr_patience 83 | print(' *', end='') 84 | 85 | else: 86 | patience -= 1 87 | if patience <= 0: 88 | lr /= self.lr_factor 89 | print(' lr={:.1e}'.format(lr), end='') 90 | if lr < self.lr_min: 91 | print() 92 | patience = self.lr_patience 93 | self.optimizer = self._get_optimizer(lr) 94 | print() 95 | 96 | # Restore best 97 | utils.set_model_(self.model, best_model) 98 | 99 | # Update old 100 | self.model_old = deepcopy(self.model) 101 | self.model_old.train() 102 | utils.freeze_model(self.model_old) # Freeze the weights 103 | 104 | return 105 | 106 | def train_epoch(self,t,x,y, epoch): 107 | self.model.train() 108 | 109 | r=np.arange(x.size(0)) 110 | np.random.shuffle(r) 111 | r=torch.LongTensor(r).cuda() 112 | 113 | # Loop batches 114 | for i in range(0,len(r),self.sbatch): 115 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 116 | else: b=r[i:] 117 | images=x[b] 118 | targets=y[b] 119 | 120 | # Forward current model 121 | outputs = self.model.forward(images)[t] 122 | loss=self.criterion(t,outputs,targets) 123 | 124 | # Backward 125 | self.optimizer.zero_grad() 126 | loss.backward() 127 | self.optimizer.step() 128 | return 129 | 130 | def eval(self,t,x,y): 131 | total_loss=0 132 | total_acc=0 133 | total_num=0 134 | self.model.eval() 135 | 136 | r = np.arange(x.size(0)) 137 | r = torch.LongTensor(r).cuda() 138 | 139 | # Loop batches 140 | for i in range(0,len(r),self.sbatch): 141 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 142 | else: b=r[i:] 143 | images=x[b] 144 | targets=y[b] 145 | 146 | # Forward 147 | 148 | output = self.model.forward(images)[t] 149 | 150 | loss=self.criterion(t,output,targets) 151 | _,pred=output.max(1) 152 | hits=(pred==targets).float() 153 | 154 | # Log 155 | total_loss+=loss.data.cpu().numpy()*len(b) 156 | total_acc+=hits.sum().data.cpu().numpy() 157 | total_num+=len(b) 158 | 159 | return total_loss/total_num,total_acc/total_num 160 | 161 | def criterion(self,t,output,targets): 162 | return self.ce(output,targets) -------------------------------------------------------------------------------- /LargeScale/trainer/rwalk.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data as td 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import trainer 14 | 15 | import networks 16 | 17 | class Trainer(trainer.GenericTrainer): 18 | def __init__(self, model, args, optimizer, evaluator, taskcla): 19 | super().__init__(model, args, optimizer, evaluator, taskcla) 20 | 21 | self.lamb=args.lamb 22 | self.alpha = 0.9 23 | self.s = {} 24 | self.s_running = {} 25 | self.fisher = {} 26 | self.fisher_running = {} 27 | self.p_old = {} 28 | 29 | self.eps = 0.01 30 | 31 | for n, p in self.model.named_parameters(): 32 | if p.requires_grad: 33 | self.s[n] = 0 34 | self.s_running[n] = 0 35 | self.fisher[n] = 0 36 | self.fisher_running[n] = 0 37 | self.p_old[n] = p.data.clone() 38 | 39 | def update_lr(self, epoch, schedule): 40 | for temp in range(0, len(schedule)): 41 | if schedule[temp] == epoch: 42 | for param_group in self.optimizer.param_groups: 43 | self.current_lr = param_group['lr'] 44 | param_group['lr'] = self.current_lr * self.args.gammas[temp] 45 | print("Changing learning rate from %0.4f to %0.4f"%(self.current_lr, 46 | self.current_lr * self.args.gammas[temp])) 47 | self.current_lr *= self.args.gammas[temp] 48 | 49 | 50 | def setup_training(self, lr): 51 | 52 | for param_group in self.optimizer.param_groups: 53 | print("Setting LR to %0.4f"%lr) 54 | param_group['lr'] = lr 55 | self.current_lr = lr 56 | 57 | def update_frozen_model(self): 58 | self.model.eval() 59 | self.model_fixed = copy.deepcopy(self.model) 60 | self.model_fixed.eval() 61 | for param in self.model_fixed.parameters(): 62 | param.requires_grad = False 63 | 64 | def train(self, train_loader, test_loader, t): 65 | 66 | lr = self.args.lr 67 | 68 | self.setup_training(lr) 69 | # Do not update self.t 70 | if t>0: 71 | self.update_frozen_model() 72 | self.freeze_fisher_and_s() 73 | 74 | # Now, you can update self.t 75 | self.t = t 76 | #kwargs = {'num_workers': 8, 'pin_memory': True} 77 | kwargs = {'num_workers': 0, 'pin_memory': False} 78 | self.train_iterator = torch.utils.data.DataLoader(train_loader, batch_size=self.args.batch_size, shuffle=True, **kwargs) 79 | self.test_iterator = torch.utils.data.DataLoader(test_loader, 100, shuffle=False, **kwargs) 80 | for epoch in range(self.args.nepochs): 81 | self.model.train() 82 | self.update_lr(epoch, self.args.schedule) 83 | for samples in tqdm(self.train_iterator): 84 | data, target = samples 85 | data, target = data.cuda(), target.cuda() 86 | 87 | output = self.model(data)[t] 88 | loss_CE = self.criterion(output,target) 89 | 90 | self.optimizer.zero_grad() 91 | (loss_CE).backward() 92 | self.optimizer.step() 93 | self.update_fisher_and_s() 94 | 95 | 96 | train_loss,train_acc = self.evaluator.evaluate(self.model, self.train_iterator, t) 97 | num_batch = len(self.train_iterator) 98 | print('| Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% |'.format(epoch+1,train_loss,100*train_acc),end='') 99 | valid_loss,valid_acc=self.evaluator.evaluate(self.model, self.test_iterator, t) 100 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 101 | print() 102 | 103 | def criterion(self,output,targets): 104 | # Regularization for all previous tasks 105 | loss_reg=0 106 | if self.t>0: 107 | for (n,param),(_,param_old) in zip(self.model.named_parameters(),self.model_fixed.named_parameters()): 108 | loss_reg+=torch.sum((self.fisher[n] + self.s[n])*(param_old-param).pow(2)) 109 | return self.ce(output,targets)+self.lamb*loss_reg 110 | 111 | def update_fisher_and_s(self): 112 | for n, p in self.model.named_parameters(): 113 | if p.requires_grad: 114 | if p.grad is not None: 115 | # Compute running fisher 116 | fisher_current = p.grad.data.pow(2) 117 | self.fisher_running[n] = self.alpha*fisher_current + (1-self.alpha)*self.fisher_running[n] 118 | 119 | # Compute running s 120 | loss_diff = -p.grad * (p.detach() - self.p_old[n]) 121 | fisher_distance = (1/2) * (self.fisher_running[n]*(p.detach() - self.p_old[n])**2) 122 | s = loss_diff / (fisher_distance+self.eps) 123 | self.s_running[n] = self.s_running[n] + s 124 | 125 | self.p_old[n] = p.detach().clone() 126 | 127 | 128 | def freeze_fisher_and_s(self): 129 | for n,p in self.model.named_parameters(): 130 | if p.requires_grad: 131 | if p.grad is not None: 132 | self.fisher[n] = self.fisher_running[n].clone() 133 | self.s[n] = (1/2) * self.s_running[n].clone() 134 | self.s_running[n] = self.s[n].clone() 135 | 136 | -------------------------------------------------------------------------------- /networks/resnet/conv_net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from utils import * 5 | import torch.nn.functional as F 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | class PreActBlock(nn.Module): 11 | '''Pre-activation version of the BasicBlock.''' 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1, droprate=0): 15 | super(PreActBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, planes, stride) 18 | self.drop = nn.Dropout(p=droprate) if droprate > 0 else None 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.conv2 = conv3x3(planes, planes) 21 | 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, 25 | kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.bn1(x) 30 | out = F.relu(x) 31 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 32 | out = self.conv1(out) 33 | if self.drop is not None: 34 | out = self.drop(out) 35 | out = self.bn2(out) 36 | out = self.conv2(F.relu(out)) 37 | out += shortcut 38 | return out 39 | 40 | class Net(nn.Module): 41 | def __init__(self, inputsize, taskcla, block=PreActBlock, num_blocks=[2, 2, 2, 2], num_classes=10, in_channels=3): 42 | super().__init__() 43 | #super(PreActResNet, self).__init__() 44 | self.in_planes = 64 45 | last_planes = 512*block.expansion 46 | 47 | ncha,size,_=inputsize 48 | self.taskcla = taskcla 49 | 50 | self.conv1 = conv3x3(in_channels, 64) 51 | self.stage1 = self._make_layer(block, 64, num_blocks[0], stride=1) 52 | self.stage2 = self._make_layer(block, 128, num_blocks[1], stride=2) 53 | self.stage3 = self._make_layer(block, 256, num_blocks[2], stride=2) 54 | self.stage4 = self._make_layer(block, 512, num_blocks[3], stride=2) 55 | self.bn_last = nn.BatchNorm2d(last_planes) 56 | #self.last = nn.Linear(last_planes, num_classes) # last layer 57 | 58 | self.last = torch.nn.ModuleList() 59 | for t, n in self.taskcla: 60 | self.last.append(torch.nn.Linear(last_planes, n, bias=False)) 61 | 62 | def _make_layer(self, block, planes, num_blocks, stride): 63 | strides = [stride] + [1]*(num_blocks-1) 64 | layers = [] 65 | for stride in strides: 66 | layers.append(block(self.in_planes, planes, stride)) 67 | self.in_planes = planes * block.expansion 68 | return nn.Sequential(*layers) 69 | 70 | def features(self, x): 71 | out = self.conv1(x) 72 | out = self.stage1(out) 73 | out = self.stage2(out) 74 | out = self.stage3(out) 75 | out = self.stage4(out) 76 | return out 77 | 78 | def forward(self, x): 79 | x = self.features(x) 80 | x = self.bn_last(x) 81 | x = F.relu(x) 82 | x = F.adaptive_avg_pool2d(x, 1) 83 | 84 | y = [] 85 | for t,i in self.taskcla: 86 | y.append(self.last[t](x.view(x.size(0), -1))) 87 | return y 88 | 89 | 90 | def resnet18(): 91 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=10) 92 | 93 | ''' 94 | class Net(nn.Module): 95 | def __init__(self, inputsize, taskcla): 96 | super().__init__() 97 | 98 | ncha,size,_=inputsize 99 | self.taskcla = taskcla 100 | 101 | self.conv1 = nn.Conv2d(ncha,32,kernel_size=3,padding=1) 102 | s = compute_conv_output_size(size,3, padding=1) # 32 103 | self.conv2 = nn.Conv2d(32,32,kernel_size=3,padding=1) 104 | s = compute_conv_output_size(s,3, padding=1) # 32 105 | s = s//2 # 16 106 | self.conv3 = nn.Conv2d(32,64,kernel_size=3,padding=1) 107 | s = compute_conv_output_size(s,3, padding=1) # 16 108 | self.conv4 = nn.Conv2d(64,64,kernel_size=3,padding=1) 109 | s = compute_conv_output_size(s,3, padding=1) # 16 110 | s = s//2 # 8 111 | self.conv5 = nn.Conv2d(64,128,kernel_size=3,padding=1) 112 | s = compute_conv_output_size(s,3, padding=1) # 8 113 | self.conv6 = nn.Conv2d(128,128,kernel_size=3,padding=1) 114 | s = compute_conv_output_size(s,3, padding=1) # 8 115 | # self.conv7 = nn.Conv2d(128,128,kernel_size=3,padding=1) 116 | # s = compute_conv_output_size(s,3, padding=1) # 8 117 | s = s//2 # 4 118 | self.fc1 = nn.Linear(s*s*128,256) # 2048 119 | self.drop1 = nn.Dropout(0.25) 120 | self.drop2 = nn.Dropout(0.5) 121 | self.MaxPool = torch.nn.MaxPool2d(2) 122 | self.avg_neg = [] 123 | self.last=torch.nn.ModuleList() 124 | 125 | for t,n in self.taskcla: 126 | self.last.append(torch.nn.Linear(256,n)) 127 | self.relu = torch.nn.ReLU() 128 | 129 | def forward(self, x, avg_act = False): 130 | act1=self.relu(self.conv1(x)) 131 | act2=self.relu(self.conv2(act1)) 132 | h=self.drop1(self.MaxPool(act2)) 133 | act3=self.relu(self.conv3(h)) 134 | act4=self.relu(self.conv4(act3)) 135 | h=self.drop1(self.MaxPool(act4)) 136 | act5=self.relu(self.conv5(h)) 137 | act6=self.relu(self.conv6(act5)) 138 | h=self.drop1(self.MaxPool(act6)) 139 | h=h.view(x.shape[0],-1) 140 | act7 = self.relu(self.fc1(h)) 141 | h = self.drop2(act7) 142 | y = [] 143 | for t,i in self.taskcla: 144 | y.append(self.last[t](h)) 145 | 146 | return y 147 | ''' -------------------------------------------------------------------------------- /approaches/ewc.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | args = get_args() 12 | 13 | 14 | class Appr(object): 15 | """ Class implementing the Elastic Weight Consolidation approach described in http://arxiv.org/abs/1612.00796 """ 16 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None): 17 | self.model=model 18 | self.model_old=model 19 | self.fisher=None 20 | 21 | self.nepochs = nepochs 22 | self.sbatch = sbatch 23 | self.lr = lr 24 | self.lr_min = lr_min * 1/3 25 | self.lr_factor = lr_factor 26 | self.lr_patience = lr_patience 27 | self.clipgrad = clipgrad 28 | 29 | self.ce=torch.nn.CrossEntropyLoss() 30 | self.optimizer=self._get_optimizer() 31 | self.lamb=args.lamb 32 | if len(args.parameter)>=1: 33 | params=args.parameter.split(',') 34 | print('Setting parameters to',params) 35 | self.lamb=float(params[0]) 36 | 37 | return 38 | 39 | def _get_optimizer(self,lr=None): 40 | if lr is None: lr=self.lr 41 | 42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 43 | return optimizer 44 | 45 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 46 | best_loss = np.inf 47 | best_model = utils.get_model(self.model) 48 | lr = self.lr 49 | self.optimizer = self._get_optimizer(lr) 50 | 51 | # Loop epochs 52 | for e in range(self.nepochs): 53 | # Train 54 | clock0=time.time() 55 | 56 | num_batch = xtrain.size(0) 57 | 58 | self.train_epoch(t,xtrain,ytrain, e) 59 | 60 | clock1=time.time() 61 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 62 | clock2=time.time() 63 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 64 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 65 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 66 | # Valid 67 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 68 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 69 | print(' lr : {:.6f}'.format(self.optimizer.param_groups[0]['lr'])) 70 | #save log for current task & old tasks at every epoch 71 | 72 | # Adapt lr 73 | if valid_loss < best_loss: 74 | best_loss = valid_loss 75 | best_model = utils.get_model(self.model) 76 | patience = self.lr_patience 77 | print(' *', end='') 78 | 79 | else: 80 | patience -= 1 81 | if patience <= 0: 82 | lr /= self.lr_factor 83 | print(' lr={:.1e}'.format(lr), end='') 84 | if lr < self.lr_min: 85 | print() 86 | patience = self.lr_patience 87 | self.optimizer = self._get_optimizer(lr) 88 | print() 89 | 90 | # Restore best 91 | utils.set_model_(self.model, best_model) 92 | 93 | # Update old 94 | self.model_old = deepcopy(self.model) 95 | self.model_old.train() 96 | utils.freeze_model(self.model_old) # Freeze the weights 97 | 98 | # Fisher ops 99 | if t>0: 100 | fisher_old={} 101 | for n,_ in self.model.named_parameters(): 102 | fisher_old[n]=self.fisher[n].clone() 103 | self.fisher=utils.fisher_matrix_diag(t,xtrain,ytrain,self.model,self.criterion) 104 | if t>0: 105 | # Watch out! We do not want to keep t models (or fisher diagonals) in memory, therefore we have to merge fisher diagonals 106 | for n,_ in self.model.named_parameters(): 107 | self.fisher[n]=(self.fisher[n]+fisher_old[n]*t)/(t+1) # Checked: it is better than the other option 108 | 109 | return 110 | 111 | def train_epoch(self,t,x,y, epoch): 112 | self.model.train() 113 | 114 | r=np.arange(x.size(0)) 115 | np.random.shuffle(r) 116 | r=torch.LongTensor(r).cuda() 117 | 118 | # Loop batches 119 | for i in range(0,len(r),self.sbatch): 120 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 121 | else: b=r[i:] 122 | images=x[b] 123 | targets=y[b] 124 | 125 | # Forward current model 126 | outputs = self.model.forward(images)[t] 127 | loss=self.criterion(t,outputs,targets) 128 | 129 | # Backward 130 | self.optimizer.zero_grad() 131 | loss.backward() 132 | self.optimizer.step() 133 | return 134 | 135 | def eval(self,t,x,y): 136 | total_loss=0 137 | total_acc=0 138 | total_num=0 139 | self.model.eval() 140 | 141 | r = np.arange(x.size(0)) 142 | r = torch.LongTensor(r).cuda() 143 | 144 | # Loop batches 145 | for i in range(0,len(r),self.sbatch): 146 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 147 | else: b=r[i:] 148 | images=x[b] 149 | targets=y[b] 150 | 151 | # Forward 152 | 153 | output = self.model.forward(images)[t] 154 | 155 | loss=self.criterion(t,output,targets) 156 | _,pred=output.max(1) 157 | hits=(pred==targets).float() 158 | 159 | # Log 160 | total_loss+=loss.data.cpu().numpy()*len(b) 161 | total_acc+=hits.sum().data.cpu().numpy() 162 | total_num+=len(b) 163 | 164 | return total_loss/total_num,total_acc/total_num 165 | 166 | def criterion(self,t,output,targets): 167 | # Regularization for all previous tasks 168 | loss_reg=0 169 | if t>0: 170 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 171 | loss_reg+=torch.sum(self.fisher[name]*(param_old-param).pow(2))/2 172 | return self.ce(output,targets)+self.lamb*loss_reg -------------------------------------------------------------------------------- /approaches/ft.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | args = get_args() 12 | 13 | 14 | class Appr(object): 15 | """ Class implementing the fine tuning """ 16 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None): 17 | self.model=model 18 | self.model_old=model 19 | self.fisher=None 20 | 21 | self.nepochs = nepochs 22 | self.sbatch = sbatch 23 | self.lr = lr 24 | self.lr_min = lr_min * 1/3 25 | self.lr_factor = lr_factor 26 | self.lr_patience = lr_patience 27 | self.clipgrad = clipgrad 28 | 29 | self.ce=torch.nn.CrossEntropyLoss() 30 | self.optimizer=self._get_optimizer() 31 | self.lamb=args.lamb 32 | if len(args.parameter)>=1: 33 | params=args.parameter.split(',') 34 | print('Setting parameters to',params) 35 | self.lamb=float(params[0]) 36 | 37 | return 38 | 39 | def _get_optimizer(self,lr=None): 40 | if lr is None: lr=self.lr 41 | 42 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 43 | return optimizer 44 | 45 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 46 | best_loss = np.inf 47 | best_model = utils.get_model(self.model) 48 | lr = self.lr 49 | self.optimizer = self._get_optimizer(lr) 50 | 51 | # Loop epochs 52 | for e in range(self.nepochs): 53 | # Train 54 | clock0=time.time() 55 | 56 | num_batch = xtrain.size(0) 57 | 58 | self.train_epoch(t,xtrain,ytrain, e) 59 | 60 | clock1=time.time() 61 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 62 | clock2=time.time() 63 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 64 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 65 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 66 | # Valid 67 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 68 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 69 | print(' lr : {:.6f}'.format(self.optimizer.param_groups[0]['lr'])) 70 | #save log for current task & old tasks at every epoch 71 | 72 | # Adapt lr 73 | if valid_loss < best_loss: 74 | best_loss = valid_loss 75 | best_model = utils.get_model(self.model) 76 | patience = self.lr_patience 77 | print(' *', end='') 78 | 79 | else: 80 | patience -= 1 81 | if patience <= 0: 82 | lr /= self.lr_factor 83 | print(' lr={:.1e}'.format(lr), end='') 84 | if lr < self.lr_min: 85 | print() 86 | patience = self.lr_patience 87 | self.optimizer = self._get_optimizer(lr) 88 | print() 89 | 90 | # Restore best 91 | utils.set_model_(self.model, best_model) 92 | 93 | # Update old 94 | self.model_old = deepcopy(self.model) 95 | self.model_old.train() 96 | utils.freeze_model(self.model_old) # Freeze the weights 97 | 98 | ''' 99 | # Fisher ops 100 | if t>0: 101 | fisher_old={} 102 | for n,_ in self.model.named_parameters(): 103 | fisher_old[n]=self.fisher[n].clone() 104 | self.fisher=utils.fisher_matrix_diag(t,xtrain,ytrain,self.model,self.criterion) 105 | if t>0: 106 | # Watch out! We do not want to keep t models (or fisher diagonals) in memory, therefore we have to merge fisher diagonals 107 | for n,_ in self.model.named_parameters(): 108 | self.fisher[n]=(self.fisher[n]+fisher_old[n]*t)/(t+1) # Checked: it is better than the other option 109 | ''' 110 | 111 | return 112 | 113 | def train_epoch(self,t,x,y, epoch): 114 | self.model.train() 115 | 116 | r=np.arange(x.size(0)) 117 | np.random.shuffle(r) 118 | r=torch.LongTensor(r).cuda() 119 | 120 | # Loop batches 121 | for i in range(0,len(r),self.sbatch): 122 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 123 | else: b=r[i:] 124 | images=x[b] 125 | targets=y[b] 126 | 127 | # Forward current model 128 | outputs = self.model.forward(images)[t] 129 | loss=self.criterion(t,outputs,targets) 130 | 131 | # Backward 132 | self.optimizer.zero_grad() 133 | loss.backward() 134 | self.optimizer.step() 135 | return 136 | 137 | def eval(self,t,x,y): 138 | total_loss=0 139 | total_acc=0 140 | total_num=0 141 | self.model.eval() 142 | 143 | r = np.arange(x.size(0)) 144 | r = torch.LongTensor(r).cuda() 145 | 146 | # Loop batches 147 | for i in range(0,len(r),self.sbatch): 148 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 149 | else: b=r[i:] 150 | images=x[b] 151 | targets=y[b] 152 | 153 | # Forward 154 | 155 | output = self.model.forward(images)[t] 156 | 157 | loss=self.criterion(t,output,targets) 158 | _,pred=output.max(1) 159 | hits=(pred==targets).float() 160 | 161 | # Log 162 | total_loss+=loss.data.cpu().numpy()*len(b) 163 | total_acc+=hits.sum().data.cpu().numpy() 164 | total_num+=len(b) 165 | 166 | return total_loss/total_num,total_acc/total_num 167 | 168 | def criterion(self,t,output,targets): 169 | # Regularization for all previous tasks 170 | loss_reg=0 171 | ''' 172 | if t>0: 173 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 174 | loss_reg+=torch.sum(self.fisher[name]*(param_old-param).pow(2))/2 175 | ''' 176 | return self.ce(output,targets)+self.lamb*loss_reg -------------------------------------------------------------------------------- /approaches/mas.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | args = get_args() 13 | 14 | class Appr(object): 15 | """ Class implementing the Elastic Weight Consolidation approach described in http://arxiv.org/abs/1612.00796 """ 16 | 17 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None): 18 | self.model=model 19 | self.model_old=model 20 | self.fisher=None 21 | 22 | self.nepochs = nepochs 23 | self.sbatch = sbatch 24 | self.lr = lr 25 | self.lr_min = lr_min * 1/3 26 | self.lr_factor = lr_factor 27 | self.lr_patience = lr_patience 28 | self.clipgrad = clipgrad 29 | self.lamb=args.lamb 30 | 31 | self.ce=torch.nn.CrossEntropyLoss() 32 | self.optimizer=self._get_optimizer() 33 | 34 | self.omega = {} 35 | 36 | for n,_ in self.model.named_parameters(): 37 | self.omega[n] = 0 38 | 39 | return 40 | 41 | def _get_optimizer(self,lr=None): 42 | if lr is None: lr=self.lr 43 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 44 | return optimizer 45 | 46 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 47 | best_loss = np.inf 48 | best_model = utils.get_model(self.model) 49 | lr = self.lr 50 | self.optimizer = self._get_optimizer(lr) 51 | 52 | # Loop epochs 53 | for e in range(self.nepochs): 54 | # Train 55 | clock0=time.time() 56 | 57 | num_batch = xtrain.size(0) 58 | 59 | self.train_epoch(t,xtrain,ytrain) 60 | 61 | clock1=time.time() 62 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 63 | clock2=time.time() 64 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 65 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 66 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 67 | # Valid 68 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 69 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 70 | print(' lr : {:.6f}'.format(self.optimizer.param_groups[0]['lr'])) 71 | #save log for current task & old tasks at every epoch 72 | 73 | # Adapt lr 74 | if valid_loss < best_loss: 75 | best_loss = valid_loss 76 | best_model = utils.get_model(self.model) 77 | patience = self.lr_patience 78 | print(' *', end='') 79 | 80 | else: 81 | patience -= 1 82 | if patience <= 0: 83 | lr /= self.lr_factor 84 | print(' lr={:.1e}'.format(lr), end='') 85 | if lr < self.lr_min: 86 | print() 87 | patience = self.lr_patience 88 | self.optimizer = self._get_optimizer(lr) 89 | print() 90 | 91 | # Restore best 92 | utils.set_model_(self.model, best_model) 93 | 94 | # Update old 95 | self.model_old = deepcopy(self.model) 96 | utils.freeze_model(self.model_old) # Freeze the weights 97 | self.omega_update(t,xtrain) 98 | 99 | return 100 | 101 | def train_epoch(self,t,x,y): 102 | self.model.train() 103 | 104 | r=np.arange(x.size(0)) 105 | np.random.shuffle(r) 106 | r=torch.LongTensor(r).cuda() 107 | 108 | # Loop batches 109 | for i in range(0,len(r),self.sbatch): 110 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 111 | else: b=r[i:] 112 | images=x[b] 113 | targets=y[b] 114 | 115 | # Forward current model 116 | outputs = self.model.forward(images)[t] 117 | loss=self.criterion(t,outputs,targets) 118 | 119 | # Backward 120 | self.optimizer.zero_grad() 121 | loss.backward() 122 | self.optimizer.step() 123 | 124 | return 125 | 126 | def eval(self,t,x,y): 127 | total_loss=0 128 | total_acc=0 129 | total_num=0 130 | self.model.eval() 131 | 132 | r = np.arange(x.size(0)) 133 | r = torch.LongTensor(r).cuda() 134 | 135 | # Loop batches 136 | for i in range(0,len(r),self.sbatch): 137 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 138 | else: b=r[i:] 139 | images=x[b] 140 | targets=y[b] 141 | 142 | # Forward 143 | output = self.model.forward(images)[t] 144 | 145 | loss=self.criterion(t,output,targets) 146 | _,pred=output.max(1) 147 | hits=(pred==targets).float() 148 | 149 | # Log 150 | total_loss+=loss.data.cpu().numpy()*len(b) 151 | total_acc+=hits.sum().data.cpu().numpy() 152 | total_num+=len(b) 153 | 154 | return total_loss/total_num,total_acc/total_num 155 | 156 | def criterion(self,t,output,targets): 157 | # Regularization for all previous tasks 158 | loss_reg=0 159 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 160 | loss_reg+=torch.sum(self.omega[name]*(param_old-param).pow(2))/2 161 | 162 | return self.ce(output,targets)+self.lamb*loss_reg 163 | 164 | def omega_update(self,t,x): 165 | sbatch = 20 166 | 167 | # Compute 168 | self.model.train() 169 | for i in tqdm(range(0,x.size(0),sbatch),desc='Omega',ncols=100,ascii=True): 170 | b=torch.LongTensor(np.arange(i,np.min([i+sbatch,x.size(0)]))).cuda() 171 | images = x[b] 172 | # Forward and backward 173 | self.model.zero_grad() 174 | outputs = self.model.forward(images)[t] 175 | 176 | # Sum of L2 norm of output scores 177 | loss = torch.sum(outputs.norm(2, dim = -1)) 178 | 179 | loss.backward() 180 | 181 | # Get gradients 182 | for n,p in self.model.named_parameters(): 183 | if p.grad is not None: 184 | self.omega[n]+= p.grad.data.abs() / x.size(0) 185 | 186 | return -------------------------------------------------------------------------------- /networks/alexnet_hat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class Net(nn.Module): 15 | 16 | def __init__(self, inputsize, taskcla): 17 | super(Net, self).__init__() 18 | self.taskcla = taskcla 19 | self.relu = nn.ReLU(inplace=True) 20 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 21 | self.dropout = nn.Dropout() 22 | self.c1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1) 23 | self.c2 = nn.Conv2d(64, 192, kernel_size=3, padding=1) 24 | self.c3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 25 | self.c4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 26 | self.c5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 27 | self.fc1 = nn.Linear(256 * 1 * 1, 4096) 28 | self.fc2 = nn.Linear(4096, 4096) 29 | 30 | self.last=torch.nn.ModuleList() 31 | for t,n in self.taskcla: 32 | self.last.append(torch.nn.Linear(4096,n)) 33 | 34 | 35 | self.gate=torch.nn.Sigmoid() 36 | # All embedding stuff should start with 'e' 37 | self.ec1=torch.nn.Embedding(len(self.taskcla),64) 38 | self.ec2=torch.nn.Embedding(len(self.taskcla),192) 39 | self.ec3=torch.nn.Embedding(len(self.taskcla),384) 40 | self.ec4=torch.nn.Embedding(len(self.taskcla),256) 41 | self.ec5=torch.nn.Embedding(len(self.taskcla),256) 42 | self.efc1=torch.nn.Embedding(len(self.taskcla),4096) 43 | self.efc2=torch.nn.Embedding(len(self.taskcla),4096) 44 | 45 | """ (e.g., used in the compression experiments) 46 | lo,hi=0,2 47 | self.ec1.weight.data.uniform_(lo,hi) 48 | self.ec2.weight.data.uniform_(lo,hi) 49 | self.ec3.weight.data.uniform_(lo,hi) 50 | self.ec4.weight.data.uniform_(lo,hi) 51 | self.ec5.weight.data.uniform_(lo,hi) 52 | self.efc1.weight.data.uniform_(lo,hi) 53 | self.efc2.weight.data.uniform_(lo,hi) 54 | #""" 55 | 56 | def forward(self,t,x,s=1): 57 | # Gates 58 | masks=self.mask(t,s=s) 59 | gc1,gc2,gc3,gc4,gc5,gfc1,gfc2=masks 60 | 61 | # Gated 62 | h=self.relu(self.c1(x)) 63 | h=self.maxpool(h) 64 | h=h*gc1.view(1,-1,1,1).expand_as(h) 65 | 66 | h=self.relu(self.c2(h)) 67 | h=self.maxpool(h) 68 | h=h*gc2.view(1,-1,1,1).expand_as(h) 69 | 70 | h=self.relu(self.c3(h)) 71 | h=h*gc3.view(1,-1,1,1).expand_as(h) 72 | h=self.relu(self.c4(h)) 73 | h=h*gc4.view(1,-1,1,1).expand_as(h) 74 | h=self.relu(self.c5(h)) 75 | h=self.maxpool(h) 76 | h=h*gc5.view(1,-1,1,1).expand_as(h) 77 | 78 | h=h.view(x.shape[0],-1) 79 | 80 | h=self.dropout(self.relu(self.fc1(h))) 81 | h=h*gfc1.expand_as(h) 82 | h=self.dropout(self.relu(self.fc2(h))) 83 | h=h*gfc2.expand_as(h) 84 | y=[] 85 | for i,_ in self.taskcla: 86 | y.append(self.last[i](h)) 87 | 88 | 89 | return y,masks 90 | 91 | def mask(self,t,s=1): 92 | gc1=self.gate(s*self.ec1(t)) 93 | gc2=self.gate(s*self.ec2(t)) 94 | gc3=self.gate(s*self.ec3(t)) 95 | gc4=self.gate(s*self.ec4(t)) 96 | gc5=self.gate(s*self.ec5(t)) 97 | gfc1=self.gate(s*self.efc1(t)) 98 | gfc2=self.gate(s*self.efc2(t)) 99 | return [gc1,gc2,gc3,gc4,gc5,gfc1,gfc2] 100 | 101 | def get_view_for(self,n,masks): 102 | gc1,gc2,gc3,gc4,gc5,gfc1,gfc2=masks 103 | if n=='fc1.weight': 104 | post=gfc1.data.view(-1,1).expand_as(self.fc1.weight) 105 | pre=gc5.data.view(-1,1,1).expand((self.ec5.weight.size(1), 106 | 1, 107 | 1)).contiguous().view(1,-1).expand_as(self.fc1.weight) 108 | return torch.min(post,pre) 109 | elif n=='fc1.bias': 110 | return gfc1.data.view(-1) 111 | elif n=='fc2.weight': 112 | post=gfc2.data.view(-1,1).expand_as(self.fc2.weight) 113 | pre=gfc1.data.view(1,-1).expand_as(self.fc2.weight) 114 | return torch.min(post,pre) 115 | elif n=='fc2.bias': 116 | return gfc2.data.view(-1) 117 | elif n=='c1.weight': 118 | return gc1.data.view(-1,1,1,1).expand_as(self.c1.weight) 119 | elif n=='c1.bias': 120 | return gc1.data.view(-1) 121 | elif n=='c2.weight': 122 | post=gc2.data.view(-1,1,1,1).expand_as(self.c2.weight) 123 | pre=gc1.data.view(1,-1,1,1).expand_as(self.c2.weight) 124 | return torch.min(post,pre) 125 | elif n=='c2.bias': 126 | return gc2.data.view(-1) 127 | elif n=='c3.weight': 128 | post=gc3.data.view(-1,1,1,1).expand_as(self.c3.weight) 129 | pre=gc2.data.view(1,-1,1,1).expand_as(self.c3.weight) 130 | return torch.min(post,pre) 131 | elif n=='c3.bias': 132 | return gc3.data.view(-1) 133 | elif n=='c4.weight': 134 | post=gc4.data.view(-1,1,1,1).expand_as(self.c4.weight) 135 | pre=gc3.data.view(1,-1,1,1).expand_as(self.c4.weight) 136 | return torch.min(post,pre) 137 | elif n=='c4.bias': 138 | return gc4.data.view(-1) 139 | elif n=='c5.weight': 140 | post=gc5.data.view(-1,1,1,1).expand_as(self.c5.weight) 141 | pre=gc4.data.view(1,-1,1,1).expand_as(self.c5.weight) 142 | return torch.min(post,pre) 143 | elif n=='c5.bias': 144 | return gc5.data.view(-1) 145 | 146 | return None 147 | 148 | 149 | def alexnet(taskcla, pretrained=False): 150 | r"""AlexNet model architecture from the 151 | `"One weird trick..." `_ paper. 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | progress (bool): If True, displays a progress bar of the download to stderr 155 | """ 156 | model = AlexNet(taskcla) 157 | 158 | if pretrained: 159 | pre_model = torchvision.models.alexnet(pretrained=True) 160 | for key1, key2 in zip(model.state_dict().keys(), pre_model.state_dict().keys()): 161 | if 'last' in key1: 162 | break 163 | if model.state_dict()[key1].shape == torch.tensor(1).shape: 164 | model.state_dict()[key1] = pre_model.state_dict()[key2] 165 | else: 166 | model.state_dict()[key1][:] = pre_model.state_dict()[key2][:] 167 | 168 | return model 169 | 170 | 171 | -------------------------------------------------------------------------------- /LargeScale/networks/alexnet_hat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class AlexNet(nn.Module): 15 | 16 | def __init__(self, taskcla): 17 | super(AlexNet, self).__init__() 18 | self.taskcla = taskcla 19 | self.relu = nn.ReLU(inplace=True) 20 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 21 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 22 | self.dropout = nn.Dropout() 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2) 24 | self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) 25 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 26 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 27 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 28 | self.fc1 = nn.Linear(256 * 6 * 6, 4096) 29 | self.fc2 = nn.Linear(4096, 4096) 30 | 31 | self.last=torch.nn.ModuleList() 32 | for t,n in self.taskcla: 33 | self.last.append(torch.nn.Linear(4096,n)) 34 | 35 | self.smid=6 36 | self.gate=torch.nn.Sigmoid() 37 | # All embedding stuff should start with 'e' 38 | self.ec1=torch.nn.Embedding(len(self.taskcla),64) 39 | self.ec2=torch.nn.Embedding(len(self.taskcla),192) 40 | self.ec3=torch.nn.Embedding(len(self.taskcla),384) 41 | self.ec4=torch.nn.Embedding(len(self.taskcla),256) 42 | self.ec5=torch.nn.Embedding(len(self.taskcla),256) 43 | 44 | self.efc1=torch.nn.Embedding(len(self.taskcla),4096) 45 | self.efc2=torch.nn.Embedding(len(self.taskcla),4096) 46 | 47 | def forward(self,x,t,s, mask_return=False): 48 | # Gates 49 | masks=self.mask(t,s=s) 50 | gc1,gc2,gc3,gc4,gc5,gfc1,gfc2=masks 51 | # Gated 52 | x = self.maxpool(self.relu(self.conv1(x))) 53 | x=x*gc1.view(1,-1,1,1).expand_as(x) 54 | x = self.maxpool(self.relu(self.conv2(x))) 55 | x=x*gc2.view(1,-1,1,1).expand_as(x) 56 | x = self.relu(self.conv3(x)) 57 | x=x*gc3.view(1,-1,1,1).expand_as(x) 58 | x = self.relu(self.conv4(x)) 59 | x=x*gc4.view(1,-1,1,1).expand_as(x) 60 | x = self.maxpool(self.relu(self.conv5(x))) 61 | x=x*gc5.view(1,-1,1,1).expand_as(x) 62 | 63 | x = torch.flatten(x, 1) 64 | x=self.dropout(self.relu(self.fc1(x))) 65 | x=x*gfc1.expand_as(x) 66 | x=self.dropout(self.relu(self.fc2(x))) 67 | x=x*gfc2.expand_as(x) 68 | 69 | y = [] 70 | for t,i in self.taskcla: 71 | y.append(self.last[t](x)) 72 | 73 | if mask_return: 74 | return y,masks 75 | return y 76 | 77 | def mask(self,t,s=1): 78 | gc1=self.gate(s*self.ec1(t)) 79 | gc2=self.gate(s*self.ec2(t)) 80 | gc3=self.gate(s*self.ec3(t)) 81 | gc4=self.gate(s*self.ec4(t)) 82 | gc5=self.gate(s*self.ec5(t)) 83 | gfc1=self.gate(s*self.efc1(t)) 84 | gfc2=self.gate(s*self.efc2(t)) 85 | return [gc1,gc2,gc3,gc4,gc5,gfc1,gfc2] 86 | 87 | def get_view_for(self,n,masks): 88 | gc1,gc2,gc3,gc4,gc5,gfc1,gfc2=masks 89 | if n=='fc1.weight': 90 | post=gfc1.data.view(-1,1).expand_as(self.fc1.weight) 91 | pre=gc6.data.view(-1,1,1).expand((self.ec6.weight.size(1), 92 | self.smid, 93 | self.smid)).contiguous().view(1,-1).expand_as(self.fc1.weight) 94 | return torch.min(post,pre) 95 | elif n=='fc1.bias': 96 | return gfc1.data.view(-1) 97 | elif n=='fc2.weight': 98 | post=gfc2.data.view(-1,1).expand_as(self.fc2.weight) 99 | pre=gfc1.data.view(1,-1).expand_as(self.fc2.weight) 100 | return torch.min(post,pre) 101 | elif n=='fc2.bias': 102 | return gfc2.data.view(-1) 103 | elif n=='c1.weight': 104 | return gc1.data.view(-1,1,1,1).expand_as(self.c1.weight) 105 | elif n=='c1.bias': 106 | return gc1.data.view(-1) 107 | elif n=='c2.weight': 108 | post=gc2.data.view(-1,1,1,1).expand_as(self.c2.weight) 109 | pre=gc1.data.view(1,-1,1,1).expand_as(self.c2.weight) 110 | return torch.min(post,pre) 111 | elif n=='c2.bias': 112 | return gc2.data.view(-1) 113 | elif n=='c3.weight': 114 | post=gc3.data.view(-1,1,1,1).expand_as(self.c3.weight) 115 | pre=gc2.data.view(1,-1,1,1).expand_as(self.c3.weight) 116 | return torch.min(post,pre) 117 | elif n=='c3.bias': 118 | return gc3.data.view(-1) 119 | elif n=='c4.weight': 120 | post=gc4.data.view(-1,1,1,1).expand_as(self.c4.weight) 121 | pre=gc3.data.view(1,-1,1,1).expand_as(self.c4.weight) 122 | return torch.min(post,pre) 123 | elif n=='c4.bias': 124 | return gc4.data.view(-1) 125 | elif n=='c5.weight': 126 | post=gc5.data.view(-1,1,1,1).expand_as(self.c5.weight) 127 | pre=gc4.data.view(1,-1,1,1).expand_as(self.c5.weight) 128 | return torch.min(post,pre) 129 | elif n=='c5.bias': 130 | return gc5.data.view(-1) 131 | elif n=='c6.weight': 132 | post=gc6.data.view(-1,1,1,1).expand_as(self.c6.weight) 133 | pre=gc5.data.view(1,-1,1,1).expand_as(self.c6.weight) 134 | return torch.min(post,pre) 135 | elif n=='c6.bias': 136 | return gc6.data.view(-1) 137 | return None 138 | 139 | 140 | def alexnet(taskcla, pretrained=False): 141 | r"""AlexNet model architecture from the 142 | `"One weird trick..." `_ paper. 143 | Args: 144 | pretrained (bool): If True, returns a model pre-trained on ImageNet 145 | progress (bool): If True, displays a progress bar of the download to stderr 146 | """ 147 | model = AlexNet(taskcla) 148 | 149 | if pretrained: 150 | pre_model = torchvision.models.alexnet(pretrained=True) 151 | for key in model.state_dict().keys(): 152 | print(key) 153 | for key in pre_model.state_dict().keys(): 154 | print(key) 155 | for key1, key2 in zip(model.state_dict().keys(), pre_model.state_dict().keys()): 156 | if 'last' in key1: 157 | break 158 | if model.state_dict()[key1].shape == torch.tensor(1).shape: 159 | model.state_dict()[key1] = pre_model.state_dict()[key2] 160 | else: 161 | model.state_dict()[key1][:] = pre_model.state_dict()[key2][:] 162 | 163 | return model 164 | 165 | -------------------------------------------------------------------------------- /networks/conv_net_hat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from utils import * 5 | 6 | class Net(torch.nn.Module): 7 | 8 | def __init__(self,inputsize,taskcla): 9 | super(Net,self).__init__() 10 | 11 | ncha,size,_=inputsize 12 | self.taskcla=taskcla 13 | 14 | self.c1 = nn.Conv2d(ncha,32,kernel_size=3,padding=1) 15 | s = compute_conv_output_size(size,3, padding=1) # 32 16 | self.c2 = nn.Conv2d(32,32,kernel_size=3,padding=1) 17 | s = compute_conv_output_size(s,3, padding=1) # 32 18 | s = s//2 # 16 19 | self.c3 = nn.Conv2d(32,64,kernel_size=3,padding=1) 20 | s = compute_conv_output_size(s,3, padding=1) # 16 21 | self.c4 = nn.Conv2d(64,64,kernel_size=3,padding=1) 22 | s = compute_conv_output_size(s,3, padding=1) # 16 23 | s = s//2 # 8 24 | self.c5 = nn.Conv2d(64,128,kernel_size=3,padding=1) 25 | s = compute_conv_output_size(s,3, padding=1) # 8 26 | self.c6 = nn.Conv2d(128,128,kernel_size=3,padding=1) 27 | s = compute_conv_output_size(s,3, padding=1) # 8 28 | # self.c7 = nn.Conv2d(128,128,kernel_size=3,padding=1) 29 | # s = compute_conv_output_size(s,3, padding=1) # 8 30 | s = s//2 # 4 31 | self.fc1 = nn.Linear(s*s*128,256) # 2048 32 | self.drop1=torch.nn.Dropout(0.2) 33 | self.drop2=torch.nn.Dropout(0.5) 34 | 35 | self.smid=s 36 | self.MaxPool = torch.nn.MaxPool2d(2) 37 | self.relu=torch.nn.ReLU() 38 | 39 | self.last=torch.nn.ModuleList() 40 | for t,n in self.taskcla: 41 | self.last.append(torch.nn.Linear(256,n)) 42 | 43 | self.gate=torch.nn.Sigmoid() 44 | # All embedding stuff should start with 'e' 45 | self.ec1=torch.nn.Embedding(len(self.taskcla),32) 46 | self.ec2=torch.nn.Embedding(len(self.taskcla),32) 47 | self.ec3=torch.nn.Embedding(len(self.taskcla),64) 48 | self.ec4=torch.nn.Embedding(len(self.taskcla),64) 49 | self.ec5=torch.nn.Embedding(len(self.taskcla),128) 50 | self.ec6=torch.nn.Embedding(len(self.taskcla),128) 51 | # self.ec7=torch.nn.Embedding(len(self.taskcla),128) 52 | self.efc1=torch.nn.Embedding(len(self.taskcla),256) 53 | 54 | """ (e.g., used in the compression experiments) 55 | lo,hi=0,2 56 | self.ec1.weight.data.uniform_(lo,hi) 57 | self.ec2.weight.data.uniform_(lo,hi) 58 | self.ec3.weight.data.uniform_(lo,hi) 59 | self.ec4.weight.data.uniform_(lo,hi) 60 | self.ec5.weight.data.uniform_(lo,hi) 61 | self.ec6.weight.data.uniform_(lo,hi) 62 | self.ec7.weight.data.uniform_(lo,hi) 63 | self.efc1.weight.data.uniform_(lo,hi) 64 | #""" 65 | 66 | return 67 | 68 | def forward(self,t,x,s=1): 69 | # Gates 70 | masks=self.mask(t,s=s) 71 | # gc1,gc2,gc3,gc4,gc5,gc6,gc7,gfc1=masks 72 | gc1,gc2,gc3,gc4,gc5,gc6,gfc1=masks 73 | 74 | # Gated 75 | h=self.relu(self.c1(x)) 76 | h=h*gc1.view(1,-1,1,1).expand_as(h) 77 | h=self.relu(self.c2(h)) 78 | h=h*gc2.view(1,-1,1,1).expand_as(h) 79 | h=self.drop1(self.MaxPool(h)) 80 | 81 | h=self.relu(self.c3(h)) 82 | h=h*gc3.view(1,-1,1,1).expand_as(h) 83 | h=self.relu(self.c4(h)) 84 | h=h*gc4.view(1,-1,1,1).expand_as(h) 85 | h=self.drop1(self.MaxPool(h)) 86 | 87 | h=self.relu(self.c5(h)) 88 | h=h*gc5.view(1,-1,1,1).expand_as(h) 89 | h=self.relu(self.c6(h)) 90 | h=h*gc6.view(1,-1,1,1).expand_as(h) 91 | # h=self.relu(self.c7(h)) 92 | # h=h*gc7.view(1,-1,1,1).expand_as(h) 93 | h=self.drop1(self.MaxPool(h)) 94 | 95 | h=h.view(x.shape[0],-1) 96 | h=self.drop2(self.relu(self.fc1(h))) 97 | h=h*gfc1.expand_as(h) 98 | y=[] 99 | for i,_ in self.taskcla: 100 | y.append(self.last[i](h)) 101 | return y,masks 102 | 103 | def mask(self,t,s=1): 104 | gc1=self.gate(s*self.ec1(t)) 105 | gc2=self.gate(s*self.ec2(t)) 106 | gc3=self.gate(s*self.ec3(t)) 107 | gc4=self.gate(s*self.ec4(t)) 108 | gc5=self.gate(s*self.ec5(t)) 109 | gc6=self.gate(s*self.ec6(t)) 110 | # gc7=self.gate(s*self.ec7(t)) 111 | gfc1=self.gate(s*self.efc1(t)) 112 | # return [gc1,gc2,gc3,gc4,gc5,gc6,gc7,gfc1] 113 | return [gc1,gc2,gc3,gc4,gc5,gc6,gfc1] 114 | 115 | def get_view_for(self,n,masks): 116 | # gc1,gc2,gc3,gc4,gc5,gc6,gc7,gfc1=masks 117 | gc1,gc2,gc3,gc4,gc5,gc6,gfc1=masks 118 | if n=='fc1.weight': 119 | post=gfc1.data.view(-1,1).expand_as(self.fc1.weight) 120 | pre=gc6.data.view(-1,1,1).expand((self.ec6.weight.size(1), 121 | self.smid, 122 | self.smid)).contiguous().view(1,-1).expand_as(self.fc1.weight) 123 | return torch.min(post,pre) 124 | elif n=='fc1.bias': 125 | return gfc1.data.view(-1) 126 | elif n=='c1.weight': 127 | return gc1.data.view(-1,1,1,1).expand_as(self.c1.weight) 128 | elif n=='c1.bias': 129 | return gc1.data.view(-1) 130 | elif n=='c2.weight': 131 | post=gc2.data.view(-1,1,1,1).expand_as(self.c2.weight) 132 | pre=gc1.data.view(1,-1,1,1).expand_as(self.c2.weight) 133 | return torch.min(post,pre) 134 | elif n=='c2.bias': 135 | return gc2.data.view(-1) 136 | elif n=='c3.weight': 137 | post=gc3.data.view(-1,1,1,1).expand_as(self.c3.weight) 138 | pre=gc2.data.view(1,-1,1,1).expand_as(self.c3.weight) 139 | return torch.min(post,pre) 140 | elif n=='c3.bias': 141 | return gc3.data.view(-1) 142 | elif n=='c4.weight': 143 | post=gc4.data.view(-1,1,1,1).expand_as(self.c4.weight) 144 | pre=gc3.data.view(1,-1,1,1).expand_as(self.c4.weight) 145 | return torch.min(post,pre) 146 | elif n=='c4.bias': 147 | return gc4.data.view(-1) 148 | elif n=='c5.weight': 149 | post=gc5.data.view(-1,1,1,1).expand_as(self.c5.weight) 150 | pre=gc4.data.view(1,-1,1,1).expand_as(self.c5.weight) 151 | return torch.min(post,pre) 152 | elif n=='c5.bias': 153 | return gc5.data.view(-1) 154 | elif n=='c6.weight': 155 | post=gc6.data.view(-1,1,1,1).expand_as(self.c6.weight) 156 | pre=gc5.data.view(1,-1,1,1).expand_as(self.c6.weight) 157 | return torch.min(post,pre) 158 | elif n=='c6.bias': 159 | return gc6.data.view(-1) 160 | # elif n=='c7.weight': 161 | # post=gc7.data.view(-1,1,1,1).expand_as(self.c7.weight) 162 | # pre=gc6.data.view(1,-1,1,1).expand_as(self.c7.weight) 163 | # return torch.min(post,pre) 164 | # elif n=='c7.bias': 165 | # return gc7.data.view(-1) 166 | return None 167 | 168 | -------------------------------------------------------------------------------- /LargeScale/trainer/si.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data as td 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import trainer 14 | 15 | import networks 16 | 17 | 18 | class Trainer(trainer.GenericTrainer): 19 | def __init__(self, model, args, optimizer, evaluator, taskcla): 20 | super().__init__(model, args, optimizer, evaluator, taskcla) 21 | 22 | self.lamb=args.lamb 23 | self.epsilon=0.01 24 | self.omega = {} 25 | self.W = {} 26 | self.p_old = {} 27 | 28 | n=0 29 | 30 | # Register starting param-values (needed for “intelligent synapses”). 31 | for n, p in self.model.named_parameters(): 32 | if p.requires_grad: 33 | n = n.replace('.', '__') 34 | self.model.register_buffer('{}_SI_prev_task'.format(n), p.data.clone()) 35 | 36 | def update_lr(self, epoch, schedule): 37 | for temp in range(0, len(schedule)): 38 | if schedule[temp] == epoch: 39 | for param_group in self.optimizer.param_groups: 40 | self.current_lr = param_group['lr'] 41 | param_group['lr'] = self.current_lr * self.args.gammas[temp] 42 | print("Changing learning rate from %0.4f to %0.4f"%(self.current_lr, 43 | self.current_lr * self.args.gammas[temp])) 44 | self.current_lr *= self.args.gammas[temp] 45 | 46 | 47 | def setup_training(self, lr): 48 | 49 | for param_group in self.optimizer.param_groups: 50 | print("Setting LR to %0.4f"%lr) 51 | param_group['lr'] = lr 52 | self.current_lr = lr 53 | 54 | def update_frozen_model(self): 55 | self.model.eval() 56 | self.model_fixed = copy.deepcopy(self.model) 57 | self.model_fixed.eval() 58 | for param in self.model_fixed.parameters(): 59 | param.requires_grad = False 60 | 61 | def train(self, train_loader, test_loader, t): 62 | 63 | lr = self.args.lr 64 | self.setup_training(lr) 65 | 66 | # Do not update self.t 67 | if t>0: 68 | self.update_frozen_model() 69 | self.update_omega() 70 | 71 | # Now, you can update self.t 72 | self.W = {} 73 | self.p_old = {} 74 | for n, p in self.model.named_parameters(): 75 | if p.requires_grad: 76 | n = n.replace('.', '__') 77 | self.W[n] = p.data.clone().zero_() 78 | self.p_old[n] = p.data.clone() 79 | 80 | 81 | self.t = t 82 | #kwargs = {'num_workers': 8, 'pin_memory': True} 83 | kwargs = {'num_workers': 0, 'pin_memory': False} 84 | self.train_iterator = torch.utils.data.DataLoader(train_loader, batch_size=self.args.batch_size, shuffle=True, **kwargs) 85 | self.test_iterator = torch.utils.data.DataLoader(test_loader, 100, shuffle=False, **kwargs) 86 | 87 | for epoch in range(self.args.nepochs): 88 | self.model.train() 89 | self.update_lr(epoch, self.args.schedule) 90 | for samples in tqdm(self.train_iterator): 91 | data, target = samples 92 | data, target = data.cuda(), target.cuda() 93 | 94 | batch_size = data.shape[0] 95 | output = self.model(data)[t] 96 | loss_CE = self.criterion(output,target) 97 | 98 | self.optimizer.zero_grad() 99 | (loss_CE).backward() 100 | self.optimizer.step() 101 | for n, p in self.model.named_parameters(): 102 | if p.requires_grad: 103 | n = n.replace('.', '__') 104 | if p.grad is not None: 105 | self.W[n].add_(-p.grad * (p.detach() - self.p_old[n])) 106 | self.p_old[n] = p.detach().clone() 107 | 108 | 109 | train_loss,train_acc = self.evaluator.evaluate(self.model, self.train_iterator, t) 110 | num_batch = len(self.train_iterator) 111 | print('| Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% |'.format(epoch+1,train_loss,100*train_acc),end='') 112 | valid_loss,valid_acc=self.evaluator.evaluate(self.model, self.test_iterator, t) 113 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 114 | print() 115 | 116 | 117 | 118 | def criterion(self,output,targets): 119 | # Regularization for all previous tasks 120 | loss_reg = 0 121 | if self.t>0: 122 | loss_reg=self.surrogate_loss() 123 | 124 | return self.ce(output,targets)+self.lamb*loss_reg 125 | 126 | def update_omega(self): 127 | """After completing training on a task, update the per-parameter regularization strength. 128 | [W] estimated parameter-specific contribution to changes in total loss of completed task 129 | [epsilon] dampening parameter (to bound [omega] when [p_change] goes to 0)""" 130 | 131 | # Loop over all parameters 132 | for n, p in self.model.named_parameters(): 133 | if p.requires_grad: 134 | n = n.replace('.', '__') 135 | 136 | # Find/calculate new values for quadratic penalty on parameters 137 | p_prev = getattr(self.model, '{}_SI_prev_task'.format(n)) 138 | p_current = p.detach().clone() 139 | p_change = p_current - p_prev 140 | omega_add = self.W[n] / (p_change ** 2 + self.epsilon) 141 | try: 142 | omega = getattr(self.model, '{}_SI_omega'.format(n)) 143 | except AttributeError: 144 | omega = p.detach().clone().zero_() 145 | omega_new = omega + omega_add 146 | 147 | # Store these new values in the model 148 | self.model.register_buffer('{}_SI_prev_task'.format(n), p_current) 149 | self.model.register_buffer('{}_SI_omega'.format(n), omega_new) 150 | 151 | def surrogate_loss(self): 152 | """Calculate SI’s surrogate loss""" 153 | try: 154 | losses = [] 155 | for n, p in self.model.named_parameters(): 156 | if p.requires_grad: 157 | # Retrieve previous parameter values and their normalized path integral (i.e., omega) 158 | n = n.replace('.', '__') 159 | prev_values = getattr(self.model, '{}_SI_prev_task'.format(n)) 160 | omega = getattr(self.model, '{}_SI_omega'.format(n)) 161 | # Calculate SI’s surrogate loss, sum over all parameters 162 | losses.append((omega * (p - prev_values) ** 2).sum()) 163 | return sum(losses) 164 | except AttributeError: 165 | # SI-loss is 0 if there is no stored omega yet 166 | return 0. 167 | 168 | -------------------------------------------------------------------------------- /LargeScale/main.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | import pickle 9 | import torch 10 | from arguments import get_args 11 | import random 12 | import utils 13 | 14 | import data_handler 15 | from sklearn.utils import shuffle 16 | import trainer 17 | import networks 18 | 19 | # Arguments 20 | 21 | args = get_args() 22 | 23 | ##########################################################################################################################33 24 | if args.trainer == 'afec_ewc' or args.trainer == 'ewc' or args.trainer == 'afec_rwalk' or args.trainer == 'rwalk' or args.trainer == 'afec_mas' or args.trainer == 'mas' or args.trainer == 'afec_si' or args.trainer == 'si': 25 | log_name = '{}_{}_{}_{}_lamb_{}_lr_{}_batch_{}_epoch_{}'.format(args.date, args.dataset, args.trainer,args.seed, 26 | args.lamb, args.lr, args.batch_size, args.nepochs) 27 | elif args.trainer == 'gs': 28 | log_name = '{}_{}_{}_{}_lamb_{}_mu_{}_rho_{}_eta_{}_lr_{}_batch_{}_epoch_{}'.format(args.date, 29 | args.dataset, 30 | args.trainer, 31 | args.seed, 32 | args.lamb, 33 | args.mu, 34 | args.rho, 35 | args.eta, 36 | args.lr, 37 | args.batch_size, 38 | args.nepochs) 39 | 40 | if args.output == '': 41 | args.output = './result_data/' + log_name + '.txt' 42 | ######################################################################################################################## 43 | # Seed 44 | np.random.seed(args.seed) 45 | random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | if torch.cuda.is_available(): 48 | torch.cuda.manual_seed(args.seed) 49 | else: 50 | print('[CUDA unavailable]'); sys.exit() 51 | torch.backends.cudnn.deterministic = True 52 | # torch.backends.cudnn.benchmark = False 53 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 54 | 55 | if not os.path.isdir('dat'): 56 | print('Make directory for dataset') 57 | os.makedirs('dat') 58 | 59 | print('Load data...') 60 | data_dict = None 61 | dataset = data_handler.DatasetFactory.get_dataset(args.dataset) 62 | loader = dataset.loader 63 | if args.dataset == 'CUB200' or 'ImageNet': 64 | data_dict = dataset.data_dict 65 | taskcla = dataset.taskcla 66 | print('\nTask info =', taskcla) 67 | 68 | if not os.path.isdir('result_data'): 69 | print('Make directory for saving results') 70 | os.makedirs('result_data') 71 | 72 | if not os.path.isdir('trained_model'): 73 | print('Make directory for saving trained models') 74 | os.makedirs('trained_model') 75 | 76 | # Args -- Experiment 77 | 78 | # Loader used for training data 79 | shuffle_idx = shuffle(np.arange(dataset.classes), random_state=args.seed) 80 | 81 | 82 | train_dataset_loaders = data_handler.make_ResultLoaders(dataset.train_data, 83 | dataset.train_labels, 84 | taskcla, 85 | transform=dataset.train_transform, 86 | shuffle_idx = shuffle_idx, 87 | loader = loader, 88 | data_dict = data_dict, 89 | ) 90 | 91 | test_dataset_loaders = data_handler.make_ResultLoaders(dataset.test_data, 92 | dataset.test_labels, 93 | taskcla, 94 | transform=dataset.test_transform, 95 | shuffle_idx = shuffle_idx, 96 | loader = loader, 97 | data_dict = data_dict, 98 | ) 99 | 100 | # Get the required model 101 | myModel = networks.ModelFactory.get_model(args.dataset, args.trainer, taskcla).cuda() 102 | 103 | # Define the optimizer used in the experiment 104 | 105 | optimizer = torch.optim.SGD(myModel.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay) 106 | 107 | # Initilize the evaluators used to measure the performance of the system. 108 | t_classifier = trainer.EvaluatorFactory.get_evaluator("trainedClassifier") 109 | 110 | # Trainer object used for training 111 | myTrainer = trainer.TrainerFactory.get_trainer(myModel, args, optimizer, t_classifier, taskcla) 112 | 113 | ######################################################################################################################## 114 | 115 | utils.print_model_report(myModel) 116 | utils.print_optimizer_config(optimizer) 117 | print('-' * 100) 118 | 119 | # Loop tasks 120 | acc = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32) 121 | lss = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32) 122 | #kwargs = {'num_workers': 0, 'pin_memory': False} 123 | kwargs = {'num_workers': 8, 'pin_memory': True} 124 | for t, ncla in taskcla: 125 | print("tasknum:", t) 126 | # Add new classes to the train, and test iterator 127 | 128 | train_loader = train_dataset_loaders[t] 129 | test_loader = test_dataset_loaders[t] 130 | 131 | myTrainer.train(train_loader, test_loader, t) 132 | 133 | for u in range(t+1): 134 | 135 | test_loader = test_dataset_loaders[u] 136 | test_iterator = torch.utils.data.DataLoader(test_loader, 100, shuffle=False, **kwargs) 137 | test_loss, test_acc = t_classifier.evaluate(myTrainer.model, test_iterator, u) 138 | print('>>> Test on task {:2d}: loss={:.3f}, acc={:5.1f}% <<<'.format(u, test_loss, 100 * test_acc)) 139 | acc[t, u] = test_acc 140 | lss[t, u] = test_loss 141 | 142 | print('Average accuracy={:5.1f}%'.format(100 * np.mean(acc[t,:t+1]))) 143 | 144 | print('Save at ' + args.output) 145 | np.savetxt(args.output, acc, '%.4f') 146 | #torch.save(myModel.state_dict(), './trained_model/' + log_name + '_task_{}.pt'.format(t)) 147 | 148 | print('*' * 100) 149 | print('Accuracies =') 150 | for i in range(acc.shape[0]): 151 | print('\t', end='') 152 | for j in range(acc.shape[1]): 153 | print('{:5.1f}% '.format(100 * acc[i, j]), end='') 154 | print() 155 | print('*' * 100) 156 | 157 | 158 | print('*' * 100) 159 | print('Accuracies =') 160 | for i in range(acc.shape[0]): 161 | print('\t', end='') 162 | for j in range(acc.shape[1]): 163 | print('{:5.1f}% '.format(100 * acc[i, j]), end='') 164 | print() 165 | print('*' * 100) 166 | print('Done!') 167 | 168 | 169 | -------------------------------------------------------------------------------- /approaches/rwalk.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | args = get_args() 12 | 13 | class Appr(object): 14 | """ Class implementing the Elastic Weight Consolidation approach described in http://arxiv.org/abs/1612.00796 """ 15 | 16 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=2e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None): 17 | self.model=model 18 | self.model_old=model 19 | 20 | self.nepochs = nepochs 21 | self.sbatch = sbatch 22 | self.lr = lr 23 | self.lr_min = lr_min *1/3 24 | self.lr_factor = lr_factor 25 | self.lr_patience = lr_patience 26 | self.clipgrad = clipgrad 27 | 28 | self.ce=torch.nn.CrossEntropyLoss() 29 | self.optimizer=self._get_optimizer() 30 | self.lamb=args.lamb 31 | self.alpha = 0.9 32 | if len(args.parameter)>=1: 33 | params=args.parameter.split(',') 34 | print('Setting parameters to',params) 35 | self.lamb=float(params[0]) 36 | 37 | self.s = {} 38 | self.s_running = {} 39 | self.fisher = {} 40 | self.fisher_running = {} 41 | self.p_old = {} 42 | 43 | self.eps = 0.01 44 | 45 | 46 | for n, p in self.model.named_parameters(): 47 | if p.requires_grad: 48 | self.s[n] = 0 49 | self.s_running[n] = 0 50 | self.fisher[n] = 0 51 | self.fisher_running[n] = 0 52 | self.p_old[n] = p.data.clone() 53 | 54 | 55 | return 56 | 57 | def _get_optimizer(self,lr=None): 58 | if lr is None: lr=self.lr 59 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 60 | return optimizer 61 | 62 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 63 | best_loss = np.inf 64 | best_model = utils.get_model(self.model) 65 | lr = self.lr 66 | patience = self.lr_patience 67 | self.optimizer = self._get_optimizer(lr) 68 | 69 | # Loop epochs 70 | for e in range(self.nepochs): 71 | # Train 72 | clock0=time.time() 73 | num_batch = xtrain.size(0) 74 | 75 | self.train_epoch(t,xtrain,ytrain) 76 | 77 | clock1=time.time() 78 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 79 | clock2=time.time() 80 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 81 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 82 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 83 | # Valid 84 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 85 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 86 | print() 87 | #save log for current task & old tasks at every epoch 88 | 89 | 90 | if valid_loss < best_loss: 91 | best_loss = valid_loss 92 | best_model = utils.get_model(self.model) 93 | patience = self.lr_patience 94 | print(' *', end='') 95 | 96 | else: 97 | patience -= 1 98 | if patience <= 0: 99 | lr /= self.lr_factor 100 | print(' lr={:.1e}'.format(lr), end='') 101 | if lr < self.lr_min: 102 | print() 103 | patience = self.lr_patience 104 | self.optimizer = self._get_optimizer(lr) 105 | print() 106 | # Restore best 107 | utils.set_model_(self.model, best_model) 108 | 109 | # Update old 110 | self.model_old = deepcopy(self.model) 111 | utils.freeze_model(self.model_old) # Freeze the weights 112 | 113 | 114 | # Update fisher & s 115 | for n,p in self.model.named_parameters(): 116 | if p.requires_grad: 117 | if p.grad is not None: 118 | self.fisher[n] = self.fisher_running[n].clone() 119 | self.s[n] = (1/2) * self.s_running[n].clone() 120 | self.s_running[n] = self.s[n].clone() 121 | 122 | return 123 | 124 | def train_epoch(self,t,x,y): 125 | self.model.train() 126 | 127 | r=np.arange(x.size(0)) 128 | np.random.shuffle(r) 129 | r=torch.LongTensor(r).cuda() 130 | 131 | # Loop batches 132 | for i in range(0,len(r),self.sbatch): 133 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 134 | else: b=r[i:] 135 | images=x[b] 136 | targets=y[b] 137 | 138 | # Forward current model 139 | outputs = self.model.forward(images)[t] 140 | loss=self.criterion(t,outputs,targets) 141 | 142 | # Backward 143 | self.optimizer.zero_grad() 144 | loss.backward() 145 | self.optimizer.step() 146 | 147 | # Compute Fisher & s 148 | self.update_fisher_and_s() 149 | 150 | return 151 | 152 | def eval(self,t,x,y): 153 | total_loss=0 154 | total_acc=0 155 | total_num=0 156 | self.model.eval() 157 | 158 | r = np.arange(x.size(0)) 159 | r = torch.LongTensor(r).cuda() 160 | 161 | # Loop batches 162 | for i in range(0,len(r),self.sbatch): 163 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 164 | else: b=r[i:] 165 | images=x[b] 166 | targets=y[b] 167 | 168 | # Forward 169 | output = self.model.forward(images)[t] 170 | loss=self.criterion(t,output,targets) 171 | _,pred=output.max(1) 172 | hits=(pred==targets).float() 173 | 174 | # Log 175 | total_loss+=loss.data.cpu().numpy()*len(b) 176 | total_acc+=hits.sum().data.cpu().numpy() 177 | total_num+=len(b) 178 | 179 | return total_loss/total_num,total_acc/total_num 180 | 181 | def criterion(self,t,output,targets): 182 | # Regularization for all previous tasks 183 | loss_reg=0 184 | if t>0: 185 | for (n,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 186 | loss_reg+=torch.sum((self.fisher[n] + self.s[n])*(param_old-param).pow(2)) 187 | return self.ce(output,targets)+self.lamb*loss_reg 188 | 189 | def update_fisher_and_s(self): 190 | for n, p in self.model.named_parameters(): 191 | if p.requires_grad: 192 | if p.grad is not None: 193 | # Compute running fisher 194 | fisher_current = p.grad.data.pow(2) 195 | self.fisher_running[n] = self.alpha*fisher_current + (1-self.alpha)*self.fisher_running[n] 196 | 197 | # Compute running s 198 | loss_diff = -p.grad * (p.detach() - self.p_old[n]) 199 | fisher_distance = (1/2) * (self.fisher_running[n]*(p.detach() - self.p_old[n])**2) 200 | s = loss_diff /(fisher_distance+self.eps) 201 | self.s_running[n] = self.s_running[n] + s 202 | 203 | self.p_old[n] = p.detach().clone() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | import pickle 9 | import utils 10 | import torch 11 | from arguments import get_args 12 | 13 | 14 | tstart = time.time() 15 | 16 | # Arguments 17 | 18 | args = get_args() 19 | 20 | ##########################################################################################################################33 21 | 22 | if args.approach == 'afec_ewc' or args.approach == 'ewc' or args.approach == 'afec_rwalk' or args.approach == 'rwalk' or args.approach == 'afec_mas' or args.approach == 'mas' or args.approach == 'afec_si' or args.approach == 'si' or args.approach == 'ft' or args.approach == 'random_init': 23 | log_name = '{}_{}_{}_{}_lamb_{}_lr_{}_batch_{}_epoch_{}'.format(args.date, args.experiment, args.approach,args.seed, 24 | args.lamb, args.lr, args.batch_size, args.nepochs) 25 | elif args.approach == 'gs': 26 | log_name = '{}_{}_{}_{}_lamb_{}_mu_{}_rho_{}_eta_{}_lr_{}_batch_{}_epoch_{}'.format(args.date, args.experiment, 27 | args.approach, args.seed, 28 | args.lamb, args.mu, args.rho, 29 | args.eta, 30 | args.lr, args.batch_size, args.nepochs) 31 | 32 | 33 | if args.output == '': 34 | args.output = './result_data/' + log_name + '.txt' 35 | tr_output = './result_data/' + log_name + '_train' '.txt' 36 | ######################################################################################################################## 37 | # Seed 38 | np.random.seed(args.seed) 39 | torch.manual_seed(args.seed) 40 | if torch.cuda.is_available(): 41 | torch.cuda.manual_seed(args.seed) 42 | else: 43 | print('[CUDA unavailable]'); sys.exit() 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | 47 | # Args -- Experiment 48 | if args.experiment == 'split_cifar100': 49 | from dataloaders import split_cifar100 as dataloader 50 | if args.experiment == 'split_cifar100_SC': 51 | from dataloaders import split_cifar100_SC as dataloader 52 | elif args.experiment == 'split_cifar10_100': 53 | from dataloaders import split_cifar10_100 as dataloader 54 | 55 | # Args -- Approach 56 | if args.approach == 'gs': 57 | from approaches import gs as approach 58 | elif args.approach == 'afec_ewc': 59 | from approaches import afec_ewc as approach 60 | elif args.approach == 'ewc': 61 | from approaches import ewc as approach 62 | elif args.approach == 'afec_si': 63 | from approaches import afec_si as approach 64 | elif args.approach == 'si': 65 | from approaches import si as approach 66 | elif args.approach == 'afec_rwalk': 67 | from approaches import afec_rwalk as approach 68 | elif args.approach == 'rwalk': 69 | from approaches import rwalk as approach 70 | elif args.approach == 'afec_mas': 71 | from approaches import afec_mas as approach 72 | elif args.approach == 'mas': 73 | from approaches import mas as approach 74 | elif args.approach == 'ft': 75 | from approaches import ft as approach 76 | 77 | if args.experiment == 'split_cifar100' or args.experiment == 'split_cifar100_SC' or args.experiment == 'split_cifar10_100': 78 | from networks import conv_net as network 79 | 80 | ######################################################################################################################## 81 | 82 | 83 | # Load 84 | print('Load data...') 85 | data, taskcla, inputsize = dataloader.get(seed=args.seed, tasknum=args.tasknum) # num_task is provided by dataloader 86 | print('\nInput size =', inputsize, '\nTask info =', taskcla) 87 | 88 | # Inits 89 | print('Inits...') 90 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 91 | 92 | if not os.path.isdir('result_data'): 93 | print('Make directory for saving results') 94 | os.makedirs('result_data') 95 | 96 | if not os.path.isdir('trained_model'): 97 | print('Make directory for saving trained models') 98 | os.makedirs('trained_model') 99 | 100 | net = network.Net(inputsize, taskcla).cuda() 101 | net_emp = network.Net(inputsize, taskcla).cuda() 102 | if 'afec' in args.approach: 103 | appr = approach.Appr(net, sbatch=args.batch_size, lr=args.lr, nepochs=args.nepochs, args=args, log_name=log_name, empty_net = net_emp) 104 | else: 105 | appr = approach.Appr(net, sbatch=args.batch_size, lr=args.lr, nepochs=args.nepochs, args=args, log_name=log_name) 106 | 107 | utils.print_model_report(net) 108 | print(appr.criterion) 109 | utils.print_optimizer_config(appr.optimizer) 110 | print('-' * 100) 111 | relevance_set = {} 112 | # Loop tasks 113 | acc = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32) 114 | lss = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32) 115 | for t, ncla in taskcla: 116 | if t==1 and 'find_mu' in args.date: 117 | break 118 | 119 | print('*' * 100) 120 | print('Task {:2d} ({:s})'.format(t, data[t]['name'])) 121 | print('*' * 100) 122 | 123 | # Get data 124 | xtrain = data[t]['train']['x'].cuda() 125 | xvalid = data[t]['valid']['x'].cuda() 126 | 127 | ytrain = data[t]['train']['y'].cuda() 128 | yvalid = data[t]['valid']['y'].cuda() 129 | task = t 130 | 131 | # Train 132 | appr.train(task, xtrain, ytrain, xvalid, yvalid, data, inputsize, taskcla) 133 | print('-' * 100) 134 | 135 | # Test 136 | for u in range(t + 1): 137 | xtest = data[u]['test']['x'].cuda() 138 | ytest = data[u]['test']['y'].cuda() 139 | test_loss, test_acc = appr.eval(u, xtest, ytest) 140 | print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<'.format(u, data[u]['name'], test_loss, 141 | 100 * test_acc)) 142 | acc[t, u] = test_acc 143 | lss[t, u] = test_loss 144 | 145 | # Save 146 | 147 | print('Average accuracy={:5.1f}%'.format(100 * np.mean(acc[t,:t+1]))) 148 | print('Save at ' + args.output) 149 | np.savetxt(args.output, acc, '%.4f') 150 | #if args.approach != 'gs': 151 | # torch.save(net.state_dict(), './trained_model/' + log_name + '_task_{}.pt'.format(t)) 152 | 153 | 154 | # Done 155 | print('*' * 100) 156 | print('Accuracies =') 157 | for i in range(acc.shape[0]): 158 | print('\t', end='') 159 | for j in range(acc.shape[1]): 160 | print('{:5.1f}% '.format(100 * acc[i, j]), end='') 161 | print() 162 | print('*' * 100) 163 | print('Done!') 164 | 165 | print('[Elapsed time = {:.1f} h]'.format((time.time() - tstart) / (60 * 60))) 166 | 167 | if hasattr(appr, 'logs'): 168 | if appr.logs is not None: 169 | # save task names 170 | from copy import deepcopy 171 | 172 | appr.logs['task_name'] = {} 173 | appr.logs['test_acc'] = {} 174 | appr.logs['test_loss'] = {} 175 | for t, ncla in taskcla: 176 | appr.logs['task_name'][t] = deepcopy(data[t]['name']) 177 | appr.logs['test_acc'][t] = deepcopy(acc[t, :]) 178 | appr.logs['test_loss'][t] = deepcopy(lss[t, :]) 179 | # pickle 180 | import gzip 181 | import pickle 182 | 183 | with gzip.open(os.path.join(appr.logpath), 'wb') as output: 184 | pickle.dump(appr.logs, output, pickle.HIGHEST_PROTOCOL) 185 | 186 | ######################################################################################################################## 187 | 188 | -------------------------------------------------------------------------------- /dataloaders/split_cifar10_100.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import torch 4 | import utils 5 | from torchvision import datasets,transforms 6 | from sklearn.utils import shuffle 7 | 8 | def get(seed=0,pc_valid=0.10, tasknum = 10): 9 | data={} 10 | taskcla=[] 11 | size=[3,32,32] 12 | 13 | if not os.path.isdir('../dat/binary_cifar10_5/'): 14 | # CIFAR10 15 | os.makedirs('../dat/binary_cifar10_5') 16 | 17 | mean=[x/255 for x in [125.3,123.0,113.9]] 18 | std=[x/255 for x in [63.0,62.1,66.7]] 19 | 20 | # CIFAR10 21 | dat={} 22 | dat['train']=datasets.CIFAR10('../dat/',train=True,download=True, 23 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 24 | dat['test']=datasets.CIFAR10('../dat/',train=False,download=True, 25 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 26 | 27 | 28 | for n in range(2): 29 | data[n]={} 30 | data[n]['name']='cifar10' 31 | data[n]['ncla']= 5 32 | data[n]['train']={'x': [],'y': []} 33 | data[n]['test']={'x': [],'y': []} 34 | 35 | for s in ['train','test']: 36 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 37 | for image,target in loader: 38 | task_idx = target.numpy()[0] // 5 #num_task 39 | #print("task_idx", task_idx) 40 | data[task_idx][s]['x'].append(image) 41 | data[task_idx][s]['y'].append(target.numpy()[0] % 5) 42 | 43 | # "Unify" and save 44 | for t in range(2): 45 | for s in ['train','test']: 46 | data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2]) 47 | data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1) 48 | torch.save(data[t][s]['x'], os.path.join(os.path.expanduser('../dat/binary_cifar10_5'), 49 | 'data'+str(t)+s+'x.bin')) 50 | torch.save(data[t][s]['y'], os.path.join(os.path.expanduser('../dat/binary_cifar10_5'), 51 | 'data'+str(t)+s+'y.bin')) 52 | 53 | ''' 54 | data[0]={} 55 | data[0]['name']='cifar10' 56 | data[0]['ncla']=10 57 | data[0]['train']={'x': [],'y': []} 58 | data[0]['test']={'x': [],'y': []} 59 | for s in ['train','test']: 60 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 61 | for image,target in loader: 62 | data[0][s]['x'].append(image) 63 | data[0][s]['y'].append(target.numpy()[0]) 64 | 65 | # "Unify" and save 66 | for s in ['train','test']: 67 | data[0][s]['x']=torch.stack(data[0][s]['x']).view(-1,size[0],size[1],size[2]) 68 | data[0][s]['y']=torch.LongTensor(np.array(data[0][s]['y'],dtype=int)).view(-1) 69 | torch.save(data[0][s]['x'], os.path.join(os.path.expanduser('../dat/binary_cifar10_5'),'data'+s+'x.bin')) 70 | torch.save(data[0][s]['y'], os.path.join(os.path.expanduser('../dat/binary_cifar10_5'),'data'+s+'y.bin')) 71 | ''' 72 | 73 | 74 | if not os.path.isdir('../dat/binary_split_cifar100_5/'): 75 | # CIFAR100 76 | os.makedirs('../dat/binary_split_cifar100_5') 77 | dat={} 78 | 79 | mean = [0.5071, 0.4867, 0.4408] 80 | std = [0.2675, 0.2565, 0.2761] 81 | 82 | dat['train']=datasets.CIFAR100('../dat/',train=True,download=True, 83 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 84 | dat['test']=datasets.CIFAR100('../dat/',train=False,download=True, 85 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 86 | for n in range(2,22): 87 | data[n]={} 88 | data[n]['name']='cifar100' 89 | data[n]['ncla']=5 90 | data[n]['train']={'x': [],'y': []} 91 | data[n]['test']={'x': [],'y': []} 92 | for s in ['train','test']: 93 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 94 | for image,target in loader: 95 | task_idx = target.numpy()[0] // 5 + 2 96 | data[task_idx][s]['x'].append(image) 97 | data[task_idx][s]['y'].append(target.numpy()[0]%5) 98 | 99 | 100 | 101 | for t in range(2,22): 102 | for s in ['train','test']: 103 | data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2]) 104 | data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1) 105 | torch.save(data[t][s]['x'], os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 106 | 'data'+str(t)+s+'x.bin')) 107 | torch.save(data[t][s]['y'], os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 108 | 'data'+str(t)+s+'y.bin')) 109 | 110 | # Load binary files 111 | data={} 112 | data[0] = dict.fromkeys(['name','ncla','train','test']) 113 | for i in range(2): 114 | data[i] = dict.fromkeys(['name','ncla','train','test']) 115 | for s in ['train','test']: 116 | data[i][s]={'x':[],'y':[]} 117 | data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('../dat/binary_cifar10_5'), 118 | 'data'+str(i)+s+'x.bin')) 119 | data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('../dat/binary_cifar10_5'), 120 | 'data'+str(i)+s+'y.bin')) 121 | data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy())) 122 | data[i]['name']='cifar10-'+str(i) 123 | 124 | ''' 125 | for s in ['train','test']: 126 | data[0][s]={'x':[],'y':[]} 127 | data[0][s]['x']=torch.load(os.path.join(os.path.expanduser('../dat/binary_cifar10_5'),'data'+s+'x.bin')) 128 | data[0][s]['y']=torch.load(os.path.join(os.path.expanduser('../dat/binary_cifar10_5'),'data'+s+'y.bin')) 129 | data[0]['ncla']=len(np.unique(data[0]['train']['y'].numpy())) 130 | data[0]['name']='cifar10' 131 | ''' 132 | 133 | ids=list(shuffle(np.arange(20),random_state=seed) + 1) 134 | # ids=list(range(1,11)) 135 | print('Task order =',ids) 136 | for i in range(2,22): 137 | data[i] = dict.fromkeys(['name','ncla','train','test']) 138 | for s in ['train','test']: 139 | data[i][s]={'x':[],'y':[]} 140 | data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 141 | 'data'+str(ids[i-2]+1)+s+'x.bin')) 142 | data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('../dat/binary_split_cifar100_5'), 143 | 'data'+str(ids[i-2]+1)+s+'y.bin')) 144 | data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy())) 145 | data[i]['name']='cifar100-'+str(ids[i-2]) 146 | 147 | # Validation 148 | for t in range(22): 149 | r=np.arange(data[t]['train']['x'].size(0)) 150 | r=np.array(shuffle(r,random_state=seed),dtype=int) 151 | nvalid=int(pc_valid*len(r)) 152 | ivalid=torch.LongTensor(r[:nvalid]) 153 | itrain=torch.LongTensor(r[nvalid:]) 154 | data[t]['valid']={} 155 | data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone() 156 | data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone() 157 | data[t]['train']['x']=data[t]['train']['x'][itrain].clone() 158 | data[t]['train']['y']=data[t]['train']['y'][itrain].clone() 159 | 160 | # Others 161 | n=0 162 | for t in range(22): 163 | taskcla.append((t,data[t]['ncla'])) 164 | n+=data[t]['ncla'] 165 | data['ncla']=n 166 | 167 | return data,taskcla,size 168 | -------------------------------------------------------------------------------- /approaches/si.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import random 4 | import torch 5 | from copy import deepcopy 6 | import utils 7 | from utils import * 8 | sys.path.append('..') 9 | from arguments import get_args 10 | import torch.nn.functional as F 11 | import torch.nn as nn 12 | args = get_args() 13 | 14 | class Appr(): 15 | """ Class implementing the Synaptic intelligence approach described in https://arxiv.org/abs/1703.04200 """ 16 | 17 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None): 18 | super().__init__() 19 | self.model=model 20 | self.model_old=model 21 | 22 | self.nepochs = nepochs 23 | self.sbatch = sbatch 24 | self.lr = lr 25 | self.lr_min = lr_min * 1/3 26 | self.lr_factor = lr_factor 27 | self.lr_patience = lr_patience 28 | self.clipgrad = clipgrad 29 | 30 | self.ce=torch.nn.CrossEntropyLoss() 31 | self.optimizer=self._get_optimizer() 32 | self.c=args.lamb 33 | self.epsilon=0.01 34 | self.omega = {} 35 | self.W = {} 36 | self.p_old = {} 37 | 38 | n=0 39 | 40 | # Register starting param-values (needed for “intelligent synapses”). 41 | for n, p in self.model.named_parameters(): 42 | if p.requires_grad: 43 | n = n.replace('.', '__') 44 | self.model.register_buffer('{}_SI_prev_task'.format(n), p.data.clone()) 45 | 46 | return 47 | 48 | def _get_optimizer(self,lr=None): 49 | if lr is None: lr=self.lr 50 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 51 | return optimizer 52 | 53 | 54 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 55 | best_loss = np.inf 56 | best_model = utils.get_model(self.model) 57 | lr = self.lr 58 | patience = self.lr_patience 59 | self.optimizer = self._get_optimizer(lr) 60 | 61 | self.W = {} 62 | self.p_old = {} 63 | for n, p in self.model.named_parameters(): 64 | if p.requires_grad: 65 | n = n.replace('.', '__') 66 | self.W[n] = p.data.clone().zero_() 67 | self.p_old[n] = p.data.clone() 68 | 69 | # Loop epochs 70 | for e in range(self.nepochs): 71 | # Train 72 | clock0=time.time() 73 | num_batch = xtrain.size(0) 74 | 75 | self.train_epoch(t,xtrain,ytrain) 76 | 77 | clock1=time.time() 78 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 79 | clock2=time.time() 80 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 81 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch,1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 82 | # Valid 83 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 84 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 85 | print() 86 | #save log for current task & old tasks at every epoch 87 | 88 | if valid_loss < best_loss: 89 | best_loss = valid_loss 90 | best_model = utils.get_model(self.model) 91 | patience = self.lr_patience 92 | print(' *', end='') 93 | 94 | else: 95 | patience -= 1 96 | if patience <= 0: 97 | lr /= self.lr_factor 98 | print(' lr={:.1e}'.format(lr), end='') 99 | if lr < self.lr_min: 100 | print() 101 | patience = self.lr_patience 102 | self.optimizer = self._get_optimizer(lr) 103 | print() 104 | 105 | 106 | # Restore best 107 | utils.set_model_(self.model, best_model) 108 | 109 | self.update_omega(self.W, self.epsilon) 110 | self.model_old = deepcopy(self.model) 111 | utils.freeze_model(self.model_old) # Freeze the weights 112 | 113 | return 114 | 115 | def train_epoch(self,t,x,y): 116 | self.model.train() 117 | 118 | r=np.arange(x.size(0)) 119 | np.random.shuffle(r) 120 | r=torch.LongTensor(r).cuda() 121 | 122 | # Loop batches 123 | for i in range(0,len(r),self.sbatch): 124 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 125 | else: b=r[i:] 126 | images=x[b] 127 | targets=y[b] 128 | 129 | # Forward current model 130 | output = self.model.forward(images)[t] 131 | loss=self.criterion(t,output,targets) 132 | 133 | n = 0 134 | # Backward 135 | self.optimizer.zero_grad() 136 | loss.backward() 137 | self.optimizer.step() 138 | for n, p in self.model.named_parameters(): 139 | if p.requires_grad: 140 | n = n.replace('.', '__') 141 | if p.grad is not None: 142 | self.W[n].add_(-p.grad * (p.detach() - self.p_old[n])) 143 | self.p_old[n] = p.detach().clone() 144 | 145 | return 146 | 147 | def eval(self,t,x,y): 148 | total_loss=0 149 | total_acc=0 150 | total_num=0 151 | self.model.eval() 152 | 153 | r = np.arange(x.size(0)) 154 | r = torch.LongTensor(r).cuda() 155 | 156 | # Loop batches 157 | for i in range(0,len(r),self.sbatch): 158 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 159 | else: b=r[i:] 160 | images=x[b] 161 | targets=y[b] 162 | 163 | # Forward 164 | output = self.model.forward(images)[t] 165 | 166 | loss=self.criterion(t,output,targets) 167 | _,pred=output.max(1) 168 | hits=(pred==targets).float() 169 | 170 | # Log 171 | total_loss+=loss.data.cpu().numpy()*len(b) 172 | total_acc+=hits.sum().data.cpu().numpy() 173 | total_num+=len(b) 174 | 175 | return total_loss/total_num,total_acc/total_num 176 | 177 | def criterion(self,t,output,targets): 178 | # Regularization for all previous tasks 179 | loss_reg = 0 180 | if t>0: 181 | loss_reg=self.surrogate_loss() 182 | 183 | return self.ce(output,targets)+self.c*loss_reg 184 | 185 | def update_omega(self, W, epsilon): 186 | """After completing training on a task, update the per-parameter regularization strength. 187 | [W] estimated parameter-specific contribution to changes in total loss of completed task 188 | [epsilon] dampening parameter (to bound [omega] when [p_change] goes to 0)""" 189 | 190 | # Loop over all parameters 191 | for n, p in self.model.named_parameters(): 192 | if p.requires_grad: 193 | n = n.replace('.', '__') 194 | 195 | # Find/calculate new values for quadratic penalty on parameters 196 | p_prev = getattr(self.model, '{}_SI_prev_task'.format(n)) 197 | p_current = p.detach().clone() 198 | p_change = p_current - p_prev 199 | omega_add = W[n] / (p_change ** 2 + epsilon) 200 | try: 201 | omega = getattr(self.model, '{}_SI_omega'.format(n)) 202 | except AttributeError: 203 | omega = p.detach().clone().zero_() 204 | omega_new = omega + omega_add 205 | 206 | # Store these new values in the model 207 | self.model.register_buffer('{}_SI_prev_task'.format(n), p_current) 208 | self.model.register_buffer('{}_SI_omega'.format(n), omega_new) 209 | 210 | def surrogate_loss(self): 211 | """Calculate SI’s surrogate loss""" 212 | try: 213 | losses = [] 214 | for n, p in self.model.named_parameters(): 215 | if p.requires_grad: 216 | # Retrieve previous parameter values and their normalized path integral (i.e., omega) 217 | n = n.replace('.', '__') 218 | prev_values = getattr(self.model, '{}_SI_prev_task'.format(n)) 219 | omega = getattr(self.model, '{}_SI_omega'.format(n)) 220 | # Calculate SI’s surrogate loss, sum over all parameters 221 | losses.append((omega * (p - prev_values) ** 2).sum()) 222 | return sum(losses) 223 | except AttributeError: 224 | # SI-loss is 0 if there is no stored omega yet 225 | return 0. -------------------------------------------------------------------------------- /approaches/afec_ewc.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | args = get_args() 12 | import itertools 13 | 14 | 15 | class Appr(object): 16 | """ Class implementing the Elastic Weight Consolidation approach described in http://arxiv.org/abs/1612.00796 """ 17 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None, empty_net = None): 18 | self.model=model 19 | self.model_old=model 20 | self.model_emp = empty_net 21 | self.model_emp_tmp = empty_net 22 | self.model_pt = None 23 | 24 | self.fisher = None 25 | self.fisher_emp = None 26 | 27 | self.nepochs = nepochs 28 | self.sbatch = sbatch 29 | self.lr = lr 30 | self.lr_min = lr_min * 1/3 31 | self.lr_factor = lr_factor 32 | self.lr_patience = lr_patience 33 | self.clipgrad = clipgrad 34 | 35 | self.ce=torch.nn.CrossEntropyLoss() 36 | self.optimizer=self._get_optimizer() 37 | self.optimizer_emp = self._get_optimizer_emp() 38 | self.lamb = args.lamb 39 | self.lamb_emp = args.lamb_emp 40 | 41 | if len(args.parameter)>=1: 42 | params=args.parameter.split(',') 43 | print('Setting parameters to',params) 44 | self.lamb=float(params[0]) 45 | 46 | return 47 | 48 | def _get_optimizer(self,lr=None): 49 | if lr is None: lr=self.lr 50 | 51 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 52 | return optimizer 53 | 54 | def _get_optimizer_emp(self, lr=None): 55 | if lr is None: lr = self.lr 56 | 57 | optimizer = torch.optim.Adam(self.model_emp.parameters(), lr=lr) 58 | return optimizer 59 | 60 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 61 | best_loss = np.inf 62 | best_model = utils.get_model(self.model) 63 | lr = self.lr 64 | self.optimizer = self._get_optimizer(lr) 65 | self.optimizer_emp = self._get_optimizer_emp(lr) 66 | self.add_emp = 0 67 | 68 | if t == 0: 69 | self.model_emp = deepcopy(self.model) #use the same initialization 70 | self.model_emp_tmp = deepcopy(self.model) 71 | 72 | # Loop epochs 73 | for e in range(self.nepochs): 74 | # Train 75 | clock0=time.time() 76 | 77 | num_batch = xtrain.size(0) 78 | 79 | #train the empty net and measure fim 80 | if t > self.add_emp-1: 81 | 82 | self.train_emp_epoch(t, xtrain, ytrain, e) 83 | # freeze the empty net 84 | self.model_emp_tmp = deepcopy(self.model_emp) 85 | self.model_emp_tmp.train() 86 | utils.freeze_model(self.model_emp_tmp) 87 | 88 | # Fisher ops 89 | self.fisher_emp, _ = utils.fisher_matrix_diag_emp(t, xtrain, ytrain, self.model_emp, self.criterion) 90 | 91 | self.train_epoch(t, xtrain, ytrain, e) 92 | 93 | clock1=time.time() 94 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 95 | clock2=time.time() 96 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 97 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 98 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 99 | # Valid 100 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 101 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 102 | print(' lr : {:.6f}'.format(self.optimizer.param_groups[0]['lr'])) 103 | #save log for current task & old tasks at every epoch 104 | 105 | # Adapt lr 106 | if valid_loss < best_loss: 107 | best_loss = valid_loss 108 | best_model = utils.get_model(self.model) 109 | patience = self.lr_patience 110 | print(' *', end='') 111 | 112 | else: 113 | patience -= 1 114 | if patience <= 0: 115 | lr /= self.lr_factor 116 | print(' lr={:.1e}'.format(lr), end='') 117 | if lr < self.lr_min: 118 | print() 119 | patience = self.lr_patience 120 | self.optimizer = self._get_optimizer(lr) 121 | self.optimizer_emp = self._get_optimizer_emp(lr) 122 | print() 123 | 124 | # after pretrain in task 0, copy the PT model as empty 125 | if t == 0: 126 | self.model_pt = deepcopy(self.model) 127 | 128 | # Restore best 129 | utils.set_model_(self.model, best_model) 130 | 131 | # Update old 132 | self.model_old = deepcopy(self.model) 133 | self.model_old.train() 134 | utils.freeze_model(self.model_old) # Freeze the weights 135 | 136 | # Fisher ops 137 | if t>0: 138 | fisher_old={} 139 | for n,_ in self.model.named_parameters(): 140 | fisher_old[n]=self.fisher[n].clone() 141 | self.fisher=utils.fisher_matrix_diag(t,xtrain,ytrain,self.model,self.criterion) 142 | if t>0: 143 | # Watch out! We do not want to keep t models (or fisher diagonals) in memory, therefore we have to merge fisher diagonals 144 | for n,_ in self.model.named_parameters(): 145 | self.fisher[n]=(self.fisher[n]+fisher_old[n]*t)/(t+1) # Checked: it is better than the other option 146 | 147 | return 148 | 149 | def train_epoch(self,t,x,y, epoch): 150 | self.model.train() 151 | 152 | r=np.arange(x.size(0)) 153 | np.random.shuffle(r) 154 | r=torch.LongTensor(r).cuda() 155 | 156 | # Loop batches 157 | for i in range(0,len(r),self.sbatch): 158 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 159 | else: b=r[i:] 160 | images=x[b] 161 | targets=y[b] 162 | 163 | # Forward current model 164 | outputs = self.model.forward(images)[t] 165 | loss = self.ce(outputs,targets) 166 | 167 | if t > self.add_emp: 168 | loss_fg = self.criterion_fg(t) 169 | loss += loss_fg 170 | 171 | self.optimizer.zero_grad() 172 | loss.backward() 173 | self.optimizer.step() 174 | 175 | del loss 176 | del images, targets, outputs 177 | 178 | return 179 | 180 | 181 | def train_emp_epoch(self,t,x,y, epoch): 182 | self.model_emp.train() 183 | 184 | r=np.arange(x.size(0)) 185 | np.random.shuffle(r) 186 | r=torch.LongTensor(r).cuda() 187 | 188 | # Loop batches 189 | for i in range(0,len(r),self.sbatch): 190 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 191 | else: b=r[i:] 192 | images=x[b] 193 | targets=y[b] 194 | 195 | # train empty net 196 | # Forward current model 197 | outputs = self.model_emp.forward(images)[t] 198 | loss = self.ce(outputs, targets) 199 | 200 | # Backward 201 | self.optimizer_emp.zero_grad() 202 | loss.backward() 203 | self.optimizer_emp.step() 204 | 205 | return 206 | 207 | def eval(self,t,x,y): 208 | total_loss=0 209 | total_acc=0 210 | total_num=0 211 | self.model.eval() 212 | 213 | r = np.arange(x.size(0)) 214 | r = torch.LongTensor(r).cuda() 215 | 216 | # Loop batches 217 | for i in range(0,len(r),self.sbatch): 218 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 219 | else: b=r[i:] 220 | images=x[b] 221 | targets=y[b] 222 | 223 | # Forward 224 | output = self.model.forward(images)[t] 225 | 226 | loss=self.criterion(t,output,targets) 227 | _,pred=output.max(1) 228 | hits=(pred==targets).float() 229 | 230 | # Log 231 | total_loss+=loss.data.cpu().numpy()*len(b) 232 | total_acc+=hits.sum().data.cpu().numpy() 233 | total_num+=len(b) 234 | 235 | return total_loss/total_num,total_acc/total_num 236 | 237 | def criterion(self,t,output,targets): 238 | # Regularization for all previous tasks 239 | loss_reg=0 240 | if t>0: 241 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 242 | loss_reg+=torch.sum(self.fisher[name]*(param_old-param).pow(2))/2 243 | return self.ce(output,targets)+self.lamb*loss_reg 244 | 245 | 246 | def criterion_fg(self,t): 247 | # Regularization for all previous tasks 248 | loss_reg=0 249 | loss_reg_emp = 0 250 | 251 | if t>0: 252 | 253 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 254 | if 'last' not in name: 255 | loss_reg+=torch.sum(self.fisher[name]*(param_old-param).pow(2))/2 256 | 257 | 258 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_emp_tmp.named_parameters()): 259 | if 'last' not in name: 260 | loss_reg_emp+=torch.sum(self.fisher_emp[name]*(param_old-param).pow(2))/2 261 | 262 | return self.lamb*loss_reg + self.lamb_emp*loss_reg_emp 263 | -------------------------------------------------------------------------------- /LargeScale/trainer/afec_mas.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import copy 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data as td 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import trainer 14 | from copy import deepcopy 15 | import itertools 16 | 17 | import networks 18 | 19 | 20 | class Trainer(trainer.GenericTrainer): 21 | def __init__(self, model, args, optimizer, evaluator, taskcla, model_emp=None): 22 | super().__init__(model, args, optimizer, evaluator, taskcla) 23 | 24 | self.omega = {} 25 | for n,_ in self.model.named_parameters(): 26 | self.omega[n] = 0 27 | 28 | self.lamb = args.lamb 29 | self.lamb_emp = args.lamb_emp 30 | self.lamb_emp_tmp = args.lamb_emp 31 | 32 | self.fisher_emp = None 33 | 34 | def update_lr(self, epoch, schedule): 35 | for temp in range(0, len(schedule)): 36 | if schedule[temp] == epoch: 37 | for param_group in self.optimizer.param_groups: 38 | self.current_lr = param_group['lr'] 39 | param_group['lr'] = self.current_lr * self.args.gammas[temp] 40 | print("full net: Changing learning rate from %0.4f to %0.4f"%(self.current_lr, 41 | self.current_lr * self.args.gammas[temp])) 42 | self.current_lr *= self.args.gammas[temp] 43 | 44 | def update_lr_emp(self, epoch, schedule): 45 | for temp in range(0, len(schedule)): 46 | if schedule[temp] == epoch: 47 | for param_group in self.optimizer_emp.param_groups: 48 | self.current_lr = param_group['lr'] 49 | param_group['lr'] = self.current_lr * self.args.gammas[temp] 50 | print("emp net: Changing learning rate from %0.4f to %0.4f"%(self.current_lr, 51 | self.current_lr * self.args.gammas[temp])) 52 | self.current_lr *= self.args.gammas[temp] 53 | 54 | def setup_training(self, lr): 55 | 56 | for param_group in self.optimizer.param_groups: 57 | print("Setting LR to %0.4f in full net"%lr) 58 | param_group['lr'] = lr 59 | self.current_lr = lr 60 | 61 | def setup_training_emp(self, lr): 62 | 63 | for param_group in self.optimizer_emp.param_groups: 64 | print("Setting LR to %0.4f in emp net" % lr) 65 | param_group['lr'] = lr 66 | self.current_lr = lr 67 | 68 | def update_frozen_model(self): 69 | self.model.eval() 70 | self.model_fixed = copy.deepcopy(self.model) 71 | self.model_fixed.eval() 72 | for param in self.model_fixed.parameters(): 73 | param.requires_grad = False 74 | 75 | def train(self, train_loader, test_loader, t): 76 | if t == 0: 77 | self.model_emp_tmp = deepcopy(self.model) 78 | self.model_emp = deepcopy(self.model) 79 | #self.model_emp_pt = deepcopy(self.model) 80 | #self.model_emp = deepcopy(self.model_emp_pt) 81 | 82 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=self.args.decay) 83 | self.optimizer_emp = torch.optim.SGD(self.model_emp.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=self.args.decay) 84 | 85 | 86 | lr = self.args.lr 87 | self.setup_training(lr) 88 | self.setup_training_emp(lr) 89 | 90 | # Do not update self.t 91 | if t>0: 92 | self.update_frozen_model() 93 | self.omega_update() 94 | 95 | # Now, you can update self.t 96 | self.t = t 97 | 98 | kwargs = {'num_workers': 8, 'pin_memory': True} 99 | #kwargs = {'num_workers': 0, 'pin_memory': False} 100 | self.train_iterator = torch.utils.data.DataLoader(train_loader, batch_size=self.args.batch_size, shuffle=True, **kwargs) 101 | self.test_iterator = torch.utils.data.DataLoader(test_loader, 100, shuffle=False, **kwargs) 102 | self.omega_iterator = torch.utils.data.DataLoader(train_loader, batch_size=20, shuffle=True, **kwargs) 103 | self.fisher_iterator = torch.utils.data.DataLoader(train_loader, batch_size=20, shuffle=True, **kwargs) 104 | 105 | for epoch in range(self.args.nepochs): 106 | self.model.train() 107 | self.update_lr(epoch, self.args.schedule) 108 | self.update_lr_emp(epoch, self.args.schedule) 109 | 110 | #train the empty net and measure fim 111 | if t > -1: 112 | #train empty net 113 | for samples in tqdm(self.train_iterator): 114 | data, target = samples 115 | data, target = data.cuda(), target.cuda() 116 | batch_size = data.shape[0] 117 | 118 | output = self.model_emp(data)[t] 119 | loss_CE = self.ce(output, target) 120 | 121 | self.optimizer_emp.zero_grad() 122 | (loss_CE).backward() 123 | self.optimizer_emp.step() 124 | 125 | # freeze the empty net 126 | self.model_emp_tmp = deepcopy(self.model_emp) 127 | self.model_emp_tmp.eval() 128 | for param in self.model_emp_tmp.parameters(): 129 | param.requires_grad = False 130 | 131 | # Fisher ops 132 | self.fisher_emp = self.fisher_matrix_diag_emp() 133 | 134 | for samples in tqdm(self.train_iterator): 135 | data, target = samples 136 | data, target = data.cuda(), target.cuda() 137 | 138 | output = self.model(data)[t] 139 | loss_CE = self.ce(output,target) 140 | 141 | loss = loss_CE 142 | 143 | if t > 0: 144 | loss_fg = self.criterion_fg() 145 | 146 | loss = loss_CE + loss_fg 147 | 148 | self.optimizer.zero_grad() 149 | loss.backward() 150 | self.optimizer.step() 151 | 152 | train_loss,train_acc = self.evaluator.evaluate(self.model, self.train_iterator, t) 153 | num_batch = len(self.train_iterator) 154 | print('| Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% |'.format(epoch+1,train_loss,100*train_acc),end='') 155 | valid_loss,valid_acc=self.evaluator.evaluate(self.model, self.test_iterator, t) 156 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 157 | print() 158 | 159 | 160 | 161 | def criterion(self,output,targets): 162 | # Regularization for all previous tasks 163 | loss_reg=0 164 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_fixed.named_parameters()): 165 | loss_reg+=torch.sum(self.omega[name]*(param_old-param).pow(2))/2 166 | 167 | return self.ce(output,targets)+self.lamb*loss_reg 168 | 169 | def criterion_fg(self): 170 | # Regularization for all previous tasks 171 | loss_reg = 0 172 | loss_reg_emp = 0 173 | 174 | if self.t > 0: 175 | for (name, param), (_, param_old) in zip(self.model.named_parameters(), 176 | self.model_fixed.named_parameters()): 177 | # if 'conv' in name: 178 | loss_reg += torch.sum(self.omega[name] * (param_old - param).pow(2)) / 2 179 | 180 | for (name, param), (_, param_old) in zip(self.model.named_parameters(), 181 | self.model_emp_tmp.named_parameters()): 182 | # if 'conv' in name: 183 | loss_reg_emp += torch.sum(self.fisher_emp[name] * (param_old - param).pow(2)) / 2 184 | 185 | return self.lamb * loss_reg + self.lamb_emp * loss_reg_emp 186 | 187 | def omega_update(self): 188 | sbatch = 20 189 | 190 | # Compute 191 | self.model.train() 192 | for samples in tqdm(self.omega_iterator): 193 | data, target = samples 194 | data, target = data.cuda(), target.cuda() 195 | # Forward and backward 196 | self.model.zero_grad() 197 | outputs = self.model.forward(data)[self.t] 198 | 199 | # Sum of L2 norm of output scores 200 | loss = torch.sum(outputs.norm(2, dim = -1)) 201 | loss.backward() 202 | 203 | # Get gradients 204 | for n,p in self.model.named_parameters(): 205 | if p.grad is not None: 206 | self.omega[n]+= p.grad.data.abs() / len(self.train_iterator) 207 | 208 | return 209 | 210 | def fisher_matrix_diag_emp(self): 211 | # Init 212 | fisher = {} 213 | for n, p in self.model_emp.named_parameters(): 214 | fisher[n] = 0 * p.data 215 | # Compute 216 | self.model_emp.eval() 217 | criterion = torch.nn.CrossEntropyLoss() 218 | for samples in tqdm(self.fisher_iterator): 219 | data, target = samples 220 | data, target = data.cuda(), target.cuda() 221 | 222 | # Forward and backward 223 | self.model_emp.zero_grad() 224 | outputs = self.model_emp.forward(data)[self.t] 225 | loss = self.ce(outputs, target) 226 | loss.backward() 227 | 228 | # Get gradients 229 | for n, p in self.model_emp.named_parameters(): 230 | if p.grad is not None: 231 | fisher[n] += self.args.batch_size * p.grad.data.pow(2) 232 | # Mean 233 | with torch.no_grad(): 234 | for n, _ in self.model_emp.named_parameters(): 235 | fisher[n] = fisher[n] / len(self.train_iterator) 236 | self.model_emp.train() 237 | return fisher 238 | -------------------------------------------------------------------------------- /approaches/afec_mas.py: -------------------------------------------------------------------------------- 1 | import sys,time,os 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | import utils 6 | from utils import * 7 | sys.path.append('..') 8 | from arguments import get_args 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | args = get_args() 13 | import itertools 14 | 15 | class Appr(object): 16 | """ Class implementing the Elastic Weight Consolidation approach described in http://arxiv.org/abs/1612.00796 """ 17 | 18 | def __init__(self,model,nepochs=100,sbatch=256,lr=0.001,lr_min=1e-6,lr_factor=3,lr_patience=5,clipgrad=100,args=None,log_name = None, empty_net = None): 19 | self.model=model 20 | self.model_old=model 21 | self.model_emp = empty_net 22 | self.model_emp_tmp = empty_net 23 | self.model_pt = None 24 | 25 | self.fisher = None 26 | self.fisher_emp = None 27 | 28 | self.nepochs = nepochs 29 | self.sbatch = sbatch 30 | self.lr = lr 31 | self.lr_min = lr_min * 1/3 32 | self.lr_factor = lr_factor 33 | self.lr_patience = lr_patience 34 | self.clipgrad = clipgrad 35 | 36 | self.ce=torch.nn.CrossEntropyLoss() 37 | self.optimizer=self._get_optimizer() 38 | self.optimizer_emp = self._get_optimizer_emp() 39 | self.lamb = args.lamb 40 | self.lamb_emp = args.lamb_emp 41 | 42 | self.omega = {} 43 | 44 | for n,_ in self.model.named_parameters(): 45 | self.omega[n] = 0 46 | 47 | return 48 | 49 | def _get_optimizer(self, lr=None): 50 | if lr is None: lr = self.lr 51 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 52 | return optimizer 53 | 54 | def _get_optimizer_emp(self, lr=None): 55 | if lr is None: lr = self.lr 56 | optimizer = torch.optim.Adam(self.model_emp.parameters(), lr=lr) 57 | return optimizer 58 | 59 | def train(self, t, xtrain, ytrain, xvalid, yvalid, data, input_size, taskcla): 60 | best_loss = np.inf 61 | best_model = utils.get_model(self.model) 62 | lr = self.lr 63 | self.optimizer = self._get_optimizer(lr) 64 | self.optimizer_emp = self._get_optimizer_emp(lr) 65 | self.add_emp = 0 66 | 67 | if t == 0: 68 | self.model_emp = deepcopy(self.model) #use the same initialization 69 | self.model_emp_tmp = deepcopy(self.model) 70 | 71 | # Loop epochs 72 | for e in range(self.nepochs): 73 | # Train 74 | clock0=time.time() 75 | num_batch = xtrain.size(0) 76 | 77 | #train the empty net and measure fim 78 | if t > self.add_emp-1: #self.add_emp-1 ptemp 79 | 80 | self.train_emp_epoch(t, xtrain, ytrain, e) 81 | 82 | # freeze the empty net 83 | self.model_emp_tmp = deepcopy(self.model_emp) 84 | self.model_emp_tmp.train() 85 | utils.freeze_model(self.model_emp_tmp) 86 | 87 | # Fisher ops 88 | self.fisher_emp, _ = utils.fisher_matrix_diag_emp(t, xtrain, ytrain, self.model_emp, self.criterion) 89 | 90 | self.train_epoch(t,xtrain,ytrain) 91 | 92 | clock1=time.time() 93 | train_loss,train_acc=self.eval(t,xtrain,ytrain) 94 | clock2=time.time() 95 | print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format( 96 | e+1,1000*self.sbatch*(clock1-clock0)/num_batch, 97 | 1000*self.sbatch*(clock2-clock1)/num_batch,train_loss,100*train_acc),end='') 98 | # Valid 99 | valid_loss,valid_acc=self.eval(t,xvalid,yvalid) 100 | print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss,100*valid_acc),end='') 101 | print(' lr : {:.6f}'.format(self.optimizer.param_groups[0]['lr'])) 102 | #save log for current task & old tasks at every epoch 103 | 104 | # Adapt lr 105 | if valid_loss < best_loss: 106 | best_loss = valid_loss 107 | best_model = utils.get_model(self.model) 108 | patience = self.lr_patience 109 | print(' *', end='') 110 | 111 | else: 112 | patience -= 1 113 | if patience <= 0: 114 | lr /= self.lr_factor 115 | print(' lr={:.1e}'.format(lr), end='') 116 | if lr < self.lr_min: 117 | print() 118 | patience = self.lr_patience 119 | self.optimizer = self._get_optimizer(lr) 120 | self.optimizer_emp = self._get_optimizer_emp(lr) 121 | print() 122 | 123 | # after pretrain in task 0, copy the PT model as empty 124 | #if t == 0: 125 | # self.model_pt = deepcopy(self.model) 126 | 127 | # Restore best 128 | utils.set_model_(self.model, best_model) 129 | 130 | # Update old 131 | self.model_old = deepcopy(self.model) 132 | utils.freeze_model(self.model_old) # Freeze the weights 133 | self.omega_update(t,xtrain) 134 | 135 | return 136 | 137 | def train_epoch(self,t,x,y): 138 | self.model.train() 139 | 140 | r=np.arange(x.size(0)) 141 | np.random.shuffle(r) 142 | r=torch.LongTensor(r).cuda() 143 | 144 | # Loop batches 145 | for i in range(0,len(r),self.sbatch): 146 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 147 | else: b=r[i:] 148 | images=x[b] 149 | targets=y[b] 150 | 151 | # Forward current model 152 | outputs = self.model.forward(images)[t] 153 | loss=self.criterion(t,outputs,targets) 154 | 155 | if t > self.add_emp: 156 | loss_fg = self.criterion_fg(t) 157 | loss += loss_fg 158 | 159 | # Backward 160 | self.optimizer.zero_grad() 161 | loss.backward() 162 | self.optimizer.step() 163 | 164 | del loss 165 | del images, targets, outputs 166 | 167 | 168 | return 169 | 170 | 171 | def train_emp_epoch(self,t,x,y, epoch): 172 | self.model_emp.train() 173 | 174 | r=np.arange(x.size(0)) 175 | np.random.shuffle(r) 176 | r=torch.LongTensor(r).cuda() 177 | 178 | # Loop batches 179 | for i in range(0,len(r),self.sbatch): 180 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 181 | else: b=r[i:] 182 | images=x[b] 183 | targets=y[b] 184 | 185 | # train empty net 186 | # Forward current model 187 | outputs = self.model_emp.forward(images)[t] 188 | loss = self.ce(outputs, targets) 189 | 190 | # Backward 191 | self.optimizer_emp.zero_grad() 192 | loss.backward() 193 | self.optimizer_emp.step() 194 | 195 | return 196 | 197 | def eval(self,t,x,y): 198 | total_loss=0 199 | total_acc=0 200 | total_num=0 201 | self.model.eval() 202 | 203 | r = np.arange(x.size(0)) 204 | r = torch.LongTensor(r).cuda() 205 | 206 | # Loop batches 207 | for i in range(0,len(r),self.sbatch): 208 | if i+self.sbatch<=len(r): b=r[i:i+self.sbatch] 209 | else: b=r[i:] 210 | images=x[b] 211 | targets=y[b] 212 | 213 | # Forward 214 | output = self.model.forward(images)[t] 215 | 216 | loss=self.criterion(t,output,targets) 217 | _,pred=output.max(1) 218 | hits=(pred==targets).float() 219 | 220 | # Log 221 | total_loss+=loss.data.cpu().numpy()*len(b) 222 | total_acc+=hits.sum().data.cpu().numpy() 223 | total_num+=len(b) 224 | 225 | return total_loss/total_num,total_acc/total_num 226 | 227 | def criterion(self,t,output,targets): 228 | # Regularization for all previous tasks 229 | loss_reg=0 230 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 231 | loss_reg+=torch.sum(self.omega[name]*(param_old-param).pow(2))/2 232 | 233 | return self.ce(output,targets)+self.lamb*loss_reg 234 | 235 | def criterion_fg(self,t): 236 | # Regularization for all previous tasks 237 | loss_reg=0 238 | loss_reg_emp = 0 239 | 240 | if t>0: 241 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_old.named_parameters()): 242 | if 'last' not in name: 243 | loss_reg+=torch.sum(self.omega[name]*(param_old-param).pow(2))/2 244 | 245 | for (name,param),(_,param_old) in zip(self.model.named_parameters(),self.model_emp_tmp.named_parameters()): 246 | if 'last' not in name: 247 | loss_reg_emp+=torch.sum(self.fisher_emp[name]*(param_old-param).pow(2))/2 248 | 249 | return self.lamb*loss_reg + self.lamb_emp*loss_reg_emp 250 | 251 | def omega_update(self,t,x): 252 | sbatch = 20 253 | 254 | # Compute 255 | self.model.train() 256 | for i in tqdm(range(0,x.size(0),sbatch),desc='Omega',ncols=100,ascii=True): 257 | b=torch.LongTensor(np.arange(i,np.min([i+sbatch,x.size(0)]))).cuda() 258 | images = x[b] 259 | # Forward and backward 260 | self.model.zero_grad() 261 | outputs = self.model.forward(images)[t] 262 | 263 | # Sum of L2 norm of output scores 264 | loss = torch.sum(outputs.norm(2, dim = -1)) 265 | 266 | loss.backward() 267 | 268 | # Get gradients 269 | for n,p in self.model.named_parameters(): 270 | if p.grad is not None: 271 | self.omega[n]+= p.grad.data.abs() / x.size(0) 272 | 273 | return 274 | --------------------------------------------------------------------------------