├── convs ├── __init__.py ├── conv_cifar.py ├── conv_imagenet.py ├── ACL_buffer.py ├── memo_cifar_resnet.py ├── modified_represnet.py ├── cifar_resnet.py ├── ucir_cifar_resnet.py └── resnet_cbam.py ├── models ├── __init__.py ├── dsal.py ├── simplecil.py ├── finetune.py ├── replay.py ├── lwf.py ├── icarl.py ├── bic.py ├── wa.py ├── der.py ├── aper_finetune.py ├── ewc.py ├── pa2s.py └── fetril.py ├── utils ├── __init__.py ├── rl_utils │ ├── rl_utils.py │ └── ddpg.py ├── factory.py ├── toolkit.py ├── ops.py ├── data.py └── autoaugment.py ├── .gitignore ├── resources ├── logo.png ├── logo_v2.png ├── cifar100.png ├── ImageNet100.png ├── imagenet20st5.png └── PR_policy.md ├── exps ├── rmm-pretrain.json ├── der.json ├── ewc.json ├── gem.json ├── wa.json ├── bic.json ├── lwf.json ├── replay.json ├── finetune.json ├── icarl.json ├── podnet.json ├── rmm-icarl.json ├── coil.json ├── fetril.json ├── pass.json ├── aper_finetune.json ├── il2a.json ├── simplecil.json ├── ssre.json ├── tagfex.json ├── beef.json ├── foster.json ├── rmm-foster.json ├── memo.json ├── acil.json └── ds-al.json ├── main.py ├── LICENSE ├── trainer.py └── rmm_train.py /convs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | logs/ 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /resources/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAMDA-CL/PyCIL/HEAD/resources/logo.png -------------------------------------------------------------------------------- /resources/logo_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAMDA-CL/PyCIL/HEAD/resources/logo_v2.png -------------------------------------------------------------------------------- /resources/cifar100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAMDA-CL/PyCIL/HEAD/resources/cifar100.png -------------------------------------------------------------------------------- /resources/ImageNet100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAMDA-CL/PyCIL/HEAD/resources/ImageNet100.png -------------------------------------------------------------------------------- /resources/imagenet20st5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAMDA-CL/PyCIL/HEAD/resources/imagenet20st5.png -------------------------------------------------------------------------------- /exps/rmm-pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "pretrain-rmm", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "shuffle": true, 6 | "model_name": "rmm-icarl", 7 | "convnet_type": "resnet32", 8 | "device": ["0"], 9 | "seed": [1993] 10 | } 11 | -------------------------------------------------------------------------------- /exps/der.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "der", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } -------------------------------------------------------------------------------- /exps/ewc.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "ewc", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } -------------------------------------------------------------------------------- /exps/gem.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "gem", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } -------------------------------------------------------------------------------- /exps/wa.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "wa", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } -------------------------------------------------------------------------------- /exps/bic.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "bic", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } 15 | -------------------------------------------------------------------------------- /exps/lwf.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "lwf", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } 15 | -------------------------------------------------------------------------------- /exps/replay.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "replay", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } -------------------------------------------------------------------------------- /exps/finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "finetune", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } -------------------------------------------------------------------------------- /exps/icarl.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "icarl", 11 | "convnet_type": "resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } 15 | 16 | -------------------------------------------------------------------------------- /exps/podnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "podnet", 11 | "convnet_type": "cosine_resnet32", 12 | "device": ["0","1","2","3"], 13 | "seed": [1993] 14 | } 15 | -------------------------------------------------------------------------------- /exps/rmm-icarl.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "m_rate_list":[0.8, 0.8, 0.6, 0.6, 0.6, 0.6], 5 | "c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0], 6 | "memory_size": 2000, 7 | "shuffle": true, 8 | "init_cls": 50, 9 | "increment": 10, 10 | "model_name": "rmm-icarl", 11 | "convnet_type": "resnet32", 12 | "device": ["0"], 13 | "seed": [1993] 14 | } 15 | 16 | -------------------------------------------------------------------------------- /exps/coil.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 5, 9 | "increment": 5, 10 | "sinkhorn":0.464, 11 | "calibration_term":1.5, 12 | "norm_term":3.0, 13 | "reg_term":1e-3, 14 | "model_name": "coil", 15 | "convnet_type": "cosine_resnet32", 16 | "device": ["0","1","2","3"], 17 | "seed": [1993] 18 | } 19 | -------------------------------------------------------------------------------- /exps/fetril.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "train", 3 | "dataset": "cifar100", 4 | "memory_size": 0, 5 | "shuffle": true, 6 | "init_cls": 40, 7 | "increment": 1, 8 | "model_name": "fetril", 9 | "convnet_type": "resnet32", 10 | "device": ["0"], 11 | "seed": [1993], 12 | "init_epochs": 200, 13 | "init_lr" : 0.1, 14 | "init_weight_decay" : 5e-4, 15 | "epochs" : 50, 16 | "lr" : 0.05, 17 | "batch_size" : 128, 18 | "weight_decay" : 5e-4, 19 | "num_workers" : 8, 20 | "T" : 2 21 | } -------------------------------------------------------------------------------- /exps/pass.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "train", 3 | "dataset": "cifar100", 4 | "memory_size": 0, 5 | "shuffle": true, 6 | "init_cls": 50, 7 | "increment": 10, 8 | "model_name": "pass", 9 | "convnet_type": "resnet18_cbam", 10 | "device": ["0"], 11 | "seed": [1993], 12 | "lambda_fkd":10, 13 | "lambda_proto":10, 14 | "temp":0.1, 15 | "epochs" : 101, 16 | "lr" : 0.001, 17 | "batch_size" : 64, 18 | "weight_decay" : 2e-4, 19 | "step_size":45, 20 | "gamma":0.1, 21 | "num_workers" : 8, 22 | "T" : 2 23 | } -------------------------------------------------------------------------------- /exps/aper_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "aper_finetune", 11 | "convnet_type": "cosine_resnet32", 12 | "device": ["0"], 13 | "trained_epoch": 200, 14 | "tuned_epoch": 20, 15 | "optimizer": "sgd", 16 | "init_weight_decay": 0.05, 17 | "weight_decay": 0.05, 18 | "finetune_lr": 0.005, 19 | 20 | "seed": [1993] 21 | } -------------------------------------------------------------------------------- /exps/il2a.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "cil", 3 | "dataset": "cifar100", 4 | "memory_size": 0, 5 | "shuffle": true, 6 | "init_cls": 50, 7 | "increment": 5, 8 | "model_name": "il2a", 9 | "convnet_type": "resnet18_cbam", 10 | "device": ["0"], 11 | "seed": [1993], 12 | "lambda_fkd":10, 13 | "lambda_proto":10, 14 | "temp":0.1, 15 | "epochs" : 101, 16 | "lr" : 0.001, 17 | "batch_size" : 64, 18 | "weight_decay" : 2e-4, 19 | "step_size":45, 20 | "gamma":0.1, 21 | "num_workers" : 8, 22 | "ratio": 2.5, 23 | "T" : 2 24 | } -------------------------------------------------------------------------------- /exps/simplecil.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 50, 9 | "increment": 10, 10 | "model_name": "simplecil", 11 | "convnet_type": "cosine_resnet32", 12 | "device": ["0"], 13 | "seed": [1993], 14 | 15 | "init_epoch": 200, 16 | "init_lr": 0.01, 17 | "batch_size": 128, 18 | "weight_decay": 0.05, 19 | "init_lr_decay": 0.1, 20 | "init_weight_decay": 5e-4, 21 | "min_lr": 0 22 | } 23 | 24 | -------------------------------------------------------------------------------- /exps/ssre.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "ssre", 3 | "dataset": "cifar100", 4 | "memory_size": 0, 5 | "shuffle": true, 6 | "init_cls": 50, 7 | "increment": 10, 8 | "model_name": "ssre", 9 | "convnet_type": "resnet18_rep", 10 | "device": ["0"], 11 | "seed": [1993], 12 | "lambda_fkd":1, 13 | "lambda_proto":10, 14 | "temp":0.1, 15 | "mode": "parallel_adapters", 16 | "epochs" : 101, 17 | "lr" : 0.001, 18 | "batch_size" : 128, 19 | "weight_decay" : 5e-4, 20 | "step_size":45, 21 | "gamma":0.1, 22 | "threshold": 0.8, 23 | "num_workers" : 8, 24 | "T" : 2 25 | } -------------------------------------------------------------------------------- /utils/rl_utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import torch 4 | import collections 5 | import random 6 | 7 | class ReplayBuffer: 8 | def __init__(self, capacity): 9 | self.buffer = collections.deque(maxlen=capacity) 10 | 11 | def add(self, state, action, reward, next_state, done): 12 | self.buffer.append((state, action, reward, next_state, done)) 13 | 14 | def sample(self, batch_size): 15 | transitions = random.sample(self.buffer, batch_size) 16 | state, action, reward, next_state, done = zip(*transitions) 17 | return np.array(state), np.array(action), reward, np.array(next_state), done 18 | 19 | def size(self): 20 | return len(self.buffer) -------------------------------------------------------------------------------- /exps/tagfex.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100_aa", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "tagfex", 11 | "convnet_type": "resnet18", 12 | "device": ["0"], 13 | "seed": [1993], 14 | "init_interpolation_factor": 0.95, 15 | "infonce_temp":0.2, 16 | "kd_temp":2, 17 | "infonce_kd_temp":0.2, 18 | "ta_convnet_type":"resnet18", 19 | "proj_hidden_dim":2048, 20 | "proj_output_dim":1024, 21 | "attn_num_heads":8, 22 | "contrast_factor":1, 23 | "trans_cls_factor":1, 24 | "transfer_factor":1, 25 | "aux_factor":2, 26 | "contrast_kd_factor":2, 27 | "aug":2 28 | } -------------------------------------------------------------------------------- /exps/beef.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "fusion-energy-0.01-1.7-fixed", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": true, 7 | "shuffle": true, 8 | "init_cls": 50, 9 | "increment": 10, 10 | "model_name": "beefiso", 11 | "convnet_type": "resnet32", 12 | "device": ["0"], 13 | "seed": [1993], 14 | "logits_alignment": 1.7, 15 | "energy_weight": 0.01, 16 | "is_compress":false, 17 | "reduce_batch_size": false, 18 | "init_epochs": 200, 19 | "init_lr" : 0.1, 20 | "init_weight_decay" : 5e-4, 21 | "expansion_epochs" : 170, 22 | "fusion_epochs" : 60, 23 | "lr" : 0.1, 24 | "batch_size" : 128, 25 | "weight_decay" : 5e-4, 26 | "num_workers" : 8, 27 | "T" : 2 28 | } -------------------------------------------------------------------------------- /exps/foster.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "cil", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class": 20, 6 | "fixed_memory": true, 7 | "shuffle": true, 8 | "init_cls": 50, 9 | "increment": 10, 10 | "model_name": "foster", 11 | "convnet_type": "resnet32", 12 | "device": ["0"], 13 | "seed": [1993], 14 | "beta1":0.96, 15 | "beta2":0.97, 16 | "oofc":"ft", 17 | "is_teacher_wa":false, 18 | "is_student_wa":false, 19 | "lambda_okd":1, 20 | "wa_value":1, 21 | "init_epochs": 200, 22 | "init_lr" : 0.1, 23 | "init_weight_decay" : 5e-4, 24 | "boosting_epochs" : 170, 25 | "compression_epochs" : 130, 26 | "lr" : 0.1, 27 | "batch_size" : 128, 28 | "weight_decay" : 5e-4, 29 | "num_workers" : 8, 30 | "T" : 2 31 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from trainer import train 4 | 5 | 6 | def main(): 7 | args = setup_parser().parse_args() 8 | param = load_json(args.config) 9 | args = vars(args) # Converting argparse Namespace to a dict. 10 | args.update(param) # Add parameters from json 11 | 12 | train(args) 13 | 14 | 15 | def load_json(settings_path): 16 | with open(settings_path) as data_file: 17 | param = json.load(data_file) 18 | 19 | return param 20 | 21 | 22 | def setup_parser(): 23 | parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.') 24 | parser.add_argument('--config', type=str, default='./exps/finetune.json', 25 | help='Json file of settings.') 26 | 27 | return parser 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /exps/rmm-foster.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "rmm-foster", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "m_rate_list":[0.3, 0.3, 0.3, 0.4, 0.4, 0.4], 6 | "c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0], 7 | "shuffle": true, 8 | "init_cls": 50, 9 | "increment": 10, 10 | "model_name": "rmm-foster", 11 | "convnet_type": "resnet32", 12 | "device": ["0"], 13 | "seed": [1993], 14 | "beta1":0.97, 15 | "beta2":0.97, 16 | "oofc":"ft", 17 | "is_teacher_wa":false, 18 | "is_student_wa":false, 19 | "lambda_okd":1, 20 | "wa_value":1, 21 | "init_epochs": 200, 22 | "init_lr" : 0.1, 23 | "init_weight_decay" : 5e-4, 24 | "boosting_epochs" : 170, 25 | "compression_epochs" : 130, 26 | "lr" : 0.1, 27 | "batch_size" : 128, 28 | "weight_decay" : 5e-4, 29 | "num_workers" : 8, 30 | "T" : 2 31 | } 32 | -------------------------------------------------------------------------------- /exps/memo.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "benchmark", 3 | "dataset": "cifar100", 4 | "memory_size": 2000, 5 | "memory_per_class":20, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "memo", 11 | "convnet_type": "memo_resnet32", 12 | "train_base": true, 13 | "train_adaptive": false, 14 | "debug": false, 15 | "skip": false, 16 | "device": ["0", "1", "2", "3"], 17 | "seed":[1993], 18 | "scheduler": "steplr", 19 | "init_epoch": 200, 20 | "t_max": null, 21 | "init_lr" : 0.1, 22 | "init_weight_decay" : 5e-4, 23 | "init_lr_decay" : 0.1, 24 | "init_milestones" : [60,120,170], 25 | "milestones" : [80,120,150], 26 | "epochs": 170, 27 | "lrate" : 0.1, 28 | "batch_size" : 128, 29 | "weight_decay" : 2e-4, 30 | "lrate_decay" : 0.1, 31 | "alpha_aux" : 1.0 32 | } -------------------------------------------------------------------------------- /exps/acil.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "acil", 3 | "prefix": "ACIL", 4 | "memory_size": 0, 5 | 6 | "dataset": "cifar100", 7 | "seed": [1993], 8 | "shuffle": true, 9 | "device": ["0"], 10 | 11 | "convnet_type": "resnet32", 12 | "init_cls": 50, 13 | "increment": 5, 14 | 15 | "num_workers": 16, 16 | "init_batch_size": 128, 17 | "IL_batch_size": 4096, 18 | "inplace_repeat": 1, 19 | 20 | "configurations": { 21 | "cifar100": { 22 | "buffer_size": 8192, 23 | "gamma": 0.1, 24 | "init_weight_decay": 5e-4, 25 | "scheduler": { 26 | "type": "MultiStep", 27 | "init_lr": 0.1, 28 | "init_epochs": 160, 29 | "warmup": 0, 30 | "milestones": [120, 140], 31 | "decay": 0.1 32 | } 33 | }, 34 | "imagenet1000": { 35 | "buffer_size": 16384, 36 | "gamma": 0.1, 37 | "init_weight_decay": 5e-5, 38 | "scheduler": { 39 | "type": "MultiStep", 40 | "init_lr": 0.1, 41 | "init_epochs": 90, 42 | "warmup": 0, 43 | "milestones": [30, 60], 44 | "decay": 0.1 45 | } 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /resources/PR_policy.md: -------------------------------------------------------------------------------- 1 | # Things you need to do to get your PR merged: 2 | 3 | 1. **Make sure you have cited the toolbox paper in your camera-ready version.** For example, "All models are deployed with PyCIL~\cite{zhou2023pycil}". **Include the link to the paper in the pull request description.** 4 | 2. Make sure you understand the code structure and the training process defined in the toolbox. **All new methods should inherit from the base class in base.py.** 5 | 3. **Do not modify the readme file in your PR.** We will update the readme file after your PR is merged. 6 | 4. **Raise a pull request.** We will review your code and merge it if it meets the requirements. The code review process may take some time. Please be patient and do not urge the reviewers. 7 | 5. Send an email to the authors of the toolbox to inform them of your PR. **Include your name, position, affiliation, and the title of the paper you are submitting.** 8 | 9 | 10 | @article{zhou2023pycil, 11 | author = {Da-Wei Zhou and Fu-Yun Wang and Han-Jia Ye and De-Chuan Zhan}, 12 | title = {PyCIL: a Python toolbox for class-incremental learning}, 13 | journal = {SCIENCE CHINA Information Sciences}, 14 | year = {2023}, 15 | volume = {66}, 16 | number = {9}, 17 | pages = {197101}, 18 | doi = {https://doi.org/10.1007/s11432-022-3600-y} 19 | } 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /exps/ds-al.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "ds-al", 3 | "prefix": "DS-AL", 4 | "memory_size": 0, 5 | 6 | "dataset": "cifar100", 7 | "seed": [1993], 8 | "shuffle": true, 9 | "device": ["0"], 10 | 11 | "convnet_type": "resnet32", 12 | "init_cls": 50, 13 | "increment": 5, 14 | 15 | "num_workers": 16, 16 | "init_batch_size": 128, 17 | "IL_batch_size": 4096, 18 | "inplace_repeat": 1, 19 | 20 | "configurations": { 21 | "cifar100": { 22 | "buffer_size": 8192, 23 | "gamma": 0.1, 24 | "gamma_comp": 0.1, 25 | "compensation_ratio": 0.6, 26 | "init_weight_decay": 5e-4, 27 | "scheduler": { 28 | "type": "MultiStep", 29 | "init_lr": 0.1, 30 | "init_epochs": 160, 31 | "warmup": 0, 32 | "milestones": [120, 140], 33 | "decay": 0.1 34 | } 35 | }, 36 | "imagenet1000": { 37 | "buffer_size": 16384, 38 | "gamma": 0.1, 39 | "gamma_comp": 0.1, 40 | "compensation_ratio": 1.5, 41 | "init_weight_decay": 5e-5, 42 | "scheduler": { 43 | "type": "MultiStep", 44 | "init_lr": 0.1, 45 | "init_epochs": 90, 46 | "warmup": 0, 47 | "milestones": [30, 60], 48 | "decay": 0.1 49 | } 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /models/dsal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Proper implementation of the DS-AL [1]. 4 | 5 | This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning. 6 | 7 | References: 8 | [1] Zhuang, Huiping, et al. 9 | "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." 10 | Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. 11 | """ 12 | 13 | from .acil import ACIL 14 | from utils.inc_net import DSALNet 15 | 16 | 17 | class DSAL(ACIL): 18 | """ 19 | Training process of the DS-AL [1]. 20 | 21 | This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning. 22 | 23 | References: 24 | [1] Zhuang, Huiping, et al. 25 | "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." 26 | Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. 27 | """ 28 | 29 | def create_network(self) -> None: 30 | # We recommend using the grid search to find the best compensation ratio `C` in the interval [0, 2]. 31 | # The best value is 0.6 for the CIFAR-100, while the best value for the ImageNet-1k is 1.5. 32 | self._network = DSALNet( 33 | self.args, 34 | self.buffer_size, 35 | self.gamma, 36 | self.args["gamma_comp"], 37 | self.args["compensation_ratio"], 38 | ) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Changhong Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | MIT License 24 | 25 | Copyright (c) 2021 Fu-Yun Wang. 26 | 27 | Permission is hereby granted, free of charge, to any person obtaining a copy 28 | of this software and associated documentation files (the "Software"), to deal 29 | in the Software without restriction, including without limitation the rights 30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 31 | copies of the Software, and to permit persons to whom the Software is 32 | furnished to do so, subject to the following conditions: 33 | 34 | The above copyright notice and this permission notice shall be included in all 35 | copies or substantial portions of the Software. 36 | 37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 43 | SOFTWARE. 44 | -------------------------------------------------------------------------------- /convs/conv_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | For MEMO implementations of CIFAR-ConvNet 3 | Reference: 4 | https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # for cifar 11 | def conv_block(in_channels, out_channels): 12 | return nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 14 | nn.BatchNorm2d(out_channels), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2) 17 | ) 18 | 19 | class ConvNet2(nn.Module): 20 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64): 21 | super().__init__() 22 | self.out_dim = 64 23 | self.avgpool = nn.AvgPool2d(8) 24 | self.encoder = nn.Sequential( 25 | conv_block(x_dim, hid_dim), 26 | conv_block(hid_dim, z_dim), 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.encoder(x) 31 | x = self.avgpool(x) 32 | features = x.view(x.shape[0], -1) 33 | return { 34 | "features":features 35 | } 36 | 37 | class GeneralizedConvNet2(nn.Module): 38 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64): 39 | super().__init__() 40 | self.encoder = nn.Sequential( 41 | conv_block(x_dim, hid_dim), 42 | ) 43 | 44 | def forward(self, x): 45 | base_features = self.encoder(x) 46 | return base_features 47 | 48 | class SpecializedConvNet2(nn.Module): 49 | def __init__(self,hid_dim=64,z_dim=64): 50 | super().__init__() 51 | self.feature_dim = 64 52 | self.avgpool = nn.AvgPool2d(8) 53 | self.AdaptiveBlock = conv_block(hid_dim,z_dim) 54 | 55 | def forward(self,x): 56 | base_features = self.AdaptiveBlock(x) 57 | pooled = self.avgpool(base_features) 58 | features = pooled.view(pooled.size(0),-1) 59 | return features 60 | 61 | def conv2(): 62 | return ConvNet2() 63 | 64 | def get_conv_a2fc(): 65 | basenet = GeneralizedConvNet2() 66 | adaptivenet = SpecializedConvNet2() 67 | return basenet,adaptivenet 68 | 69 | if __name__ == '__main__': 70 | a, b = get_conv_a2fc() 71 | _base = sum(p.numel() for p in a.parameters()) 72 | _adap = sum(p.numel() for p in b.parameters()) 73 | print(f"conv :{_base+_adap}") 74 | 75 | conv2 = conv2() 76 | conv2_sum = sum(p.numel() for p in conv2.parameters()) 77 | print(f"conv2 :{conv2_sum}") -------------------------------------------------------------------------------- /convs/conv_imagenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | For MEMO implementations of ImageNet-ConvNet 3 | Reference: 4 | https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py 5 | ''' 6 | import torch.nn as nn 7 | import torch 8 | 9 | # for imagenet 10 | def first_block(in_channels, out_channels): 11 | return nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3), 13 | nn.BatchNorm2d(out_channels), 14 | nn.ReLU(), 15 | nn.MaxPool2d(2) 16 | ) 17 | 18 | def conv_block(in_channels, out_channels): 19 | return nn.Sequential( 20 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU(), 23 | nn.MaxPool2d(2) 24 | ) 25 | 26 | class ConvNet(nn.Module): 27 | def __init__(self, x_dim=3, hid_dim=128, z_dim=512): 28 | super().__init__() 29 | self.block1 = first_block(x_dim, hid_dim) 30 | self.block2 = conv_block(hid_dim, hid_dim) 31 | self.block3 = conv_block(hid_dim, hid_dim) 32 | self.block4 = conv_block(hid_dim, z_dim) 33 | self.avgpool = nn.AvgPool2d(7) 34 | self.out_dim = 512 35 | 36 | def forward(self, x): 37 | x = self.block1(x) 38 | x = self.block2(x) 39 | x = self.block3(x) 40 | x = self.block4(x) 41 | 42 | x = self.avgpool(x) 43 | features = x.view(x.shape[0], -1) 44 | 45 | return { 46 | "features": features 47 | } 48 | 49 | class GeneralizedConvNet(nn.Module): 50 | def __init__(self, x_dim=3, hid_dim=128, z_dim=512): 51 | super().__init__() 52 | self.block1 = first_block(x_dim, hid_dim) 53 | self.block2 = conv_block(hid_dim, hid_dim) 54 | self.block3 = conv_block(hid_dim, hid_dim) 55 | 56 | def forward(self, x): 57 | x = self.block1(x) 58 | x = self.block2(x) 59 | x = self.block3(x) 60 | return x 61 | 62 | class SpecializedConvNet(nn.Module): 63 | def __init__(self, hid_dim=128,z_dim=512): 64 | super().__init__() 65 | self.block4 = conv_block(hid_dim, z_dim) 66 | self.avgpool = nn.AvgPool2d(7) 67 | self.feature_dim = 512 68 | 69 | def forward(self, x): 70 | x = self.block4(x) 71 | x = self.avgpool(x) 72 | features = x.view(x.shape[0], -1) 73 | return features 74 | 75 | def conv4(): 76 | model = ConvNet() 77 | return model 78 | 79 | def conv_a2fc_imagenet(): 80 | _base = GeneralizedConvNet() 81 | _adaptive_net = SpecializedConvNet() 82 | return _base, _adaptive_net -------------------------------------------------------------------------------- /utils/factory.py: -------------------------------------------------------------------------------- 1 | def get_model(model_name, args): 2 | name = model_name.lower() 3 | if name == "icarl": 4 | from models.icarl import iCaRL 5 | return iCaRL(args) 6 | elif name == "bic": 7 | from models.bic import BiC 8 | return BiC(args) 9 | elif name == "podnet": 10 | from models.podnet import PODNet 11 | return PODNet(args) 12 | elif name == "lwf": 13 | from models.lwf import LwF 14 | return LwF(args) 15 | elif name == "ewc": 16 | from models.ewc import EWC 17 | return EWC(args) 18 | elif name == "wa": 19 | from models.wa import WA 20 | return WA(args) 21 | elif name == "der": 22 | from models.der import DER 23 | return DER(args) 24 | elif name == "finetune": 25 | from models.finetune import Finetune 26 | return Finetune(args) 27 | elif name == "replay": 28 | from models.replay import Replay 29 | return Replay(args) 30 | elif name == "gem": 31 | from models.gem import GEM 32 | return GEM(args) 33 | elif name == "coil": 34 | from models.coil import COIL 35 | return COIL(args) 36 | elif name == "foster": 37 | from models.foster import FOSTER 38 | return FOSTER(args) 39 | elif name == "rmm-icarl": 40 | from models.rmm import RMM_FOSTER, RMM_iCaRL 41 | return RMM_iCaRL(args) 42 | elif name == "rmm-foster": 43 | from models.rmm import RMM_FOSTER, RMM_iCaRL 44 | return RMM_FOSTER(args) 45 | elif name == "fetril": 46 | from models.fetril import FeTrIL 47 | return FeTrIL(args) 48 | elif name == "pass": 49 | from models.pa2s import PASS 50 | return PASS(args) 51 | elif name == "il2a": 52 | from models.il2a import IL2A 53 | return IL2A(args) 54 | elif name == "ssre": 55 | from models.ssre import SSRE 56 | return SSRE(args) 57 | elif name == "memo": 58 | from models.memo import MEMO 59 | return MEMO(args) 60 | elif name == "beefiso": 61 | from models.beef_iso import BEEFISO 62 | return BEEFISO(args) 63 | elif name == "simplecil": 64 | from models.simplecil import SimpleCIL 65 | return SimpleCIL(args) 66 | elif name == "acil": 67 | from models.acil import ACIL 68 | return ACIL(args) 69 | elif name == "ds-al": 70 | from models.dsal import DSAL 71 | return DSAL(args) 72 | elif name == "aper_finetune": 73 | from models.aper_finetune import APER_FINETUNE 74 | return APER_FINETUNE(args) 75 | elif name == "tagfex": 76 | from models.tagfex import TagFex 77 | return TagFex(args) 78 | else: 79 | assert 0 80 | -------------------------------------------------------------------------------- /convs/ACL_buffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Buffer layers for the analytic continual learning (ACL) [1-3]. 4 | 5 | This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning. 6 | 7 | References: 8 | [1] Zhuang, Huiping, et al. 9 | "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." 10 | Advances in Neural Information Processing Systems 35 (2022): 11602-11614. 11 | [2] Zhuang, Huiping, et al. 12 | "GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task." 13 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. 14 | [3] Zhuang, Huiping, et al. 15 | "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." 16 | Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. 17 | """ 18 | 19 | import torch 20 | from typing import Optional, Union, Callable 21 | from abc import ABCMeta, abstractmethod 22 | 23 | __all__ = [ 24 | "Buffer", 25 | "RandomBuffer", 26 | "activation_t", 27 | ] 28 | 29 | activation_t = Union[Callable[[torch.Tensor], torch.Tensor], torch.nn.Module] 30 | 31 | 32 | class Buffer(torch.nn.Module, metaclass=ABCMeta): 33 | def __init__(self) -> None: 34 | super().__init__() 35 | 36 | @abstractmethod 37 | def forward(self, X: torch.Tensor) -> torch.Tensor: 38 | raise NotImplementedError() 39 | 40 | 41 | class RandomBuffer(torch.nn.Linear, Buffer): 42 | """ 43 | Random buffer layer for the ACIL [1] and DS-AL [2]. 44 | 45 | This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning. 46 | 47 | References: 48 | [1] Zhuang, Huiping, et al. 49 | "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." 50 | Advances in Neural Information Processing Systems 35 (2022): 11602-11614. 51 | [2] Zhuang, Huiping, et al. 52 | "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." 53 | Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | in_features: int, 59 | out_features: int, 60 | bias: bool = False, 61 | device=None, 62 | dtype=torch.float, 63 | activation: Optional[activation_t] = torch.relu_, 64 | ) -> None: 65 | super(torch.nn.Linear, self).__init__() 66 | factory_kwargs = {"device": device, "dtype": dtype} 67 | self.in_features = in_features 68 | self.out_features = out_features 69 | self.activation: activation_t = ( 70 | torch.nn.Identity() if activation is None else activation 71 | ) 72 | 73 | W = torch.empty((out_features, in_features), **factory_kwargs) 74 | b = torch.empty(out_features, **factory_kwargs) if bias else None 75 | 76 | # Using buffer instead of parameter 77 | self.register_buffer("weight", W) 78 | self.register_buffer("bias", b) 79 | 80 | # Random Initialization 81 | self.reset_parameters() 82 | 83 | @torch.no_grad() 84 | def forward(self, X: torch.Tensor) -> torch.Tensor: 85 | X = X.to(self.weight) 86 | return self.activation(super().forward(X)) 87 | -------------------------------------------------------------------------------- /utils/toolkit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import json 5 | from enum import Enum 6 | 7 | class ConfigEncoder(json.JSONEncoder): 8 | def default(self, o): 9 | if isinstance(o, type): 10 | return {'$class': o.__module__ + "." + o.__name__} 11 | elif isinstance(o, Enum): 12 | return { 13 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name 14 | } 15 | elif callable(o): 16 | return { 17 | '$function': o.__module__ + "." + o.__name__ 18 | } 19 | return json.JSONEncoder.default(self, o) 20 | 21 | def count_parameters(model, trainable=False): 22 | if trainable: 23 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 24 | return sum(p.numel() for p in model.parameters()) 25 | 26 | 27 | def tensor2numpy(x): 28 | return x.cpu().data.numpy() if x.is_cuda else x.data.numpy() 29 | 30 | 31 | def target2onehot(targets, n_classes): 32 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) 33 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0) 34 | return onehot 35 | 36 | 37 | def makedirs(path): 38 | if not os.path.exists(path): 39 | os.makedirs(path) 40 | 41 | 42 | def accuracy(y_pred, y_true, nb_old, increment=10): 43 | assert len(y_pred) == len(y_true), "Data length error." 44 | all_acc = {} 45 | all_acc["total"] = np.around( 46 | (y_pred == y_true).sum() * 100 / len(y_true), decimals=2 47 | ) 48 | 49 | # Grouped accuracy 50 | for class_id in range(0, np.max(y_true), increment): 51 | idxes = np.where( 52 | np.logical_and(y_true >= class_id, y_true < class_id + increment) 53 | )[0] 54 | label = "{}-{}".format( 55 | str(class_id).rjust(2, "0"), str(class_id + increment - 1).rjust(2, "0") 56 | ) 57 | all_acc[label] = np.around( 58 | (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2 59 | ) 60 | 61 | # Old accuracy 62 | idxes = np.where(y_true < nb_old)[0] 63 | all_acc["old"] = ( 64 | 0 65 | if len(idxes) == 0 66 | else np.around( 67 | (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2 68 | ) 69 | ) 70 | 71 | # New accuracy 72 | idxes = np.where(y_true >= nb_old)[0] 73 | all_acc["new"] = np.around( 74 | (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2 75 | ) 76 | 77 | return all_acc 78 | 79 | 80 | def split_images_labels(imgs): 81 | # split trainset.imgs in ImageFolder 82 | images = [] 83 | labels = [] 84 | for item in imgs: 85 | images.append(item[0]) 86 | labels.append(item[1]) 87 | 88 | return np.array(images), np.array(labels) 89 | 90 | def save_fc(args, model): 91 | _path = os.path.join(args['logfilename'], "fc.pt") 92 | if len(args['device']) > 1: 93 | fc_weight = model._network.fc.weight.data 94 | else: 95 | fc_weight = model._network.fc.weight.data.cpu() 96 | torch.save(fc_weight, _path) 97 | 98 | _save_dir = os.path.join(f"./results/fc_weights/{args['prefix']}") 99 | os.makedirs(_save_dir, exist_ok=True) 100 | _save_path = os.path.join(_save_dir, f"{args['csv_name']}.csv") 101 | with open(_save_path, "a+") as f: 102 | f.write(f"{args['time_str']},{args['model_name']},{_path} \n") 103 | 104 | def save_model(args, model): 105 | #used in PODNet 106 | _path = os.path.join(args['logfilename'], "model.pt") 107 | if len(args['device']) > 1: 108 | weight = model._network 109 | else: 110 | weight = model._network.cpu() 111 | torch.save(weight, _path) -------------------------------------------------------------------------------- /utils/ops.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import random 3 | import torch 4 | import numpy as np 5 | class Cutout(object): 6 | def __init__(self, n_holes, length): 7 | self.n_holes = n_holes 8 | self.length = length 9 | 10 | def __call__(self, img): 11 | h = img.size(1) 12 | w = img.size(2) 13 | 14 | mask = np.ones((h, w), np.float32) 15 | 16 | for n in range(self.n_holes): 17 | y = np.random.randint(h) 18 | x = np.random.randint(w) 19 | 20 | y1 = np.clip(y - self.length // 2, 0, h) 21 | y2 = np.clip(y + self.length // 2, 0, h) 22 | x1 = np.clip(x - self.length // 2, 0, w) 23 | x2 = np.clip(x + self.length // 2, 0, w) 24 | 25 | mask[y1: y2, x1: x2] = 0. 26 | 27 | mask = torch.from_numpy(mask) 28 | mask = mask.expand_as(img) 29 | img = img * mask 30 | 31 | return img 32 | 33 | class ShearX(object): 34 | def __init__(self, fillcolor=(128, 128, 128)): 35 | self.fillcolor = fillcolor 36 | 37 | def __call__(self, x, magnitude): 38 | return x.transform( 39 | x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 40 | Image.BICUBIC, fillcolor=self.fillcolor) 41 | 42 | 43 | class ShearY(object): 44 | def __init__(self, fillcolor=(128, 128, 128)): 45 | self.fillcolor = fillcolor 46 | 47 | def __call__(self, x, magnitude): 48 | return x.transform( 49 | x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 50 | Image.BICUBIC, fillcolor=self.fillcolor) 51 | 52 | 53 | class TranslateX(object): 54 | def __init__(self, fillcolor=(128, 128, 128)): 55 | self.fillcolor = fillcolor 56 | 57 | def __call__(self, x, magnitude): 58 | return x.transform( 59 | x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0), 60 | fillcolor=self.fillcolor) 61 | 62 | 63 | class TranslateY(object): 64 | def __init__(self, fillcolor=(128, 128, 128)): 65 | self.fillcolor = fillcolor 66 | 67 | def __call__(self, x, magnitude): 68 | return x.transform( 69 | x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])), 70 | fillcolor=self.fillcolor) 71 | 72 | 73 | class Rotate(object): 74 | def __call__(self, x, magnitude): 75 | rot = x.convert("RGBA").rotate(magnitude * random.choice([-1, 1])) 76 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode) 77 | 78 | 79 | class Color(object): 80 | def __call__(self, x, magnitude): 81 | return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1])) 82 | 83 | 84 | class Posterize(object): 85 | def __call__(self, x, magnitude): 86 | return ImageOps.posterize(x, magnitude) 87 | 88 | 89 | class Solarize(object): 90 | def __call__(self, x, magnitude): 91 | return ImageOps.solarize(x, magnitude) 92 | 93 | 94 | class Contrast(object): 95 | def __call__(self, x, magnitude): 96 | return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1])) 97 | 98 | 99 | class Sharpness(object): 100 | def __call__(self, x, magnitude): 101 | return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1])) 102 | 103 | 104 | class Brightness(object): 105 | def __call__(self, x, magnitude): 106 | return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1])) 107 | 108 | 109 | class AutoContrast(object): 110 | def __call__(self, x, magnitude): 111 | return ImageOps.autocontrast(x) 112 | 113 | 114 | class Equalize(object): 115 | def __call__(self, x, magnitude): 116 | return ImageOps.equalize(x) 117 | 118 | 119 | class Invert(object): 120 | def __call__(self, x, magnitude): 121 | return ImageOps.invert(x) 122 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from utils.toolkit import split_images_labels 4 | from . import autoaugment 5 | from . import ops 6 | 7 | class iData(object): 8 | train_trsf = [] 9 | test_trsf = [] 10 | common_trsf = [] 11 | class_order = None 12 | 13 | 14 | class iCIFAR10(iData): 15 | use_path = False 16 | train_trsf = [ 17 | transforms.RandomCrop(32, padding=4), 18 | transforms.RandomHorizontalFlip(p=0.5), 19 | transforms.ColorJitter(brightness=63 / 255), 20 | transforms.ToTensor(), 21 | ] 22 | test_trsf = [transforms.ToTensor()] 23 | common_trsf = [ 24 | transforms.Normalize( 25 | mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) 26 | ), 27 | ] 28 | 29 | class_order = np.arange(10).tolist() 30 | 31 | def download_data(self): 32 | train_dataset = datasets.cifar.CIFAR10("./data", train=True, download=True) 33 | test_dataset = datasets.cifar.CIFAR10("./data", train=False, download=True) 34 | self.train_data, self.train_targets = train_dataset.data, np.array( 35 | train_dataset.targets 36 | ) 37 | self.test_data, self.test_targets = test_dataset.data, np.array( 38 | test_dataset.targets 39 | ) 40 | 41 | 42 | class iCIFAR100(iData): 43 | use_path = False 44 | train_trsf = [ 45 | transforms.RandomCrop(32, padding=4), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ColorJitter(brightness=63 / 255), 48 | transforms.ToTensor() 49 | ] 50 | test_trsf = [transforms.ToTensor()] 51 | common_trsf = [ 52 | transforms.Normalize( 53 | mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) 54 | ), 55 | ] 56 | 57 | class_order = np.arange(100).tolist() 58 | 59 | def download_data(self): 60 | train_dataset = datasets.cifar.CIFAR100("./data", train=True, download=True) 61 | test_dataset = datasets.cifar.CIFAR100("./data", train=False, download=True) 62 | self.train_data, self.train_targets = train_dataset.data, np.array( 63 | train_dataset.targets 64 | ) 65 | self.test_data, self.test_targets = test_dataset.data, np.array( 66 | test_dataset.targets 67 | ) 68 | 69 | 70 | class iCIFAR100_AA(iCIFAR100): 71 | train_trsf = [ 72 | transforms.RandomCrop(32, padding=4), 73 | transforms.RandomHorizontalFlip(p=0.5), 74 | transforms.ColorJitter(brightness=63 / 255), 75 | autoaugment.CIFAR10Policy(), 76 | transforms.ToTensor(), 77 | ops.Cutout(n_holes=1, length=16), 78 | ] 79 | 80 | 81 | class iCIFAR10_AA(iCIFAR10): 82 | train_trsf = [ 83 | transforms.RandomCrop(32, padding=4), 84 | transforms.RandomHorizontalFlip(p=0.5), 85 | transforms.ColorJitter(brightness=63 / 255), 86 | autoaugment.CIFAR10Policy(), 87 | transforms.ToTensor(), 88 | ops.Cutout(n_holes=1, length=16), 89 | ] 90 | 91 | 92 | class iImageNet1000(iData): 93 | use_path = True 94 | train_trsf = [ 95 | transforms.RandomResizedCrop(224), 96 | transforms.RandomHorizontalFlip(), 97 | transforms.ColorJitter(brightness=63 / 255), 98 | ] 99 | test_trsf = [ 100 | transforms.Resize(256), 101 | transforms.CenterCrop(224), 102 | ] 103 | common_trsf = [ 104 | transforms.ToTensor(), 105 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 106 | ] 107 | 108 | class_order = np.arange(1000).tolist() 109 | 110 | def download_data(self): 111 | assert 0, "You should specify the folder of your dataset" 112 | train_dir = "[DATA-PATH]/train/" 113 | test_dir = "[DATA-PATH]/val/" 114 | 115 | train_dset = datasets.ImageFolder(train_dir) 116 | test_dset = datasets.ImageFolder(test_dir) 117 | 118 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 119 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 120 | 121 | 122 | class iImageNet100(iData): 123 | use_path = True 124 | train_trsf = [ 125 | transforms.RandomResizedCrop(224), 126 | transforms.RandomHorizontalFlip(), 127 | ] 128 | test_trsf = [ 129 | transforms.Resize(256), 130 | transforms.CenterCrop(224), 131 | ] 132 | common_trsf = [ 133 | transforms.ToTensor(), 134 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 135 | ] 136 | 137 | class_order = np.arange(1000).tolist() 138 | 139 | def download_data(self): 140 | assert 0, "You should specify the folder of your dataset" 141 | train_dir = "[DATA-PATH]/train/" 142 | test_dir = "[DATA-PATH]/val/" 143 | 144 | train_dset = datasets.ImageFolder(train_dir) 145 | test_dset = datasets.ImageFolder(test_dir) 146 | 147 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 148 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 149 | -------------------------------------------------------------------------------- /convs/memo_cifar_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | For MEMO implementations of CIFAR-ResNet 3 | Reference: 4 | https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py 5 | ''' 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class DownsampleA(nn.Module): 13 | def __init__(self, nIn, nOut, stride): 14 | super(DownsampleA, self).__init__() 15 | assert stride == 2 16 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 17 | 18 | def forward(self, x): 19 | x = self.avg(x) 20 | return torch.cat((x, x.mul(0)), 1) 21 | 22 | class ResNetBasicblock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(ResNetBasicblock, self).__init__() 27 | 28 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 29 | self.bn_a = nn.BatchNorm2d(planes) 30 | 31 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 32 | self.bn_b = nn.BatchNorm2d(planes) 33 | 34 | self.downsample = downsample 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | basicblock = self.conv_a(x) 40 | basicblock = self.bn_a(basicblock) 41 | basicblock = F.relu(basicblock, inplace=True) 42 | 43 | basicblock = self.conv_b(basicblock) 44 | basicblock = self.bn_b(basicblock) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | return F.relu(residual + basicblock, inplace=True) 50 | 51 | 52 | 53 | class GeneralizedResNet_cifar(nn.Module): 54 | def __init__(self, block, depth, channels=3): 55 | super(GeneralizedResNet_cifar, self).__init__() 56 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 57 | layer_blocks = (depth - 2) // 6 58 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 59 | self.bn_1 = nn.BatchNorm2d(16) 60 | 61 | self.inplanes = 16 62 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 63 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 64 | 65 | self.out_dim = 64 * block.expansion 66 | 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 70 | m.weight.data.normal_(0, math.sqrt(2. / n)) 71 | # m.bias.data.zero_() 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.weight.data.fill_(1) 74 | m.bias.data.zero_() 75 | elif isinstance(m, nn.Linear): 76 | nn.init.kaiming_normal_(m.weight) 77 | m.bias.data.zero_() 78 | 79 | def _make_layer(self, block, planes, blocks, stride=1): 80 | downsample = None 81 | if stride != 1 or self.inplanes != planes * block.expansion: 82 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 83 | 84 | layers = [] 85 | layers.append(block(self.inplanes, planes, stride, downsample)) 86 | self.inplanes = planes * block.expansion 87 | for i in range(1, blocks): 88 | layers.append(block(self.inplanes, planes)) 89 | 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | x = self.conv_1_3x3(x) # [bs, 16, 32, 32] 94 | x = F.relu(self.bn_1(x), inplace=True) 95 | 96 | x_1 = self.stage_1(x) # [bs, 16, 32, 32] 97 | x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] 98 | return x_2 99 | 100 | class SpecializedResNet_cifar(nn.Module): 101 | def __init__(self, block, depth, inplanes=32, feature_dim=64): 102 | super(SpecializedResNet_cifar, self).__init__() 103 | self.inplanes = inplanes 104 | self.feature_dim = feature_dim 105 | layer_blocks = (depth - 2) // 6 106 | self.final_stage = self._make_layer(block, 64, layer_blocks, 2) 107 | self.avgpool = nn.AvgPool2d(8) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | # m.bias.data.zero_() 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | elif isinstance(m, nn.Linear): 118 | nn.init.kaiming_normal_(m.weight) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=2): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, base_feature_map): 133 | final_feature_map = self.final_stage(base_feature_map) 134 | pooled = self.avgpool(final_feature_map) 135 | features = pooled.view(pooled.size(0), -1) #bs x 64 136 | return features 137 | 138 | #For cifar & MEMO 139 | def get_resnet8_a2fc(): 140 | basenet = GeneralizedResNet_cifar(ResNetBasicblock,8) 141 | adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,8) 142 | return basenet,adaptivenet 143 | 144 | def get_resnet14_a2fc(): 145 | basenet = GeneralizedResNet_cifar(ResNetBasicblock,14) 146 | adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,14) 147 | return basenet,adaptivenet 148 | 149 | def get_resnet20_a2fc(): 150 | basenet = GeneralizedResNet_cifar(ResNetBasicblock,20) 151 | adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,20) 152 | return basenet,adaptivenet 153 | 154 | def get_resnet26_a2fc(): 155 | basenet = GeneralizedResNet_cifar(ResNetBasicblock,26) 156 | adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,26) 157 | return basenet,adaptivenet 158 | 159 | def get_resnet32_a2fc(): 160 | basenet = GeneralizedResNet_cifar(ResNetBasicblock,32) 161 | adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,32) 162 | return basenet,adaptivenet 163 | 164 | 165 | -------------------------------------------------------------------------------- /models/simplecil.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Re-implementation of SimpleCIL (https://arxiv.org/abs/2303.07338) without pre-trained weights. 3 | The training process is as follows: train the model with cross-entropy in the first stage and replace the classifier with prototypes for all the classes in the subsequent stages. 4 | Please refer to the original implementation (https://github.com/zhoudw-zdw/RevisitingCIL) if you are using pre-trained weights. 5 | ''' 6 | import logging 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | from torch.serialization import load 11 | from tqdm import tqdm 12 | from torch import optim 13 | from torch.nn import functional as F 14 | from torch.utils.data import DataLoader 15 | from utils.inc_net import SimpleCosineIncrementalNet 16 | from models.base import BaseLearner 17 | from utils.toolkit import target2onehot, tensor2numpy 18 | 19 | 20 | num_workers = 8 21 | batch_size = 128 22 | milestones = [80, 120] 23 | 24 | class SimpleCIL(BaseLearner): 25 | def __init__(self, args): 26 | super().__init__(args) 27 | self._network = SimpleCosineIncrementalNet(args, False) 28 | self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8 29 | self.args = args 30 | 31 | def after_task(self): 32 | self._known_classes = self._total_classes 33 | 34 | def replace_fc(self,trainloader, model, args): 35 | model = model.eval() 36 | embedding_list = [] 37 | label_list = [] 38 | with torch.no_grad(): 39 | for i, batch in enumerate(trainloader): 40 | (_,data,label) = batch 41 | data = data.cuda() 42 | label = label.cuda() 43 | embedding = model(data)["features"] 44 | embedding_list.append(embedding.cpu()) 45 | label_list.append(label.cpu()) 46 | embedding_list = torch.cat(embedding_list, dim=0) 47 | label_list = torch.cat(label_list, dim=0) 48 | 49 | class_list = np.unique(self.train_dataset.labels) 50 | proto_list = [] 51 | for class_index in class_list: 52 | # print('Replacing...',class_index) 53 | data_index = (label_list == class_index).nonzero().squeeze(-1) 54 | embedding = embedding_list[data_index] 55 | proto = embedding.mean(0) 56 | self._network.fc.weight.data[class_index] = proto 57 | return model 58 | 59 | def incremental_train(self, data_manager): 60 | self._cur_task += 1 61 | self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) 62 | self._network.update_fc(self._total_classes) 63 | logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes)) 64 | 65 | train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train", ) 66 | self.train_dataset = train_dataset 67 | self.data_manager = data_manager 68 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 69 | test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" ) 70 | self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 71 | 72 | train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="test", ) 73 | self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=batch_size, shuffle=True, num_workers=num_workers) 74 | 75 | if len(self._multiple_gpus) > 1: 76 | print('Multiple GPUs') 77 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 78 | self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet) 79 | if len(self._multiple_gpus) > 1: 80 | self._network = self._network.module 81 | 82 | def _train(self, train_loader, test_loader, train_loader_for_protonet): 83 | self._network.to(self._device) 84 | if self._cur_task == 0: 85 | optimizer = optim.SGD( 86 | self._network.parameters(), 87 | momentum=0.9, 88 | lr=self.args["init_lr"], 89 | weight_decay=self.args["init_weight_decay"] 90 | ) 91 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 92 | optimizer=optimizer, T_max=self.args['init_epoch'], eta_min=self.min_lr 93 | ) 94 | self._init_train(train_loader, test_loader, optimizer, scheduler) 95 | self.replace_fc(train_loader_for_protonet, self._network, None) 96 | 97 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 98 | prog_bar = tqdm(range(self.args["init_epoch"])) 99 | for _, epoch in enumerate(prog_bar): 100 | self._network.train() 101 | losses = 0.0 102 | correct, total = 0, 0 103 | for i, (_, inputs, targets) in enumerate(train_loader): 104 | inputs, targets = inputs.to(self._device), targets.to(self._device) 105 | logits = self._network(inputs)["logits"] 106 | 107 | loss = F.cross_entropy(logits, targets) 108 | optimizer.zero_grad() 109 | loss.backward() 110 | optimizer.step() 111 | losses += loss.item() 112 | 113 | _, preds = torch.max(logits, dim=1) 114 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 115 | total += len(targets) 116 | 117 | scheduler.step() 118 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 119 | 120 | if epoch % 5 == 0: 121 | test_acc = self._compute_accuracy(self._network, test_loader) 122 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 123 | self._cur_task, 124 | epoch + 1, 125 | self.args['init_epoch'], 126 | losses / len(train_loader), 127 | train_acc, 128 | test_acc, 129 | ) 130 | else: 131 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 132 | self._cur_task, 133 | epoch + 1, 134 | self.args['init_epoch'], 135 | losses / len(train_loader), 136 | train_acc, 137 | ) 138 | 139 | prog_bar.set_description(info) 140 | 141 | logging.info(info) 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import copy 4 | import torch 5 | from utils import factory 6 | from utils.data_manager import DataManager 7 | from utils.toolkit import count_parameters 8 | import os 9 | import numpy as np 10 | 11 | 12 | def train(args): 13 | seed_list = copy.deepcopy(args["seed"]) 14 | device = copy.deepcopy(args["device"]) 15 | 16 | for seed in seed_list: 17 | args["seed"] = seed 18 | args["device"] = device 19 | _train(args) 20 | 21 | 22 | def _train(args): 23 | 24 | init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"] 25 | logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment']) 26 | 27 | if not os.path.exists(logs_name): 28 | os.makedirs(logs_name) 29 | 30 | logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format( 31 | args["model_name"], 32 | args["dataset"], 33 | init_cls, 34 | args["increment"], 35 | args["prefix"], 36 | args["seed"], 37 | args["convnet_type"], 38 | ) 39 | logging.basicConfig( 40 | level=logging.INFO, 41 | format="%(asctime)s [%(filename)s] => %(message)s", 42 | handlers=[ 43 | logging.FileHandler(filename=logfilename + ".log"), 44 | logging.StreamHandler(sys.stdout), 45 | ], 46 | ) 47 | 48 | _set_random() 49 | _set_device(args) 50 | print_args(args) 51 | data_manager = DataManager( 52 | args["dataset"], 53 | args["shuffle"], 54 | args["seed"], 55 | args["init_cls"], 56 | args["increment"], 57 | args["aug"] if "aug" in args else 1 58 | ) 59 | model = factory.get_model(args["model_name"], args) 60 | 61 | cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []} 62 | cnn_matrix, nme_matrix = [], [] 63 | 64 | for task in range(data_manager.nb_tasks): 65 | logging.info("All params: {}".format(count_parameters(model._network))) 66 | logging.info( 67 | "Trainable params: {}".format(count_parameters(model._network, True)) 68 | ) 69 | model.incremental_train(data_manager) 70 | cnn_accy, nme_accy = model.eval_task() 71 | model.after_task() 72 | 73 | if nme_accy is not None: 74 | logging.info("CNN: {}".format(cnn_accy["grouped"])) 75 | logging.info("NME: {}".format(nme_accy["grouped"])) 76 | 77 | cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] 78 | cnn_keys_sorted = sorted(cnn_keys) 79 | cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] 80 | cnn_matrix.append(cnn_values) 81 | 82 | nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key] 83 | nme_keys_sorted = sorted(nme_keys) 84 | nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted] 85 | nme_matrix.append(nme_values) 86 | 87 | 88 | cnn_curve["top1"].append(cnn_accy["top1"]) 89 | cnn_curve["top5"].append(cnn_accy["top5"]) 90 | 91 | nme_curve["top1"].append(nme_accy["top1"]) 92 | nme_curve["top5"].append(nme_accy["top5"]) 93 | 94 | logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) 95 | logging.info("CNN top5 curve: {}".format(cnn_curve["top5"])) 96 | logging.info("NME top1 curve: {}".format(nme_curve["top1"])) 97 | logging.info("NME top5 curve: {}\n".format(nme_curve["top5"])) 98 | 99 | print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) 100 | print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"])) 101 | 102 | logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) 103 | logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"]))) 104 | else: 105 | logging.info("No NME accuracy.") 106 | logging.info("CNN: {}".format(cnn_accy["grouped"])) 107 | 108 | cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] 109 | cnn_keys_sorted = sorted(cnn_keys) 110 | cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] 111 | cnn_matrix.append(cnn_values) 112 | 113 | cnn_curve["top1"].append(cnn_accy["top1"]) 114 | cnn_curve["top5"].append(cnn_accy["top5"]) 115 | 116 | logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) 117 | logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"])) 118 | 119 | print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) 120 | logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) 121 | 122 | 123 | if len(cnn_matrix)>0: 124 | np_acctable = np.zeros([task + 1, task + 1]) 125 | for idxx, line in enumerate(cnn_matrix): 126 | idxy = len(line) 127 | np_acctable[idxx, :idxy] = np.array(line) 128 | np_acctable = np_acctable.T 129 | forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task]) 130 | print('Accuracy Matrix (CNN):') 131 | print(np_acctable) 132 | print('Forgetting (CNN):', forgetting) 133 | logging.info('Forgetting (CNN): {}'.format(forgetting)) 134 | if len(nme_matrix)>0: 135 | np_acctable = np.zeros([task + 1, task + 1]) 136 | for idxx, line in enumerate(nme_matrix): 137 | idxy = len(line) 138 | np_acctable[idxx, :idxy] = np.array(line) 139 | np_acctable = np_acctable.T 140 | forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task]) 141 | print('Accuracy Matrix (NME):') 142 | print(np_acctable) 143 | print('Forgetting (NME):', forgetting) 144 | logging.info('Forgetting (NME): {}'.format(forgetting)) 145 | 146 | 147 | def _set_device(args): 148 | device_type = args["device"] 149 | gpus = [] 150 | 151 | for device in device_type: 152 | if device == -1: 153 | device = torch.device("cpu") 154 | else: 155 | device = torch.device("cuda:{}".format(device)) 156 | 157 | gpus.append(device) 158 | 159 | args["device"] = gpus 160 | 161 | 162 | def _set_random(): 163 | torch.manual_seed(1) 164 | torch.cuda.manual_seed(1) 165 | torch.cuda.manual_seed_all(1) 166 | torch.backends.cudnn.deterministic = True 167 | torch.backends.cudnn.benchmark = False 168 | 169 | 170 | def print_args(args): 171 | for key, value in args.items(): 172 | logging.info("{}: {}".format(key, value)) 173 | -------------------------------------------------------------------------------- /convs/modified_represnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['ResNet', 'resnet18_rep', 'resnet34_rep' ] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=True) 14 | 15 | 16 | def conv1x1(in_planes, out_planes, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True) 19 | 20 | class conv_block(nn.Module): 21 | 22 | def __init__(self, in_planes, planes, mode, stride=1): 23 | super(conv_block, self).__init__() 24 | self.conv = conv3x3(in_planes, planes, stride) 25 | self.mode = mode 26 | if mode == 'parallel_adapters': 27 | self.adapter = conv1x1(in_planes, planes, stride) 28 | 29 | 30 | def re_init_conv(self): 31 | nn.init.kaiming_normal_(self.adapter.weight, mode='fan_out', nonlinearity='relu') 32 | return 33 | def forward(self, x): 34 | y = self.conv(x) 35 | if self.mode == 'parallel_adapters': 36 | y = y + self.adapter(x) 37 | 38 | return y 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, mode, stride=1, downsample=None): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = conv_block(inplanes, planes, mode, stride) 47 | self.norm1 = nn.BatchNorm2d(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv_block(planes, planes, mode) 50 | self.norm2 = nn.BatchNorm2d(planes) 51 | self.mode = mode 52 | 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | out = self.conv1(x) 59 | out = self.norm1(out) 60 | out = self.relu(out) 61 | out = self.conv2(out) 62 | out = self.norm2(out) 63 | if self.downsample is not None: 64 | residual = self.downsample(x) 65 | out += residual 66 | out = self.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | 72 | def __init__(self, block, layers, num_classes=100, args = None): 73 | self.inplanes = 64 74 | super(ResNet, self).__init__() 75 | assert args is not None 76 | self.mode = args["mode"] 77 | 78 | if 'cifar' in args["dataset"]: 79 | self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 80 | nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) 81 | print("use cifar") 82 | elif 'imagenet' in args["dataset"]: 83 | if args["init_cls"] == args["increment"]: 84 | self.conv1 = nn.Sequential( 85 | nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), 86 | nn.BatchNorm2d(self.inplanes), 87 | nn.ReLU(inplace=True), 88 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 89 | ) 90 | else: 91 | # Following PODNET implmentation 92 | self.conv1 = nn.Sequential( 93 | nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 94 | nn.BatchNorm2d(self.inplanes), 95 | nn.ReLU(inplace=True), 96 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 97 | ) 98 | 99 | 100 | self.layer1 = self._make_layer(block, 64, layers[0]) 101 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 102 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 103 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 104 | self.feature = nn.AvgPool2d(4, stride=1) 105 | self.out_dim = 512 106 | 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 111 | elif isinstance(m, nn.BatchNorm2d): 112 | nn.init.constant_(m.weight, 1) 113 | nn.init.constant_(m.bias, 0) 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=True), 121 | ) 122 | layers = [] 123 | layers.append(block(self.inplanes, planes, self.mode, stride, downsample)) 124 | self.inplanes = planes * block.expansion 125 | for i in range(1, blocks): 126 | layers.append(block(self.inplanes, planes, self.mode)) 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def switch(self, mode='normal'): 131 | for name, module in self.named_modules(): 132 | if hasattr(module, 'mode'): 133 | module.mode = mode 134 | def re_init_params(self): 135 | for name, module in self.named_modules(): 136 | if hasattr(module, 're_init_conv'): 137 | module.re_init_conv() 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | dim = x.size()[-1] 146 | pool = nn.AvgPool2d(dim, stride=1) 147 | x = pool(x) 148 | x = x.view(x.size(0), -1) 149 | return {"features": x} 150 | 151 | 152 | def resnet18_rep(pretrained=False, **kwargs): 153 | """Constructs a ResNet-18 model. 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 158 | if pretrained: 159 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) 160 | now_state_dict = model.state_dict() 161 | now_state_dict.update(pretrained_state_dict) 162 | model.load_state_dict(now_state_dict) 163 | return model 164 | 165 | 166 | def resnet34_rep(pretrained=False, **kwargs): 167 | """Constructs a ResNet-34 model. 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 172 | if pretrained: 173 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) 174 | now_state_dict = model.state_dict() 175 | now_state_dict.update(pretrained_state_dict) 176 | model.load_state_dict(now_state_dict) 177 | return model -------------------------------------------------------------------------------- /convs/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: 3 | https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py 4 | ''' 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class DownsampleA(nn.Module): 13 | def __init__(self, nIn, nOut, stride): 14 | super(DownsampleA, self).__init__() 15 | assert stride == 2 16 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 17 | 18 | def forward(self, x): 19 | x = self.avg(x) 20 | return torch.cat((x, x.mul(0)), 1) 21 | 22 | 23 | class DownsampleB(nn.Module): 24 | def __init__(self, nIn, nOut, stride): 25 | super(DownsampleB, self).__init__() 26 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 27 | self.bn = nn.BatchNorm2d(nOut) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | return x 33 | 34 | 35 | class DownsampleC(nn.Module): 36 | def __init__(self, nIn, nOut, stride): 37 | super(DownsampleC, self).__init__() 38 | assert stride != 1 or nIn != nOut 39 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | return x 44 | 45 | 46 | class DownsampleD(nn.Module): 47 | def __init__(self, nIn, nOut, stride): 48 | super(DownsampleD, self).__init__() 49 | assert stride == 2 50 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) 51 | self.bn = nn.BatchNorm2d(nOut) 52 | 53 | def forward(self, x): 54 | x = self.conv(x) 55 | x = self.bn(x) 56 | return x 57 | 58 | 59 | class ResNetBasicblock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(ResNetBasicblock, self).__init__() 64 | 65 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 66 | self.bn_a = nn.BatchNorm2d(planes) 67 | 68 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 69 | self.bn_b = nn.BatchNorm2d(planes) 70 | 71 | self.downsample = downsample 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | basicblock = self.conv_a(x) 77 | basicblock = self.bn_a(basicblock) 78 | basicblock = F.relu(basicblock, inplace=True) 79 | 80 | basicblock = self.conv_b(basicblock) 81 | basicblock = self.bn_b(basicblock) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | return F.relu(residual + basicblock, inplace=True) 87 | 88 | 89 | class CifarResNet(nn.Module): 90 | """ 91 | ResNet optimized for the Cifar Dataset, as specified in 92 | https://arxiv.org/abs/1512.03385.pdf 93 | """ 94 | 95 | def __init__(self, block, depth, channels=3): 96 | super(CifarResNet, self).__init__() 97 | 98 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 99 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 100 | layer_blocks = (depth - 2) // 6 101 | 102 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 103 | self.bn_1 = nn.BatchNorm2d(16) 104 | 105 | self.inplanes = 16 106 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 107 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 108 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 109 | self.avgpool = nn.AvgPool2d(8) 110 | self.out_dim = 64 * block.expansion 111 | self.fc = nn.Linear(64*block.expansion, 10) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | # m.bias.data.zero_() 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.Linear): 122 | nn.init.kaiming_normal_(m.weight) 123 | m.bias.data.zero_() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv_1_3x3(x) # [bs, 16, 32, 32] 140 | x = F.relu(self.bn_1(x), inplace=True) 141 | 142 | x_1 = self.stage_1(x) # [bs, 16, 32, 32] 143 | x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] 144 | x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] 145 | 146 | pooled = self.avgpool(x_3) # [bs, 64, 1, 1] 147 | features = pooled.view(pooled.size(0), -1) # [bs, 64] 148 | 149 | return { 150 | 'fmaps': [x_1, x_2, x_3], 151 | 'features': features 152 | } 153 | 154 | @property 155 | def last_conv(self): 156 | return self.stage_3[-1].conv_b 157 | 158 | 159 | def resnet20mnist(): 160 | """Constructs a ResNet-20 model for MNIST.""" 161 | model = CifarResNet(ResNetBasicblock, 20, 1) 162 | return model 163 | 164 | 165 | def resnet32mnist(): 166 | """Constructs a ResNet-32 model for MNIST.""" 167 | model = CifarResNet(ResNetBasicblock, 32, 1) 168 | return model 169 | 170 | 171 | def resnet20(): 172 | """Constructs a ResNet-20 model for CIFAR-10.""" 173 | model = CifarResNet(ResNetBasicblock, 20) 174 | return model 175 | 176 | 177 | def resnet32(): 178 | """Constructs a ResNet-32 model for CIFAR-10.""" 179 | model = CifarResNet(ResNetBasicblock, 32) 180 | return model 181 | 182 | 183 | def resnet44(): 184 | """Constructs a ResNet-44 model for CIFAR-10.""" 185 | model = CifarResNet(ResNetBasicblock, 44) 186 | return model 187 | 188 | 189 | def resnet56(): 190 | """Constructs a ResNet-56 model for CIFAR-10.""" 191 | model = CifarResNet(ResNetBasicblock, 56) 192 | return model 193 | 194 | 195 | def resnet110(): 196 | """Constructs a ResNet-110 model for CIFAR-10.""" 197 | model = CifarResNet(ResNetBasicblock, 110) 198 | return model 199 | 200 | # for auc 201 | def resnet14(): 202 | model = CifarResNet(ResNetBasicblock, 14) 203 | return model 204 | 205 | def resnet26(): 206 | model = CifarResNet(ResNetBasicblock, 26) 207 | return model -------------------------------------------------------------------------------- /convs/ucir_cifar_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: 3 | https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py 4 | https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_resnet_cifar.py 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | # from convs.modified_linear import CosineLinear 10 | 11 | 12 | class DownsampleA(nn.Module): 13 | def __init__(self, nIn, nOut, stride): 14 | super(DownsampleA, self).__init__() 15 | assert stride == 2 16 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 17 | 18 | def forward(self, x): 19 | x = self.avg(x) 20 | return torch.cat((x, x.mul(0)), 1) 21 | 22 | 23 | class DownsampleB(nn.Module): 24 | def __init__(self, nIn, nOut, stride): 25 | super(DownsampleB, self).__init__() 26 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 27 | self.bn = nn.BatchNorm2d(nOut) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | return x 33 | 34 | 35 | class DownsampleC(nn.Module): 36 | def __init__(self, nIn, nOut, stride): 37 | super(DownsampleC, self).__init__() 38 | assert stride != 1 or nIn != nOut 39 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | return x 44 | 45 | 46 | class DownsampleD(nn.Module): 47 | def __init__(self, nIn, nOut, stride): 48 | super(DownsampleD, self).__init__() 49 | assert stride == 2 50 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) 51 | self.bn = nn.BatchNorm2d(nOut) 52 | 53 | def forward(self, x): 54 | x = self.conv(x) 55 | x = self.bn(x) 56 | return x 57 | 58 | 59 | class ResNetBasicblock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None, last=False): 63 | super(ResNetBasicblock, self).__init__() 64 | 65 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 66 | self.bn_a = nn.BatchNorm2d(planes) 67 | 68 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 69 | self.bn_b = nn.BatchNorm2d(planes) 70 | 71 | self.downsample = downsample 72 | self.last = last 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | basicblock = self.conv_a(x) 78 | basicblock = self.bn_a(basicblock) 79 | basicblock = F.relu(basicblock, inplace=True) 80 | 81 | basicblock = self.conv_b(basicblock) 82 | basicblock = self.bn_b(basicblock) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out = residual + basicblock 88 | if not self.last: 89 | out = F.relu(out, inplace=True) 90 | 91 | return out 92 | 93 | 94 | class CifarResNet(nn.Module): 95 | """ 96 | ResNet optimized for the Cifar Dataset, as specified in 97 | https://arxiv.org/abs/1512.03385.pdf 98 | """ 99 | 100 | def __init__(self, block, depth, channels=3): 101 | super(CifarResNet, self).__init__() 102 | 103 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 104 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 105 | layer_blocks = (depth - 2) // 6 106 | 107 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 108 | self.bn_1 = nn.BatchNorm2d(16) 109 | 110 | self.inplanes = 16 111 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 112 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 113 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2, last_phase=True) 114 | self.avgpool = nn.AvgPool2d(8) 115 | self.out_dim = 64 * block.expansion 116 | # self.fc = CosineLinear(64*block.expansion, 10) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1, last_phase=False): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = DownsampleB(self.inplanes, planes * block.expansion, stride) # DownsampleA => DownsampleB 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | if last_phase: 134 | for i in range(1, blocks-1): 135 | layers.append(block(self.inplanes, planes)) 136 | layers.append(block(self.inplanes, planes, last=True)) 137 | else: 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv_1_3x3(x) # [bs, 16, 32, 32] 145 | x = F.relu(self.bn_1(x), inplace=True) 146 | 147 | x_1 = self.stage_1(x) # [bs, 16, 32, 32] 148 | x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] 149 | x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] 150 | 151 | pooled = self.avgpool(x_3) # [bs, 64, 1, 1] 152 | features = pooled.view(pooled.size(0), -1) # [bs, 64] 153 | # out = self.fc(vector) 154 | 155 | return { 156 | 'fmaps': [x_1, x_2, x_3], 157 | 'features': features 158 | } 159 | 160 | @property 161 | def last_conv(self): 162 | return self.stage_3[-1].conv_b 163 | 164 | 165 | def resnet20mnist(): 166 | """Constructs a ResNet-20 model for MNIST.""" 167 | model = CifarResNet(ResNetBasicblock, 20, 1) 168 | return model 169 | 170 | 171 | def resnet32mnist(): 172 | """Constructs a ResNet-32 model for MNIST.""" 173 | model = CifarResNet(ResNetBasicblock, 32, 1) 174 | return model 175 | 176 | 177 | def resnet20(): 178 | """Constructs a ResNet-20 model for CIFAR-10.""" 179 | model = CifarResNet(ResNetBasicblock, 20) 180 | return model 181 | 182 | 183 | def resnet32(): 184 | """Constructs a ResNet-32 model for CIFAR-10.""" 185 | model = CifarResNet(ResNetBasicblock, 32) 186 | return model 187 | 188 | 189 | def resnet44(): 190 | """Constructs a ResNet-44 model for CIFAR-10.""" 191 | model = CifarResNet(ResNetBasicblock, 44) 192 | return model 193 | 194 | 195 | def resnet56(): 196 | """Constructs a ResNet-56 model for CIFAR-10.""" 197 | model = CifarResNet(ResNetBasicblock, 56) 198 | return model 199 | 200 | 201 | def resnet110(): 202 | """Constructs a ResNet-110 model for CIFAR-10.""" 203 | model = CifarResNet(ResNetBasicblock, 110) 204 | return model 205 | -------------------------------------------------------------------------------- /models/finetune.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.serialization import load 6 | from tqdm import tqdm 7 | from torch import optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from utils.inc_net import IncrementalNet 11 | from models.base import BaseLearner 12 | from utils.toolkit import target2onehot, tensor2numpy 13 | 14 | 15 | init_epoch = 200 16 | init_lr = 0.1 17 | init_milestones = [60, 120, 170] 18 | init_lr_decay = 0.1 19 | init_weight_decay = 0.0005 20 | 21 | 22 | epochs = 80 23 | lrate = 0.1 24 | milestones = [40, 70] 25 | lrate_decay = 0.1 26 | batch_size = 128 27 | weight_decay = 2e-4 28 | num_workers = 8 29 | 30 | 31 | class Finetune(BaseLearner): 32 | def __init__(self, args): 33 | super().__init__(args) 34 | self._network = IncrementalNet(args, False) 35 | 36 | def after_task(self): 37 | self._known_classes = self._total_classes 38 | 39 | def incremental_train(self, data_manager): 40 | self._cur_task += 1 41 | self._total_classes = self._known_classes + data_manager.get_task_size( 42 | self._cur_task 43 | ) 44 | self._network.update_fc(self._total_classes) 45 | logging.info( 46 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 47 | ) 48 | 49 | train_dataset = data_manager.get_dataset( 50 | np.arange(self._known_classes, self._total_classes), 51 | source="train", 52 | mode="train", 53 | ) 54 | self.train_loader = DataLoader( 55 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 56 | ) 57 | test_dataset = data_manager.get_dataset( 58 | np.arange(0, self._total_classes), source="test", mode="test" 59 | ) 60 | self.test_loader = DataLoader( 61 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 62 | ) 63 | 64 | if len(self._multiple_gpus) > 1: 65 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 66 | self._train(self.train_loader, self.test_loader) 67 | if len(self._multiple_gpus) > 1: 68 | self._network = self._network.module 69 | 70 | def _train(self, train_loader, test_loader): 71 | self._network.to(self._device) 72 | if self._cur_task == 0: 73 | optimizer = optim.SGD( 74 | self._network.parameters(), 75 | momentum=0.9, 76 | lr=init_lr, 77 | weight_decay=init_weight_decay, 78 | ) 79 | scheduler = optim.lr_scheduler.MultiStepLR( 80 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 81 | ) 82 | self._init_train(train_loader, test_loader, optimizer, scheduler) 83 | else: 84 | optimizer = optim.SGD( 85 | self._network.parameters(), 86 | lr=lrate, 87 | momentum=0.9, 88 | weight_decay=weight_decay, 89 | ) # 1e-5 90 | scheduler = optim.lr_scheduler.MultiStepLR( 91 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 92 | ) 93 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 94 | 95 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 96 | prog_bar = tqdm(range(init_epoch)) 97 | for _, epoch in enumerate(prog_bar): 98 | self._network.train() 99 | losses = 0.0 100 | correct, total = 0, 0 101 | for i, (_, inputs, targets) in enumerate(train_loader): 102 | inputs, targets = inputs.to(self._device), targets.to(self._device) 103 | logits = self._network(inputs)["logits"] 104 | 105 | loss = F.cross_entropy(logits, targets) 106 | optimizer.zero_grad() 107 | loss.backward() 108 | optimizer.step() 109 | losses += loss.item() 110 | 111 | _, preds = torch.max(logits, dim=1) 112 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 113 | total += len(targets) 114 | 115 | scheduler.step() 116 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 117 | 118 | if epoch % 5 == 0: 119 | test_acc = self._compute_accuracy(self._network, test_loader) 120 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 121 | self._cur_task, 122 | epoch + 1, 123 | init_epoch, 124 | losses / len(train_loader), 125 | train_acc, 126 | test_acc, 127 | ) 128 | else: 129 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 130 | self._cur_task, 131 | epoch + 1, 132 | init_epoch, 133 | losses / len(train_loader), 134 | train_acc, 135 | ) 136 | 137 | prog_bar.set_description(info) 138 | 139 | logging.info(info) 140 | 141 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 142 | 143 | prog_bar = tqdm(range(epochs)) 144 | for _, epoch in enumerate(prog_bar): 145 | self._network.train() 146 | losses = 0.0 147 | correct, total = 0, 0 148 | for i, (_, inputs, targets) in enumerate(train_loader): 149 | inputs, targets = inputs.to(self._device), targets.to(self._device) 150 | logits = self._network(inputs)["logits"] 151 | 152 | fake_targets = targets - self._known_classes 153 | loss_clf = F.cross_entropy( 154 | logits[:, self._known_classes :], fake_targets 155 | ) 156 | 157 | loss = loss_clf 158 | 159 | optimizer.zero_grad() 160 | loss.backward() 161 | optimizer.step() 162 | losses += loss.item() 163 | 164 | _, preds = torch.max(logits, dim=1) 165 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 166 | total += len(targets) 167 | 168 | scheduler.step() 169 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 170 | if epoch % 5 == 0: 171 | test_acc = self._compute_accuracy(self._network, test_loader) 172 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 173 | self._cur_task, 174 | epoch + 1, 175 | epochs, 176 | losses / len(train_loader), 177 | train_acc, 178 | test_acc, 179 | ) 180 | else: 181 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 182 | self._cur_task, 183 | epoch + 1, 184 | epochs, 185 | losses / len(train_loader), 186 | train_acc, 187 | ) 188 | prog_bar.set_description(info) 189 | logging.info(info) 190 | -------------------------------------------------------------------------------- /models/replay.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from models.base import BaseLearner 10 | from utils.inc_net import IncrementalNet 11 | from utils.toolkit import target2onehot, tensor2numpy 12 | 13 | EPSILON = 1e-8 14 | 15 | 16 | init_epoch = 200 17 | init_lr = 0.1 18 | init_milestones = [60, 120, 170] 19 | init_lr_decay = 0.1 20 | init_weight_decay = 0.0005 21 | 22 | 23 | epochs = 70 24 | lrate = 0.1 25 | milestones = [30, 50] 26 | lrate_decay = 0.1 27 | batch_size = 128 28 | weight_decay = 2e-4 29 | num_workers = 4 30 | T = 2 31 | 32 | 33 | class Replay(BaseLearner): 34 | def __init__(self, args): 35 | super().__init__(args) 36 | self._network = IncrementalNet(args, False) 37 | 38 | def after_task(self): 39 | self._known_classes = self._total_classes 40 | logging.info("Exemplar size: {}".format(self.exemplar_size)) 41 | 42 | def incremental_train(self, data_manager): 43 | self._cur_task += 1 44 | self._total_classes = self._known_classes + data_manager.get_task_size( 45 | self._cur_task 46 | ) 47 | self._network.update_fc(self._total_classes) 48 | logging.info( 49 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 50 | ) 51 | 52 | # Loader 53 | train_dataset = data_manager.get_dataset( 54 | np.arange(self._known_classes, self._total_classes), 55 | source="train", 56 | mode="train", 57 | appendent=self._get_memory(), 58 | ) 59 | self.train_loader = DataLoader( 60 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 61 | ) 62 | test_dataset = data_manager.get_dataset( 63 | np.arange(0, self._total_classes), source="test", mode="test" 64 | ) 65 | self.test_loader = DataLoader( 66 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 67 | ) 68 | 69 | # Procedure 70 | if len(self._multiple_gpus) > 1: 71 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 72 | self._train(self.train_loader, self.test_loader) 73 | 74 | self.build_rehearsal_memory(data_manager, self.samples_per_class) 75 | if len(self._multiple_gpus) > 1: 76 | self._network = self._network.module 77 | 78 | def _train(self, train_loader, test_loader): 79 | self._network.to(self._device) 80 | if self._cur_task == 0: 81 | optimizer = optim.SGD( 82 | self._network.parameters(), 83 | momentum=0.9, 84 | lr=init_lr, 85 | weight_decay=init_weight_decay, 86 | ) 87 | scheduler = optim.lr_scheduler.MultiStepLR( 88 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 89 | ) 90 | self._init_train(train_loader, test_loader, optimizer, scheduler) 91 | else: 92 | optimizer = optim.SGD( 93 | self._network.parameters(), 94 | lr=lrate, 95 | momentum=0.9, 96 | weight_decay=weight_decay, 97 | ) # 1e-5 98 | scheduler = optim.lr_scheduler.MultiStepLR( 99 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 100 | ) 101 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 102 | 103 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 104 | prog_bar = tqdm(range(init_epoch)) 105 | for _, epoch in enumerate(prog_bar): 106 | self._network.train() 107 | losses = 0.0 108 | correct, total = 0, 0 109 | for i, (_, inputs, targets) in enumerate(train_loader): 110 | inputs, targets = inputs.to(self._device), targets.to(self._device) 111 | logits = self._network(inputs)["logits"] 112 | 113 | loss = F.cross_entropy(logits, targets) 114 | optimizer.zero_grad() 115 | loss.backward() 116 | optimizer.step() 117 | losses += loss.item() 118 | 119 | _, preds = torch.max(logits, dim=1) 120 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 121 | total += len(targets) 122 | 123 | scheduler.step() 124 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 125 | 126 | if epoch % 5 == 0: 127 | test_acc = self._compute_accuracy(self._network, test_loader) 128 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 129 | self._cur_task, 130 | epoch + 1, 131 | init_epoch, 132 | losses / len(train_loader), 133 | train_acc, 134 | test_acc, 135 | ) 136 | else: 137 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 138 | self._cur_task, 139 | epoch + 1, 140 | init_epoch, 141 | losses / len(train_loader), 142 | train_acc, 143 | ) 144 | 145 | prog_bar.set_description(info) 146 | 147 | logging.info(info) 148 | 149 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 150 | prog_bar = tqdm(range(epochs)) 151 | for _, epoch in enumerate(prog_bar): 152 | self._network.train() 153 | losses = 0.0 154 | correct, total = 0, 0 155 | for i, (_, inputs, targets) in enumerate(train_loader): 156 | inputs, targets = inputs.to(self._device), targets.to(self._device) 157 | logits = self._network(inputs)["logits"] 158 | 159 | loss_clf = F.cross_entropy(logits, targets) 160 | loss = loss_clf 161 | 162 | optimizer.zero_grad() 163 | loss.backward() 164 | optimizer.step() 165 | losses += loss.item() 166 | 167 | # acc 168 | _, preds = torch.max(logits, dim=1) 169 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 170 | total += len(targets) 171 | 172 | scheduler.step() 173 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 174 | if epoch % 5 == 0: 175 | test_acc = self._compute_accuracy(self._network, test_loader) 176 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 177 | self._cur_task, 178 | epoch + 1, 179 | epochs, 180 | losses / len(train_loader), 181 | train_acc, 182 | test_acc, 183 | ) 184 | else: 185 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 186 | self._cur_task, 187 | epoch + 1, 188 | epochs, 189 | losses / len(train_loader), 190 | train_acc, 191 | ) 192 | prog_bar.set_description(info) 193 | logging.info(info) 194 | -------------------------------------------------------------------------------- /utils/rl_utils/ddpg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class PolicyNet(torch.nn.Module): 9 | def __init__(self, state_dim, hidden_dim, action_dim, action_bound): 10 | super(PolicyNet, self).__init__() 11 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 12 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 13 | self.action_bound = action_bound 14 | 15 | def forward(self, x): 16 | x = F.relu(self.fc1(x)) 17 | return torch.tanh(self.fc2(x)) * self.action_bound 18 | 19 | 20 | class RMMPolicyNet(torch.nn.Module): 21 | def __init__(self, state_dim, hidden_dim, action_dim): 22 | super(RMMPolicyNet, self).__init__() 23 | self.fc1 = nn.Sequential( 24 | nn.Linear(state_dim, hidden_dim), 25 | nn.ReLU(inplace=True), 26 | nn.Linear(hidden_dim, action_dim), 27 | ) 28 | self.fc2 = nn.Sequential( 29 | nn.Linear(state_dim+action_dim, hidden_dim), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(hidden_dim, action_dim), 32 | ) 33 | def forward(self, x): 34 | a1 = torch.sigmoid(self.fc1(x)) 35 | x = torch.cat([x,a1],dim=1) 36 | a2 = torch.tanh(self.fc2(x)) 37 | return torch.cat([a1,a2],dim=1) 38 | 39 | class QValueNet(torch.nn.Module): 40 | def __init__(self, state_dim, hidden_dim, action_dim): 41 | super(QValueNet, self).__init__() 42 | self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim) 43 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 44 | 45 | def forward(self, x, a): 46 | cat = torch.cat([x, a], dim=1) 47 | x = F.relu(self.fc1(cat)) 48 | return self.fc2(x) 49 | 50 | 51 | class TwoLayerFC(torch.nn.Module): 52 | def __init__( 53 | self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x 54 | ): 55 | super().__init__() 56 | self.fc1 = nn.Linear(num_in, hidden_dim) 57 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 58 | self.fc3 = nn.Linear(hidden_dim, num_out) 59 | 60 | self.activation = activation 61 | self.out_fn = out_fn 62 | 63 | def forward(self, x): 64 | x = self.activation(self.fc1(x)) 65 | x = self.activation(self.fc2(x)) 66 | x = self.out_fn(self.fc3(x)) 67 | return x 68 | 69 | 70 | class DDPG: 71 | """DDPG algo""" 72 | 73 | def __init__( 74 | self, 75 | num_in_actor, 76 | num_out_actor, 77 | num_in_critic, 78 | hidden_dim, 79 | discrete, 80 | action_bound, 81 | sigma, 82 | actor_lr, 83 | critic_lr, 84 | tau, 85 | gamma, 86 | device, 87 | use_rmm=True, 88 | ): 89 | 90 | out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound) 91 | 92 | if use_rmm: 93 | self.actor = RMMPolicyNet( 94 | num_in_actor, 95 | hidden_dim, 96 | num_out_actor, 97 | ).to(device) 98 | self.target_actor = RMMPolicyNet( 99 | num_in_actor, 100 | hidden_dim, 101 | num_out_actor, 102 | ).to(device) 103 | else: 104 | self.actor = TwoLayerFC( 105 | num_in_actor, 106 | num_out_actor, 107 | hidden_dim, 108 | activation=F.relu, 109 | out_fn=out_fn, 110 | ).to(device) 111 | self.target_actor = TwoLayerFC( 112 | num_in_actor, 113 | num_out_actor, 114 | hidden_dim, 115 | activation=F.relu, 116 | out_fn=out_fn, 117 | ).to(device) 118 | 119 | self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) 120 | self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) 121 | self.target_critic.load_state_dict(self.critic.state_dict()) 122 | self.target_actor.load_state_dict(self.actor.state_dict()) 123 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) 124 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) 125 | self.gamma = gamma 126 | self.sigma = sigma 127 | self.action_bound = action_bound 128 | self.tau = tau 129 | self.action_dim = num_out_actor 130 | self.device = device 131 | 132 | def take_action(self, state): 133 | state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device) 134 | action = self.actor(state)[0].detach().cpu().numpy() 135 | 136 | action = action + self.sigma * np.random.randn(self.action_dim) 137 | action[0]=np.clip(action[0],0,1) 138 | action[1]=np.clip(action[1],-1,1) 139 | return action 140 | def save_state_dict(self,name): 141 | dicts = { 142 | "critic":self.critic.state_dict(), 143 | "target_critic":self.target_critic.state_dict(), 144 | "actor":self.actor.state_dict(), 145 | "target_actor":self.target_actor.state_dict() 146 | } 147 | torch.save(dicts,name) 148 | def load_state_dict(self,name): 149 | dicts = torch.load(name) 150 | self.critic.load_state_dict(dicts["critic"]) 151 | self.target_critic.load_state_dict(dicts["target_critic"]) 152 | self.actor.load_state_dict(dicts["actor"]) 153 | self.target_actor.load_state_dict(dicts["target_actor"]) 154 | def soft_update(self, net, target_net): 155 | for param_target, param in zip(target_net.parameters(), net.parameters()): 156 | param_target.data.copy_( 157 | param_target.data * (1.0 - self.tau) + param.data * self.tau 158 | ) 159 | 160 | def update(self, transition_dict): 161 | states = torch.tensor(transition_dict["states"], dtype=torch.float).to( 162 | self.device 163 | ) 164 | actions = ( 165 | torch.tensor(transition_dict["actions"], dtype=torch.float) 166 | .to(self.device) 167 | ) 168 | rewards = ( 169 | torch.tensor(transition_dict["rewards"], dtype=torch.float) 170 | .view(-1, 1) 171 | .to(self.device) 172 | ) 173 | next_states = torch.tensor( 174 | transition_dict["next_states"], dtype=torch.float 175 | ).to(self.device) 176 | dones = ( 177 | torch.tensor(transition_dict["dones"], dtype=torch.float) 178 | .view(-1, 1) 179 | .to(self.device) 180 | ) 181 | 182 | next_q_values = self.target_critic( 183 | torch.cat([next_states, self.target_actor(next_states)], dim=1) 184 | ) 185 | q_targets = rewards + self.gamma * next_q_values * (1 - dones) 186 | critic_loss = torch.mean( 187 | F.mse_loss( 188 | self.critic(torch.cat([states, actions], dim=1)), 189 | q_targets, 190 | ) 191 | ) 192 | self.critic_optimizer.zero_grad() 193 | critic_loss.backward() 194 | self.critic_optimizer.step() 195 | 196 | actor_loss = -torch.mean( 197 | self.critic( 198 | torch.cat([states, self.actor(states)], dim=1) 199 | ) 200 | ) 201 | self.actor_optimizer.zero_grad() 202 | actor_loss.backward() 203 | self.actor_optimizer.step() 204 | logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ") 205 | self.soft_update(self.actor, self.target_actor) # soft-update the target policy net 206 | self.soft_update(self.critic, self.target_critic) # soft-update the target Q value net 207 | -------------------------------------------------------------------------------- /rmm_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | We implemented `iCaRL+RMM`, `FOSTER+RMM` in [rmm.py](models/rmm.py). We implemented the `Pretraining Stage` of `RMM` in [rmm_train.py](rmm_train.py). 3 | Use the following training script to run it. 4 | ```bash 5 | python rmm_train.py --config=./exps/rmm-pretrain.json 6 | ``` 7 | ''' 8 | import json 9 | import argparse 10 | from trainer import train 11 | import sys 12 | import logging 13 | import copy 14 | import torch 15 | from utils import factory 16 | from utils.data_manager import DataManager 17 | from utils.rl_utils.ddpg import DDPG 18 | from utils.rl_utils.rl_utils import ReplayBuffer 19 | from utils.toolkit import count_parameters 20 | import os 21 | import numpy as np 22 | import random 23 | 24 | 25 | class CILEnv: 26 | def __init__(self, args) -> None: 27 | self._args = copy.deepcopy(args) 28 | self.settings = [(50, 2), (50, 5), (50, 10), (50, 20), (10, 10), (20, 20), (5, 5)] 29 | # self.settings = [(5,5)] # Debug 30 | self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] 31 | self.data_manager = DataManager( 32 | self._args["dataset"], 33 | self._args["shuffle"], 34 | self._args["seed"], 35 | self._args["init_cls"], 36 | self._args["increment"], 37 | ) 38 | self.model = factory.get_model(self._args["model_name"], self._args) 39 | 40 | @property 41 | def nb_task(self): 42 | return self.data_manager.nb_tasks 43 | 44 | @property 45 | def cur_task(self): 46 | return self.model._cur_task 47 | 48 | def get_task_size(self, task_id): 49 | return self.data_manager.get_task_size(task_id) 50 | 51 | def reset(self): 52 | self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] 53 | self.data_manager = DataManager( 54 | self._args["dataset"], 55 | self._args["shuffle"], 56 | self._args["seed"], 57 | self._args["init_cls"], 58 | self._args["increment"], 59 | ) 60 | self.model = factory.get_model(self._args["model_name"], self._args) 61 | 62 | info = "start new task: dataset: {}, init_cls: {}, increment: {}".format( 63 | self._args["dataset"], self._args["init_cls"], self._args["increment"] 64 | ) 65 | return np.array([self.get_task_size(0) / 100, 0]), None, False, info 66 | 67 | def step(self, action): 68 | self.model._m_rate_list.append(action[0]) 69 | self.model._c_rate_list.append(action[1]) 70 | self.model.incremental_train(self.data_manager) 71 | cnn_accy, nme_accy = self.model.eval_task() 72 | self.model.after_task() 73 | done = self.cur_task == self.nb_task - 1 74 | info = "running task [{}/{}]: dataset: {}, increment: {}, cnn_accy top1: {}, top5: {}".format( 75 | self.model._known_classes, 76 | 100, 77 | self._args["dataset"], 78 | self._args["increment"], 79 | cnn_accy["top1"], 80 | cnn_accy["top5"], 81 | ) 82 | return ( 83 | np.array( 84 | [ 85 | self.get_task_size(self.cur_task+1)/100 if not done else 0., 86 | self.model.memory_size 87 | / (self.model.memory_size + self.model.new_memory_size), 88 | ] 89 | ), 90 | cnn_accy["top1"]/100, 91 | done, 92 | info, 93 | ) 94 | 95 | 96 | def _train(args): 97 | 98 | logs_name = "logs/RL-CIL/{}/".format(args["model_name"]) 99 | if not os.path.exists(logs_name): 100 | os.makedirs(logs_name) 101 | 102 | logfilename = "logs/RL-CIL/{}/{}_{}_{}_{}_{}".format( 103 | args["model_name"], 104 | args["prefix"], 105 | args["seed"], 106 | args["model_name"], 107 | args["convnet_type"], 108 | args["dataset"], 109 | ) 110 | logging.basicConfig( 111 | level=logging.INFO, 112 | format="%(asctime)s [%(filename)s] => %(message)s", 113 | handlers=[ 114 | logging.FileHandler(filename=logfilename + ".log"), 115 | logging.StreamHandler(sys.stdout), 116 | ], 117 | ) 118 | 119 | _set_random() 120 | _set_device(args) 121 | print_args(args) 122 | 123 | actor_lr = 5e-4 124 | critic_lr = 5e-3 125 | num_episodes = 200 126 | hidden_dim = 32 127 | gamma = 0.98 128 | tau = 0.005 129 | buffer_size = 1000 130 | minimal_size = 50 131 | batch_size = 32 132 | sigma = 0.2 # action noise, encouraging the off-policy algo to explore. 133 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 134 | env = CILEnv(args) 135 | replay_buffer = ReplayBuffer(buffer_size) 136 | agent = DDPG( 137 | 2, 1, 4, hidden_dim, False, 1, sigma, actor_lr, critic_lr, tau, gamma, device 138 | ) 139 | for iteration in range(num_episodes): 140 | state, *_, info = env.reset() 141 | logging.info(info) 142 | done = False 143 | while not done: 144 | action = agent.take_action(state) 145 | logging.info(f"take action: m_rate {action[0]}, c_rate {action[1]}") 146 | next_state, reward, done, info = env.step(action) 147 | logging.info(info) 148 | replay_buffer.add(state, action, reward, next_state, done) 149 | state = next_state 150 | if replay_buffer.size() > minimal_size: 151 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) 152 | transition_dict = { 153 | "states": b_s, 154 | "actions": b_a, 155 | "next_states": b_ns, 156 | "rewards": b_r, 157 | "dones": b_d, 158 | } 159 | agent.update(transition_dict) 160 | 161 | 162 | def _set_device(args): 163 | device_type = args["device"] 164 | gpus = [] 165 | 166 | for device in device_type: 167 | if device_type == -1: 168 | device = torch.device("cpu") 169 | else: 170 | device = torch.device("cuda:{}".format(device)) 171 | 172 | gpus.append(device) 173 | 174 | args["device"] = gpus 175 | 176 | 177 | def _set_random(): 178 | random.seed(1) 179 | torch.manual_seed(1) 180 | torch.cuda.manual_seed(1) 181 | torch.cuda.manual_seed_all(1) 182 | torch.backends.cudnn.deterministic = True 183 | torch.backends.cudnn.benchmark = False 184 | 185 | 186 | def print_args(args): 187 | for key, value in args.items(): 188 | logging.info("{}: {}".format(key, value)) 189 | 190 | 191 | def train(args): 192 | seed_list = copy.deepcopy(args["seed"]) 193 | device = copy.deepcopy(args["device"]) 194 | 195 | for seed in seed_list: 196 | args["seed"] = seed 197 | args["device"] = device 198 | _train(args) 199 | 200 | 201 | def main(): 202 | args = setup_parser().parse_args() 203 | param = load_json(args.config) 204 | args = vars(args) # Converting argparse Namespace to a dict. 205 | args.update(param) # Add parameters from json 206 | 207 | train(args) 208 | 209 | 210 | def load_json(settings_path): 211 | with open(settings_path) as data_file: 212 | param = json.load(data_file) 213 | 214 | return param 215 | 216 | 217 | def setup_parser(): 218 | parser = argparse.ArgumentParser( 219 | description="Reproduce of multiple continual learning algorthms." 220 | ) 221 | parser.add_argument( 222 | "--config", 223 | type=str, 224 | default="./exps/finetune.json", 225 | help="Json file of settings.", 226 | ) 227 | 228 | return parser 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | -------------------------------------------------------------------------------- /models/lwf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.serialization import load 6 | from tqdm import tqdm 7 | from torch import optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from utils.inc_net import IncrementalNet 11 | from models.base import BaseLearner 12 | from utils.toolkit import target2onehot, tensor2numpy 13 | 14 | init_epoch = 200 15 | init_lr = 0.1 16 | init_milestones = [60, 120, 160] 17 | init_lr_decay = 0.1 18 | init_weight_decay = 0.0005 19 | 20 | 21 | epochs = 250 22 | lrate = 0.1 23 | milestones = [60, 120, 180, 220] 24 | lrate_decay = 0.1 25 | batch_size = 128 26 | weight_decay = 2e-4 27 | num_workers = 8 28 | T = 2 29 | lamda = 3 30 | 31 | 32 | class LwF(BaseLearner): 33 | def __init__(self, args): 34 | super().__init__(args) 35 | self._network = IncrementalNet(args, False) 36 | 37 | def after_task(self): 38 | self._old_network = self._network.copy().freeze() 39 | self._known_classes = self._total_classes 40 | 41 | def incremental_train(self, data_manager): 42 | self._cur_task += 1 43 | self._total_classes = self._known_classes + data_manager.get_task_size( 44 | self._cur_task 45 | ) 46 | self._network.update_fc(self._total_classes) 47 | logging.info( 48 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 49 | ) 50 | 51 | train_dataset = data_manager.get_dataset( 52 | np.arange(self._known_classes, self._total_classes), 53 | source="train", 54 | mode="train", 55 | ) 56 | self.train_loader = DataLoader( 57 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 58 | ) 59 | test_dataset = data_manager.get_dataset( 60 | np.arange(0, self._total_classes), source="test", mode="test" 61 | ) 62 | self.test_loader = DataLoader( 63 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 64 | ) 65 | 66 | if len(self._multiple_gpus) > 1: 67 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 68 | self._train(self.train_loader, self.test_loader) 69 | if len(self._multiple_gpus) > 1: 70 | self._network = self._network.module 71 | 72 | def _train(self, train_loader, test_loader): 73 | self._network.to(self._device) 74 | if self._old_network is not None: 75 | self._old_network.to(self._device) 76 | 77 | if self._cur_task == 0: 78 | optimizer = optim.SGD( 79 | self._network.parameters(), 80 | momentum=0.9, 81 | lr=init_lr, 82 | weight_decay=init_weight_decay, 83 | ) 84 | scheduler = optim.lr_scheduler.MultiStepLR( 85 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 86 | ) 87 | self._init_train(train_loader, test_loader, optimizer, scheduler) 88 | else: 89 | optimizer = optim.SGD( 90 | self._network.parameters(), 91 | lr=lrate, 92 | momentum=0.9, 93 | weight_decay=weight_decay, 94 | ) 95 | scheduler = optim.lr_scheduler.MultiStepLR( 96 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 97 | ) 98 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 99 | 100 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 101 | prog_bar = tqdm(range(init_epoch)) 102 | for _, epoch in enumerate(prog_bar): 103 | self._network.train() 104 | losses = 0.0 105 | correct, total = 0, 0 106 | for i, (_, inputs, targets) in enumerate(train_loader): 107 | inputs, targets = inputs.to(self._device), targets.to(self._device) 108 | logits = self._network(inputs)["logits"] 109 | 110 | loss = F.cross_entropy(logits, targets) 111 | optimizer.zero_grad() 112 | loss.backward() 113 | optimizer.step() 114 | losses += loss.item() 115 | 116 | _, preds = torch.max(logits, dim=1) 117 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 118 | total += len(targets) 119 | 120 | scheduler.step() 121 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 122 | 123 | if epoch % 5 == 0: 124 | test_acc = self._compute_accuracy(self._network, test_loader) 125 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 126 | self._cur_task, 127 | epoch + 1, 128 | init_epoch, 129 | losses / len(train_loader), 130 | train_acc, 131 | test_acc, 132 | ) 133 | else: 134 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 135 | self._cur_task, 136 | epoch + 1, 137 | init_epoch, 138 | losses / len(train_loader), 139 | train_acc, 140 | ) 141 | prog_bar.set_description(info) 142 | 143 | logging.info(info) 144 | 145 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 146 | 147 | prog_bar = tqdm(range(epochs)) 148 | for _, epoch in enumerate(prog_bar): 149 | self._network.train() 150 | losses = 0.0 151 | correct, total = 0, 0 152 | for i, (_, inputs, targets) in enumerate(train_loader): 153 | inputs, targets = inputs.to(self._device), targets.to(self._device) 154 | logits = self._network(inputs)["logits"] 155 | 156 | fake_targets = targets - self._known_classes 157 | loss_clf = F.cross_entropy( 158 | logits[:, self._known_classes :], fake_targets 159 | ) 160 | loss_kd = _KD_loss( 161 | logits[:, : self._known_classes], 162 | self._old_network(inputs)["logits"], 163 | T, 164 | ) 165 | 166 | loss = lamda * loss_kd + loss_clf 167 | 168 | optimizer.zero_grad() 169 | loss.backward() 170 | optimizer.step() 171 | losses += loss.item() 172 | 173 | with torch.no_grad(): 174 | _, preds = torch.max(logits, dim=1) 175 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 176 | total += len(targets) 177 | 178 | scheduler.step() 179 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 180 | if epoch % 5 == 0: 181 | test_acc = self._compute_accuracy(self._network, test_loader) 182 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 183 | self._cur_task, 184 | epoch + 1, 185 | epochs, 186 | losses / len(train_loader), 187 | train_acc, 188 | test_acc, 189 | ) 190 | else: 191 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 192 | self._cur_task, 193 | epoch + 1, 194 | epochs, 195 | losses / len(train_loader), 196 | train_acc, 197 | ) 198 | prog_bar.set_description(info) 199 | logging.info(info) 200 | 201 | 202 | def _KD_loss(pred, soft, T): 203 | pred = torch.log_softmax(pred / T, dim=1) 204 | soft = torch.softmax(soft / T, dim=1) 205 | return -1 * torch.mul(soft, pred).sum() / pred.shape[0] 206 | -------------------------------------------------------------------------------- /models/icarl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from models.base import BaseLearner 10 | from utils.inc_net import IncrementalNet 11 | from utils.inc_net import CosineIncrementalNet 12 | from utils.toolkit import target2onehot, tensor2numpy 13 | 14 | EPSILON = 1e-8 15 | 16 | init_epoch = 200 17 | init_lr = 0.1 18 | init_milestones = [60, 120, 170] 19 | init_lr_decay = 0.1 20 | init_weight_decay = 0.0005 21 | 22 | 23 | epochs = 170 24 | lrate = 0.1 25 | milestones = [80, 120] 26 | lrate_decay = 0.1 27 | batch_size = 128 28 | weight_decay = 2e-4 29 | num_workers = 8 30 | T = 2 31 | 32 | 33 | class iCaRL(BaseLearner): 34 | def __init__(self, args): 35 | super().__init__(args) 36 | self._network = IncrementalNet(args, False) 37 | 38 | def after_task(self): 39 | self._old_network = self._network.copy().freeze() 40 | self._known_classes = self._total_classes 41 | logging.info("Exemplar size: {}".format(self.exemplar_size)) 42 | 43 | def incremental_train(self, data_manager): 44 | self._cur_task += 1 45 | self._total_classes = self._known_classes + data_manager.get_task_size( 46 | self._cur_task 47 | ) 48 | self._network.update_fc(self._total_classes) 49 | logging.info( 50 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 51 | ) 52 | 53 | train_dataset = data_manager.get_dataset( 54 | np.arange(self._known_classes, self._total_classes), 55 | source="train", 56 | mode="train", 57 | appendent=self._get_memory(), 58 | ) 59 | self.train_loader = DataLoader( 60 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 61 | ) 62 | test_dataset = data_manager.get_dataset( 63 | np.arange(0, self._total_classes), source="test", mode="test" 64 | ) 65 | self.test_loader = DataLoader( 66 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 67 | ) 68 | 69 | if len(self._multiple_gpus) > 1: 70 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 71 | self._train(self.train_loader, self.test_loader) 72 | self.build_rehearsal_memory(data_manager, self.samples_per_class) 73 | if len(self._multiple_gpus) > 1: 74 | self._network = self._network.module 75 | 76 | def _train(self, train_loader, test_loader): 77 | self._network.to(self._device) 78 | if self._old_network is not None: 79 | self._old_network.to(self._device) 80 | 81 | if self._cur_task == 0: 82 | optimizer = optim.SGD( 83 | self._network.parameters(), 84 | momentum=0.9, 85 | lr=init_lr, 86 | weight_decay=init_weight_decay, 87 | ) 88 | scheduler = optim.lr_scheduler.MultiStepLR( 89 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 90 | ) 91 | self._init_train(train_loader, test_loader, optimizer, scheduler) 92 | else: 93 | optimizer = optim.SGD( 94 | self._network.parameters(), 95 | lr=lrate, 96 | momentum=0.9, 97 | weight_decay=weight_decay, 98 | ) # 1e-5 99 | scheduler = optim.lr_scheduler.MultiStepLR( 100 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 101 | ) 102 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 103 | 104 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 105 | prog_bar = tqdm(range(init_epoch)) 106 | for _, epoch in enumerate(prog_bar): 107 | self._network.train() 108 | losses = 0.0 109 | correct, total = 0, 0 110 | for i, (_, inputs, targets) in enumerate(train_loader): 111 | inputs, targets = inputs.to(self._device), targets.to(self._device) 112 | logits = self._network(inputs)["logits"] 113 | 114 | loss = F.cross_entropy(logits, targets) 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | losses += loss.item() 119 | 120 | _, preds = torch.max(logits, dim=1) 121 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 122 | total += len(targets) 123 | 124 | scheduler.step() 125 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 126 | 127 | if epoch % 5 == 0: 128 | test_acc = self._compute_accuracy(self._network, test_loader) 129 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 130 | self._cur_task, 131 | epoch + 1, 132 | init_epoch, 133 | losses / len(train_loader), 134 | train_acc, 135 | test_acc, 136 | ) 137 | else: 138 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 139 | self._cur_task, 140 | epoch + 1, 141 | init_epoch, 142 | losses / len(train_loader), 143 | train_acc, 144 | ) 145 | 146 | prog_bar.set_description(info) 147 | 148 | logging.info(info) 149 | 150 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 151 | prog_bar = tqdm(range(epochs)) 152 | for _, epoch in enumerate(prog_bar): 153 | self._network.train() 154 | losses = 0.0 155 | correct, total = 0, 0 156 | for i, (_, inputs, targets) in enumerate(train_loader): 157 | inputs, targets = inputs.to(self._device), targets.to(self._device) 158 | logits = self._network(inputs)["logits"] 159 | 160 | loss_clf = F.cross_entropy(logits, targets) 161 | loss_kd = _KD_loss( 162 | logits[:, : self._known_classes], 163 | self._old_network(inputs)["logits"], 164 | T, 165 | ) 166 | 167 | loss = loss_clf + loss_kd 168 | 169 | optimizer.zero_grad() 170 | loss.backward() 171 | optimizer.step() 172 | losses += loss.item() 173 | 174 | _, preds = torch.max(logits, dim=1) 175 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 176 | total += len(targets) 177 | 178 | scheduler.step() 179 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 180 | if epoch % 5 == 0: 181 | test_acc = self._compute_accuracy(self._network, test_loader) 182 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 183 | self._cur_task, 184 | epoch + 1, 185 | epochs, 186 | losses / len(train_loader), 187 | train_acc, 188 | test_acc, 189 | ) 190 | else: 191 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 192 | self._cur_task, 193 | epoch + 1, 194 | epochs, 195 | losses / len(train_loader), 196 | train_acc, 197 | ) 198 | prog_bar.set_description(info) 199 | logging.info(info) 200 | 201 | 202 | def _KD_loss(pred, soft, T): 203 | pred = torch.log_softmax(pred / T, dim=1) 204 | soft = torch.softmax(soft / T, dim=1) 205 | return -1 * torch.mul(soft, pred).sum() / pred.shape[0] 206 | -------------------------------------------------------------------------------- /models/bic.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from models.base import BaseLearner 9 | from utils.inc_net import IncrementalNetWithBias 10 | 11 | 12 | epochs = 170 13 | lrate = 0.1 14 | milestones = [60, 100, 140] 15 | lrate_decay = 0.1 16 | batch_size = 128 17 | split_ratio = 0.1 18 | T = 2 19 | weight_decay = 2e-4 20 | num_workers = 8 21 | 22 | 23 | class BiC(BaseLearner): 24 | def __init__(self, args): 25 | super().__init__(args) 26 | self._network = IncrementalNetWithBias( 27 | args, False, bias_correction=True 28 | ) 29 | self._class_means = None 30 | 31 | def after_task(self): 32 | self._old_network = self._network.copy().freeze() 33 | self._known_classes = self._total_classes 34 | logging.info("Exemplar size: {}".format(self.exemplar_size)) 35 | 36 | def incremental_train(self, data_manager): 37 | self._cur_task += 1 38 | self._total_classes = self._known_classes + data_manager.get_task_size( 39 | self._cur_task 40 | ) 41 | self._network.update_fc(self._total_classes) 42 | logging.info( 43 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 44 | ) 45 | 46 | if self._cur_task >= 1: 47 | train_dset, val_dset = data_manager.get_dataset_with_split( 48 | np.arange(self._known_classes, self._total_classes), 49 | source="train", 50 | mode="train", 51 | appendent=self._get_memory(), 52 | val_samples_per_class=int( 53 | split_ratio * self._memory_size / self._known_classes 54 | ), 55 | ) 56 | self.val_loader = DataLoader( 57 | val_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers 58 | ) 59 | logging.info( 60 | "Stage1 dset: {}, Stage2 dset: {}".format( 61 | len(train_dset), len(val_dset) 62 | ) 63 | ) 64 | self.lamda = self._known_classes / self._total_classes 65 | logging.info("Lambda: {:.3f}".format(self.lamda)) 66 | else: 67 | train_dset = data_manager.get_dataset( 68 | np.arange(self._known_classes, self._total_classes), 69 | source="train", 70 | mode="train", 71 | appendent=self._get_memory(), 72 | ) 73 | test_dset = data_manager.get_dataset( 74 | np.arange(0, self._total_classes), source="test", mode="test" 75 | ) 76 | 77 | self.train_loader = DataLoader( 78 | train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers 79 | ) 80 | self.test_loader = DataLoader( 81 | test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers 82 | ) 83 | 84 | self._log_bias_params() 85 | self._stage1_training(self.train_loader, self.test_loader) 86 | if self._cur_task >= 1: 87 | self._stage2_bias_correction(self.val_loader, self.test_loader) 88 | 89 | self.build_rehearsal_memory(data_manager, self.samples_per_class) 90 | 91 | if len(self._multiple_gpus) > 1: 92 | self._network = self._network.module 93 | self._log_bias_params() 94 | 95 | def _run(self, train_loader, test_loader, optimizer, scheduler, stage): 96 | for epoch in range(1, epochs + 1): 97 | self._network.train() 98 | losses = 0.0 99 | for i, (_, inputs, targets) in enumerate(train_loader): 100 | inputs, targets = inputs.to(self._device), targets.to(self._device) 101 | logits = self._network(inputs)["logits"] 102 | 103 | if stage == "training": 104 | clf_loss = F.cross_entropy(logits, targets) 105 | if self._old_network is not None: 106 | old_logits = self._old_network(inputs)["logits"].detach() 107 | hat_pai_k = F.softmax(old_logits / T, dim=1) 108 | log_pai_k = F.log_softmax( 109 | logits[:, : self._known_classes] / T, dim=1 110 | ) 111 | distill_loss = -torch.mean( 112 | torch.sum(hat_pai_k * log_pai_k, dim=1) 113 | ) 114 | loss = distill_loss * self.lamda + clf_loss * (1 - self.lamda) 115 | else: 116 | loss = clf_loss 117 | elif stage == "bias_correction": 118 | loss = F.cross_entropy(torch.softmax(logits, dim=1), targets) 119 | else: 120 | raise NotImplementedError() 121 | 122 | optimizer.zero_grad() 123 | loss.backward() 124 | optimizer.step() 125 | losses += loss.item() 126 | 127 | scheduler.step() 128 | train_acc = self._compute_accuracy(self._network, train_loader) 129 | test_acc = self._compute_accuracy(self._network, test_loader) 130 | info = "{} => Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}".format( 131 | stage, 132 | self._cur_task, 133 | epoch, 134 | epochs, 135 | losses / len(train_loader), 136 | train_acc, 137 | test_acc, 138 | ) 139 | logging.info(info) 140 | 141 | def _stage1_training(self, train_loader, test_loader): 142 | """ 143 | if self._cur_task == 0: 144 | loaded_dict = torch.load('./dict_0.pkl') 145 | self._network.load_state_dict(loaded_dict['model_state_dict']) 146 | self._network.to(self._device) 147 | return 148 | """ 149 | 150 | ignored_params = list(map(id, self._network.bias_layers.parameters())) 151 | base_params = filter( 152 | lambda p: id(p) not in ignored_params, self._network.parameters() 153 | ) 154 | network_params = [ 155 | {"params": base_params, "lr": lrate, "weight_decay": weight_decay}, 156 | { 157 | "params": self._network.bias_layers.parameters(), 158 | "lr": 0, 159 | "weight_decay": 0, 160 | }, 161 | ] 162 | optimizer = optim.SGD( 163 | network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay 164 | ) 165 | scheduler = optim.lr_scheduler.MultiStepLR( 166 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 167 | ) 168 | 169 | if len(self._multiple_gpus) > 1: 170 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 171 | self._network.to(self._device) 172 | if self._old_network is not None: 173 | self._old_network.to(self._device) 174 | 175 | self._run(train_loader, test_loader, optimizer, scheduler, stage="training") 176 | 177 | def _stage2_bias_correction(self, val_loader, test_loader): 178 | if isinstance(self._network, nn.DataParallel): 179 | self._network = self._network.module 180 | network_params = [ 181 | { 182 | "params": self._network.bias_layers[-1].parameters(), 183 | "lr": lrate, 184 | "weight_decay": weight_decay, 185 | } 186 | ] 187 | optimizer = optim.SGD( 188 | network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay 189 | ) 190 | scheduler = optim.lr_scheduler.MultiStepLR( 191 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 192 | ) 193 | 194 | if len(self._multiple_gpus) > 1: 195 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 196 | self._network.to(self._device) 197 | 198 | self._run( 199 | val_loader, test_loader, optimizer, scheduler, stage="bias_correction" 200 | ) 201 | 202 | def _log_bias_params(self): 203 | logging.info("Parameters of bias layer:") 204 | params = self._network.get_bias_params() 205 | for i, param in enumerate(params): 206 | logging.info("{} => {:.3f}, {:.3f}".format(i, param[0], param[1])) 207 | -------------------------------------------------------------------------------- /models/wa.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from models.base import BaseLearner 10 | from utils.inc_net import IncrementalNet 11 | from utils.toolkit import target2onehot, tensor2numpy 12 | 13 | EPSILON = 1e-8 14 | 15 | 16 | init_epoch = 200 17 | init_lr = 0.1 18 | init_milestones = [60, 120, 170] 19 | init_lr_decay = 0.1 20 | init_weight_decay = 0.0005 21 | 22 | 23 | epochs = 170 24 | lrate = 0.1 25 | milestones = [60, 100, 140] 26 | lrate_decay = 0.1 27 | batch_size = 128 28 | weight_decay = 2e-4 29 | num_workers = 8 30 | T = 2 31 | 32 | 33 | class WA(BaseLearner): 34 | def __init__(self, args): 35 | super().__init__(args) 36 | self._network = IncrementalNet(args, False) 37 | 38 | def after_task(self): 39 | if self._cur_task > 0: 40 | self._network.weight_align(self._total_classes - self._known_classes) 41 | self._old_network = self._network.copy().freeze() 42 | self._known_classes = self._total_classes 43 | logging.info("Exemplar size: {}".format(self.exemplar_size)) 44 | 45 | def incremental_train(self, data_manager): 46 | self._cur_task += 1 47 | self._total_classes = self._known_classes + data_manager.get_task_size( 48 | self._cur_task 49 | ) 50 | self._network.update_fc(self._total_classes) 51 | logging.info( 52 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 53 | ) 54 | 55 | # Loader 56 | train_dataset = data_manager.get_dataset( 57 | np.arange(self._known_classes, self._total_classes), 58 | source="train", 59 | mode="train", 60 | appendent=self._get_memory(), 61 | ) 62 | self.train_loader = DataLoader( 63 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 64 | ) 65 | test_dataset = data_manager.get_dataset( 66 | np.arange(0, self._total_classes), source="test", mode="test" 67 | ) 68 | self.test_loader = DataLoader( 69 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 70 | ) 71 | 72 | # Procedure 73 | if len(self._multiple_gpus) > 1: 74 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 75 | self._train(self.train_loader, self.test_loader) 76 | self.build_rehearsal_memory(data_manager, self.samples_per_class) 77 | if len(self._multiple_gpus) > 1: 78 | self._network = self._network.module 79 | 80 | def _train(self, train_loader, test_loader): 81 | self._network.to(self._device) 82 | if self._old_network is not None: 83 | self._old_network.to(self._device) 84 | 85 | if self._cur_task == 0: 86 | optimizer = optim.SGD( 87 | self._network.parameters(), 88 | momentum=0.9, 89 | lr=init_lr, 90 | weight_decay=init_weight_decay, 91 | ) 92 | scheduler = optim.lr_scheduler.MultiStepLR( 93 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 94 | ) 95 | self._init_train(train_loader, test_loader, optimizer, scheduler) 96 | else: 97 | optimizer = optim.SGD( 98 | self._network.parameters(), 99 | lr=lrate, 100 | momentum=0.9, 101 | weight_decay=weight_decay, 102 | ) # 1e-5 103 | scheduler = optim.lr_scheduler.MultiStepLR( 104 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 105 | ) 106 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 107 | if len(self._multiple_gpus) > 1: 108 | self._network.module.weight_align( 109 | self._total_classes - self._known_classes 110 | ) 111 | else: 112 | self._network.weight_align(self._total_classes - self._known_classes) 113 | 114 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 115 | prog_bar = tqdm(range(init_epoch)) 116 | for _, epoch in enumerate(prog_bar): 117 | self._network.train() 118 | losses = 0.0 119 | correct, total = 0, 0 120 | for i, (_, inputs, targets) in enumerate(train_loader): 121 | inputs, targets = inputs.to(self._device), targets.to(self._device) 122 | logits = self._network(inputs)["logits"] 123 | 124 | loss = F.cross_entropy(logits, targets) 125 | optimizer.zero_grad() 126 | loss.backward() 127 | optimizer.step() 128 | losses += loss.item() 129 | 130 | _, preds = torch.max(logits, dim=1) 131 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 132 | total += len(targets) 133 | 134 | scheduler.step() 135 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 136 | 137 | if epoch % 5 == 0: 138 | test_acc = self._compute_accuracy(self._network, test_loader) 139 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 140 | self._cur_task, 141 | epoch + 1, 142 | init_epoch, 143 | losses / len(train_loader), 144 | train_acc, 145 | test_acc, 146 | ) 147 | else: 148 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 149 | self._cur_task, 150 | epoch + 1, 151 | init_epoch, 152 | losses / len(train_loader), 153 | train_acc, 154 | ) 155 | 156 | prog_bar.set_description(info) 157 | 158 | logging.info(info) 159 | 160 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 161 | kd_lambda = self._known_classes / self._total_classes 162 | prog_bar = tqdm(range(epochs)) 163 | for _, epoch in enumerate(prog_bar): 164 | self._network.train() 165 | losses = 0.0 166 | correct, total = 0, 0 167 | for i, (_, inputs, targets) in enumerate(train_loader): 168 | inputs, targets = inputs.to(self._device), targets.to(self._device) 169 | logits = self._network(inputs)["logits"] 170 | 171 | loss_clf = F.cross_entropy(logits, targets) 172 | loss_kd = _KD_loss( 173 | logits[:, : self._known_classes], 174 | self._old_network(inputs)["logits"], 175 | T, 176 | ) 177 | 178 | loss = (1-kd_lambda) * loss_clf + kd_lambda * loss_kd 179 | 180 | optimizer.zero_grad() 181 | loss.backward() 182 | optimizer.step() 183 | losses += loss.item() 184 | 185 | # acc 186 | _, preds = torch.max(logits, dim=1) 187 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 188 | total += len(targets) 189 | 190 | scheduler.step() 191 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 192 | if epoch % 5 == 0: 193 | test_acc = self._compute_accuracy(self._network, test_loader) 194 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 195 | self._cur_task, 196 | epoch + 1, 197 | epochs, 198 | losses / len(train_loader), 199 | train_acc, 200 | test_acc, 201 | ) 202 | else: 203 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 204 | self._cur_task, 205 | epoch + 1, 206 | epochs, 207 | losses / len(train_loader), 208 | train_acc, 209 | ) 210 | prog_bar.set_description(info) 211 | logging.info(info) 212 | 213 | 214 | def _KD_loss(pred, soft, T): 215 | pred = torch.log_softmax(pred / T, dim=1) 216 | soft = torch.softmax(soft / T, dim=1) 217 | return -1 * torch.mul(soft, pred).sum() / pred.shape[0] 218 | -------------------------------------------------------------------------------- /models/der.py: -------------------------------------------------------------------------------- 1 | # Please note that the current implementation of DER only contains the dynamic expansion process, since masking and pruning are not implemented by the source repo. 2 | import logging 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from models.base import BaseLearner 11 | from utils.inc_net import DERNet, IncrementalNet 12 | from utils.toolkit import count_parameters, target2onehot, tensor2numpy 13 | 14 | EPSILON = 1e-8 15 | 16 | init_epoch = 200 17 | init_lr = 0.1 18 | init_milestones = [60, 120, 170] 19 | init_lr_decay = 0.1 20 | init_weight_decay = 0.0005 21 | 22 | 23 | epochs = 170 24 | lrate = 0.1 25 | milestones = [80, 120, 150] 26 | lrate_decay = 0.1 27 | batch_size = 128 28 | weight_decay = 2e-4 29 | num_workers = 8 30 | T = 2 31 | 32 | 33 | class DER(BaseLearner): 34 | def __init__(self, args): 35 | super().__init__(args) 36 | self._network = DERNet(args, False) 37 | 38 | def after_task(self): 39 | self._known_classes = self._total_classes 40 | logging.info("Exemplar size: {}".format(self.exemplar_size)) 41 | 42 | def incremental_train(self, data_manager): 43 | self._cur_task += 1 44 | self._total_classes = self._known_classes + data_manager.get_task_size( 45 | self._cur_task 46 | ) 47 | self._network.update_fc(self._total_classes) 48 | logging.info( 49 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 50 | ) 51 | 52 | if self._cur_task > 0: 53 | for i in range(self._cur_task): 54 | for p in self._network.convnets[i].parameters(): 55 | p.requires_grad = False 56 | 57 | logging.info("All params: {}".format(count_parameters(self._network))) 58 | logging.info( 59 | "Trainable params: {}".format(count_parameters(self._network, True)) 60 | ) 61 | 62 | train_dataset = data_manager.get_dataset( 63 | np.arange(self._known_classes, self._total_classes), 64 | source="train", 65 | mode="train", 66 | appendent=self._get_memory(), 67 | ) 68 | self.train_loader = DataLoader( 69 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 70 | ) 71 | test_dataset = data_manager.get_dataset( 72 | np.arange(0, self._total_classes), source="test", mode="test" 73 | ) 74 | self.test_loader = DataLoader( 75 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 76 | ) 77 | 78 | if len(self._multiple_gpus) > 1: 79 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 80 | self._train(self.train_loader, self.test_loader) 81 | self.build_rehearsal_memory(data_manager, self.samples_per_class) 82 | if len(self._multiple_gpus) > 1: 83 | self._network = self._network.module 84 | 85 | def train(self): 86 | self._network.train() 87 | if len(self._multiple_gpus) > 1 : 88 | self._network_module_ptr = self._network.module 89 | else: 90 | self._network_module_ptr = self._network 91 | self._network_module_ptr.convnets[-1].train() 92 | if self._cur_task >= 1: 93 | for i in range(self._cur_task): 94 | self._network_module_ptr.convnets[i].eval() 95 | 96 | def _train(self, train_loader, test_loader): 97 | self._network.to(self._device) 98 | if self._cur_task == 0: 99 | optimizer = optim.SGD( 100 | filter(lambda p: p.requires_grad, self._network.parameters()), 101 | momentum=0.9, 102 | lr=init_lr, 103 | weight_decay=init_weight_decay, 104 | ) 105 | scheduler = optim.lr_scheduler.MultiStepLR( 106 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 107 | ) 108 | self._init_train(train_loader, test_loader, optimizer, scheduler) 109 | else: 110 | optimizer = optim.SGD( 111 | filter(lambda p: p.requires_grad, self._network.parameters()), 112 | lr=lrate, 113 | momentum=0.9, 114 | weight_decay=weight_decay, 115 | ) 116 | scheduler = optim.lr_scheduler.MultiStepLR( 117 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 118 | ) 119 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 120 | if len(self._multiple_gpus) > 1: 121 | self._network.module.weight_align( 122 | self._total_classes - self._known_classes 123 | ) 124 | else: 125 | self._network.weight_align(self._total_classes - self._known_classes) 126 | 127 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 128 | prog_bar = tqdm(range(init_epoch)) 129 | for _, epoch in enumerate(prog_bar): 130 | self.train() 131 | losses = 0.0 132 | correct, total = 0, 0 133 | for i, (_, inputs, targets) in enumerate(train_loader): 134 | inputs, targets = inputs.to(self._device), targets.to(self._device) 135 | logits = self._network(inputs)["logits"] 136 | 137 | loss = F.cross_entropy(logits, targets) 138 | optimizer.zero_grad() 139 | loss.backward() 140 | optimizer.step() 141 | losses += loss.item() 142 | 143 | _, preds = torch.max(logits, dim=1) 144 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 145 | total += len(targets) 146 | 147 | scheduler.step() 148 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 149 | 150 | if epoch % 5 == 0: 151 | test_acc = self._compute_accuracy(self._network, test_loader) 152 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 153 | self._cur_task, 154 | epoch + 1, 155 | init_epoch, 156 | losses / len(train_loader), 157 | train_acc, 158 | test_acc, 159 | ) 160 | else: 161 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 162 | self._cur_task, 163 | epoch + 1, 164 | init_epoch, 165 | losses / len(train_loader), 166 | train_acc, 167 | ) 168 | prog_bar.set_description(info) 169 | 170 | logging.info(info) 171 | 172 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 173 | prog_bar = tqdm(range(epochs)) 174 | for _, epoch in enumerate(prog_bar): 175 | self.train() 176 | losses = 0.0 177 | losses_clf = 0.0 178 | losses_aux = 0.0 179 | correct, total = 0, 0 180 | for i, (_, inputs, targets) in enumerate(train_loader): 181 | inputs, targets = inputs.to(self._device), targets.to(self._device) 182 | outputs = self._network(inputs) 183 | logits, aux_logits = outputs["logits"], outputs["aux_logits"] 184 | loss_clf = F.cross_entropy(logits, targets) 185 | aux_targets = targets.clone() 186 | aux_targets = torch.where( 187 | aux_targets - self._known_classes + 1 > 0, 188 | aux_targets - self._known_classes + 1, 189 | 0, 190 | ) 191 | loss_aux = F.cross_entropy(aux_logits, aux_targets) 192 | loss = loss_clf + loss_aux 193 | 194 | optimizer.zero_grad() 195 | loss.backward() 196 | optimizer.step() 197 | losses += loss.item() 198 | losses_aux += loss_aux.item() 199 | losses_clf += loss_clf.item() 200 | 201 | _, preds = torch.max(logits, dim=1) 202 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 203 | total += len(targets) 204 | 205 | scheduler.step() 206 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 207 | if epoch % 5 == 0: 208 | test_acc = self._compute_accuracy(self._network, test_loader) 209 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 210 | self._cur_task, 211 | epoch + 1, 212 | epochs, 213 | losses / len(train_loader), 214 | losses_clf / len(train_loader), 215 | losses_aux / len(train_loader), 216 | train_acc, 217 | test_acc, 218 | ) 219 | else: 220 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}".format( 221 | self._cur_task, 222 | epoch + 1, 223 | epochs, 224 | losses / len(train_loader), 225 | losses_clf / len(train_loader), 226 | losses_aux / len(train_loader), 227 | train_acc, 228 | ) 229 | prog_bar.set_description(info) 230 | logging.info(info) 231 | -------------------------------------------------------------------------------- /models/aper_finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | Re-implementation of APER-Finetune (https://arxiv.org/abs/2303.07338) without pre-trained weights. 3 | Note: this method was initially designed for PTMs, whereas it has been slightly modified here to adapt to the train-from-scratch setting. 4 | Please refer to the original implementation (https://github.com/zhoudw-zdw/RevisitingCIL) if you are using pre-trained weights. 5 | """ 6 | 7 | import logging 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | from tqdm import tqdm 12 | from torch import optim 13 | from torch.nn import functional as F 14 | from torch.utils.data import DataLoader 15 | from utils.inc_net import IncrementalNet, SimpleCosineIncrementalNet, MultiBranchCosineIncrementalNet 16 | from models.base import BaseLearner 17 | from utils.toolkit import target2onehot, tensor2numpy 18 | import copy 19 | 20 | 21 | num_workers = 8 22 | 23 | 24 | class APER_FINETUNE(BaseLearner): 25 | def __init__(self, args): 26 | super().__init__(args) 27 | 28 | self._network = SimpleCosineIncrementalNet(args, False) 29 | self.batch_size = args.get("batch_size", 128) 30 | self.init_lr = args.get("init_lr", 0.01) 31 | self.finetune_lr = args.get("finetune_lr", 0.001) 32 | 33 | self.init_weight_decay = args.get("init_weight_decay", 0.0005) 34 | self.weight_decay = args.get("weight_decay", 0.005) 35 | self.min_lr = args.get('min_lr', 1e-8) 36 | self.args = args 37 | 38 | self.trained_epoch = args.get('trained_epoch', 50) 39 | self.tuned_epoch = args.get('tuned_epoch', 20) 40 | self.trained_model = None 41 | 42 | def after_task(self): 43 | self._known_classes = self._total_classes 44 | 45 | def replace_fc(self, trainloader, model, args): 46 | # replace fc.weight with the embedding average of train data 47 | model = model.eval() 48 | embedding_list = [] 49 | label_list = [] 50 | with torch.no_grad(): 51 | for i, batch in enumerate(trainloader): 52 | (_, data, label) = batch 53 | data = data.cuda() 54 | label = label.cuda() 55 | embedding = model(data)['features'] 56 | embedding_list.append(embedding.cpu()) 57 | label_list.append(label.cpu()) 58 | embedding_list = torch.cat(embedding_list, dim=0) 59 | label_list = torch.cat(label_list, dim=0) 60 | 61 | class_list = np.unique(self.train_dataset.labels) 62 | proto_list = [] 63 | for class_index in class_list: 64 | # print('Replacing...',class_index) 65 | data_index = (label_list == class_index).nonzero().squeeze(-1) 66 | embedding = embedding_list[data_index] 67 | proto = embedding.mean(0) 68 | self._network.fc.weight.data[class_index] = proto 69 | return model 70 | 71 | def incremental_train(self, data_manager): 72 | self._cur_task += 1 73 | self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) 74 | self._network.update_fc(self._total_classes) 75 | logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes)) 76 | 77 | train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source="train", 78 | mode="train", ) 79 | self.train_dataset = train_dataset 80 | self.data_manager = data_manager 81 | self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers) 82 | test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test") 83 | self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers) 84 | 85 | train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), 86 | source="train", mode="test", ) 87 | self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=self.batch_size, 88 | shuffle=True, num_workers=num_workers) 89 | 90 | if len(self._multiple_gpus) > 1: 91 | print('Multiple GPUs') 92 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 93 | self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet) 94 | if len(self._multiple_gpus) > 1: 95 | self._network = self._network.module 96 | 97 | def _train(self, train_loader, test_loader, train_loader_for_protonet): 98 | 99 | self._network.to(self._device) 100 | 101 | if self._cur_task == 0: 102 | if self.args['optimizer'] == 'sgd': 103 | optimizer = optim.SGD(self._network.parameters(), momentum=0.9, lr=self.init_lr, 104 | weight_decay=self.init_weight_decay) 105 | elif self.args['optimizer'] == 'adam': 106 | optimizer = optim.AdamW(self._network.parameters(), lr=self.init_lr, 107 | weight_decay=self.init_weight_decay) 108 | 109 | total_params = sum(p.numel() for p in self._network.parameters()) 110 | print(f'{total_params:,} total parameters.') 111 | total_trainable_params = sum( 112 | p.numel() for p in self._network.parameters() if p.requires_grad) 113 | print(f'{total_trainable_params:,} training parameters.') 114 | 115 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trained_epoch, 116 | eta_min=self.min_lr) 117 | self._init_train(train_loader, test_loader, optimizer, scheduler, self.trained_epoch) 118 | self.replace_fc(train_loader_for_protonet, self._network, None) 119 | self.trained_model = copy.deepcopy(self._network.cpu()) 120 | 121 | self.construct_dual_branch_network() 122 | 123 | return 124 | 125 | elif self._cur_task == 1: 126 | self._network = SimpleCosineIncrementalNet(self.args, False) 127 | self._network.regenerate_fc(self.args['init_cls']) # to be compatible with trained_model 128 | self._network.to(self._device) 129 | msg = self._network.load_state_dict(self.trained_model.state_dict(), strict=False) 130 | logging.info('INFO -- state dict loaded', msg) 131 | self._network.regenerate_fc(self.args['increment']) 132 | logging.info('Fully finetuning ...') 133 | 134 | if self.args['optimizer'] == 'sgd': 135 | optimizer = optim.SGD(self._network.parameters(), momentum=0.9, lr=self.finetune_lr, 136 | weight_decay=self.weight_decay) 137 | elif self.args['optimizer'] == 'adam': 138 | optimizer = optim.AdamW(self._network.parameters(), lr=self.finetune_lr, weight_decay=self.weight_decay) 139 | 140 | total_trainable_params = sum( 141 | p.numel() for p in self._network.parameters() if p.requires_grad) 142 | print(f'{total_trainable_params:,} training parameters.') 143 | 144 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.tuned_epoch, 145 | eta_min=self.min_lr) 146 | self._init_train(train_loader, test_loader, optimizer, scheduler, self.tuned_epoch) 147 | self.construct_dual_branch_network() 148 | 149 | self.replace_fc(train_loader_for_protonet, self._network, None) 150 | 151 | def construct_dual_branch_network(self): 152 | if self._cur_task == 0: 153 | network = MultiBranchCosineIncrementalNet(self.args, False) 154 | network.construct_dual_branch_network(self._network, self._network, self.args['init_cls']) 155 | else: 156 | network = MultiBranchCosineIncrementalNet(self.args, False) 157 | self.trained_model = self.trained_model.to(self._device) 158 | network.construct_dual_branch_network(self.trained_model, self._network, 159 | self.args['init_cls'] + self.args['increment']) 160 | 161 | network.fc.weight.data[:self.args['init_cls'], :] = self.trained_model.fc.weight.data.repeat(1, 2) 162 | self._network = network.to(self._device) 163 | 164 | def _init_train(self, train_loader, test_loader, optimizer, scheduler, epc): 165 | prog_bar = tqdm(range(epc)) 166 | for _, epoch in enumerate(prog_bar): 167 | self._network.train() 168 | losses = 0.0 169 | correct, total = 0, 0 170 | for i, (_, inputs, targets) in enumerate(train_loader): 171 | inputs, targets = inputs.to(self._device), targets.to(self._device) 172 | logits = self._network(inputs)["logits"] 173 | 174 | if self._cur_task == 1: 175 | targets -= self.args['init_cls'] 176 | loss = F.cross_entropy(logits, targets) 177 | optimizer.zero_grad() 178 | loss.backward() 179 | optimizer.step() 180 | losses += loss.item() 181 | 182 | _, preds = torch.max(logits, dim=1) 183 | # print('preds', preds) 184 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 185 | total += len(targets) 186 | 187 | scheduler.step() 188 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 189 | 190 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 191 | self._cur_task, 192 | epoch + 1, 193 | epc, 194 | losses / len(train_loader), 195 | train_acc, 196 | # test_acc, 197 | ) 198 | prog_bar.set_description(info) 199 | 200 | logging.info(info) 201 | -------------------------------------------------------------------------------- /models/ewc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from models.base import BaseLearner 10 | from models.podnet import pod_spatial_loss 11 | from utils.inc_net import IncrementalNet 12 | from utils.toolkit import target2onehot, tensor2numpy 13 | 14 | EPSILON = 1e-8 15 | 16 | init_epoch = 200 17 | init_lr = 0.1 18 | init_milestones = [60, 120, 170] 19 | init_lr_decay = 0.1 20 | init_weight_decay = 0.0005 21 | 22 | 23 | epochs = 180 24 | lrate = 0.1 25 | milestones = [70, 120, 150] 26 | lrate_decay = 0.1 27 | batch_size = 128 28 | weight_decay = 2e-4 29 | num_workers = 4 30 | T = 2 31 | lamda = 1000 32 | fishermax = 0.0001 33 | 34 | 35 | class EWC(BaseLearner): 36 | def __init__(self, args): 37 | super().__init__(args) 38 | self.fisher = None 39 | self._network = IncrementalNet(args, False) 40 | 41 | def after_task(self): 42 | self._known_classes = self._total_classes 43 | 44 | def incremental_train(self, data_manager): 45 | self._cur_task += 1 46 | self._total_classes = self._known_classes + data_manager.get_task_size( 47 | self._cur_task 48 | ) 49 | self._network.update_fc(self._total_classes) 50 | logging.info( 51 | "Learning on {}-{}".format(self._known_classes, self._total_classes) 52 | ) 53 | 54 | train_dataset = data_manager.get_dataset( 55 | np.arange(self._known_classes, self._total_classes), 56 | source="train", 57 | mode="train", 58 | ) 59 | self.train_loader = DataLoader( 60 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 61 | ) 62 | test_dataset = data_manager.get_dataset( 63 | np.arange(0, self._total_classes), source="test", mode="test" 64 | ) 65 | self.test_loader = DataLoader( 66 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 67 | ) 68 | 69 | if len(self._multiple_gpus) > 1: 70 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 71 | self._train(self.train_loader, self.test_loader) 72 | if len(self._multiple_gpus) > 1: 73 | self._network = self._network.module 74 | 75 | if self.fisher is None: 76 | self.fisher = self.getFisherDiagonal(self.train_loader) 77 | else: 78 | alpha = self._known_classes / self._total_classes 79 | new_finsher = self.getFisherDiagonal(self.train_loader) 80 | for n, p in new_finsher.items(): 81 | new_finsher[n][: len(self.fisher[n])] = ( 82 | alpha * self.fisher[n] 83 | + (1 - alpha) * new_finsher[n][: len(self.fisher[n])] 84 | ) 85 | self.fisher = new_finsher 86 | self.mean = { 87 | n: p.clone().detach() 88 | for n, p in self._network.named_parameters() 89 | if p.requires_grad 90 | } 91 | 92 | def _train(self, train_loader, test_loader): 93 | self._network.to(self._device) 94 | if self._cur_task == 0: 95 | optimizer = optim.SGD( 96 | self._network.parameters(), 97 | momentum=0.9, 98 | lr=init_lr, 99 | weight_decay=init_weight_decay, 100 | ) 101 | scheduler = optim.lr_scheduler.MultiStepLR( 102 | optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay 103 | ) 104 | self._init_train(train_loader, test_loader, optimizer, scheduler) 105 | else: 106 | optimizer = optim.SGD( 107 | self._network.parameters(), 108 | lr=lrate, 109 | momentum=0.9, 110 | weight_decay=weight_decay, 111 | ) 112 | scheduler = optim.lr_scheduler.MultiStepLR( 113 | optimizer=optimizer, milestones=milestones, gamma=lrate_decay 114 | ) 115 | self._update_representation(train_loader, test_loader, optimizer, scheduler) 116 | 117 | def _init_train(self, train_loader, test_loader, optimizer, scheduler): 118 | prog_bar = tqdm(range(init_epoch)) 119 | for _, epoch in enumerate(prog_bar): 120 | self._network.train() 121 | losses = 0.0 122 | correct, total = 0, 0 123 | for i, (_, inputs, targets) in enumerate(train_loader): 124 | inputs, targets = inputs.to(self._device), targets.to(self._device) 125 | logits = self._network(inputs)["logits"] 126 | loss = F.cross_entropy(logits, targets) 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | losses += loss.item() 131 | 132 | _, preds = torch.max(logits, dim=1) 133 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 134 | total += len(targets) 135 | 136 | scheduler.step() 137 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 138 | 139 | if epoch % 5 == 0: 140 | test_acc = self._compute_accuracy(self._network, test_loader) 141 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 142 | self._cur_task, 143 | epoch + 1, 144 | init_epoch, 145 | losses / len(train_loader), 146 | train_acc, 147 | test_acc, 148 | ) 149 | else: 150 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 151 | self._cur_task, 152 | epoch + 1, 153 | init_epoch, 154 | losses / len(train_loader), 155 | train_acc, 156 | ) 157 | 158 | prog_bar.set_description(info) 159 | 160 | logging.info(info) 161 | 162 | def _update_representation(self, train_loader, test_loader, optimizer, scheduler): 163 | prog_bar = tqdm(range(epochs)) 164 | for _, epoch in enumerate(prog_bar): 165 | self._network.train() 166 | losses = 0.0 167 | correct, total = 0, 0 168 | for i, (_, inputs, targets) in enumerate(train_loader): 169 | inputs, targets = inputs.to(self._device), targets.to(self._device) 170 | logits = self._network(inputs)["logits"] 171 | 172 | loss_clf = F.cross_entropy( 173 | logits[:, self._known_classes :], targets - self._known_classes 174 | ) 175 | loss_ewc = self.compute_ewc() 176 | loss = loss_clf + lamda * loss_ewc 177 | 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | losses += loss.item() 182 | 183 | _, preds = torch.max(logits, dim=1) 184 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 185 | total += len(targets) 186 | 187 | scheduler.step() 188 | train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) 189 | if epoch % 5 == 0: 190 | test_acc = self._compute_accuracy(self._network, test_loader) 191 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( 192 | self._cur_task, 193 | epoch + 1, 194 | epochs, 195 | losses / len(train_loader), 196 | train_acc, 197 | test_acc, 198 | ) 199 | else: 200 | info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( 201 | self._cur_task, 202 | epoch + 1, 203 | epochs, 204 | losses / len(train_loader), 205 | train_acc, 206 | ) 207 | prog_bar.set_description(info) 208 | logging.info(info) 209 | 210 | def compute_ewc(self): 211 | loss = 0 212 | if len(self._multiple_gpus) > 1: 213 | for n, p in self._network.module.named_parameters(): 214 | if n in self.fisher.keys(): 215 | loss += ( 216 | torch.sum( 217 | (self.fisher[n]) 218 | * (p[: len(self.mean[n])] - self.mean[n]).pow(2) 219 | ) 220 | / 2 221 | ) 222 | else: 223 | for n, p in self._network.named_parameters(): 224 | if n in self.fisher.keys(): 225 | loss += ( 226 | torch.sum( 227 | (self.fisher[n]) 228 | * (p[: len(self.mean[n])] - self.mean[n]).pow(2) 229 | ) 230 | / 2 231 | ) 232 | return loss 233 | 234 | def getFisherDiagonal(self, train_loader): 235 | fisher = { 236 | n: torch.zeros(p.shape).to(self._device) 237 | for n, p in self._network.named_parameters() 238 | if p.requires_grad 239 | } 240 | self._network.train() 241 | optimizer = optim.SGD(self._network.parameters(), lr=lrate) 242 | for i, (_, inputs, targets) in enumerate(train_loader): 243 | inputs, targets = inputs.to(self._device), targets.to(self._device) 244 | logits = self._network(inputs)["logits"] 245 | loss = torch.nn.functional.cross_entropy(logits, targets) 246 | optimizer.zero_grad() 247 | loss.backward() 248 | for n, p in self._network.named_parameters(): 249 | if p.grad is not None: 250 | fisher[n] += p.grad.pow(2).clone() 251 | for n, p in fisher.items(): 252 | fisher[n] = p / len(train_loader) 253 | fisher[n] = torch.min(fisher[n], torch.tensor(fishermax)) 254 | return fisher 255 | -------------------------------------------------------------------------------- /utils/autoaugment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .ops import * 3 | 4 | 5 | class ImageNetPolicy(object): 6 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 7 | 8 | Example: 9 | >>> policy = ImageNetPolicy() 10 | >>> transformed = policy(image) 11 | 12 | Example as a PyTorch Transform: 13 | >>> transform = transforms.Compose([ 14 | >>> transforms.Resize(256), 15 | >>> ImageNetPolicy(), 16 | >>> transforms.ToTensor()]) 17 | """ 18 | def __init__(self, fillcolor=(128, 128, 128)): 19 | self.policies = [ 20 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 21 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 22 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 23 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 24 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 25 | 26 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 27 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 28 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 29 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 30 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 31 | 32 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 33 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 34 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 35 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 37 | 38 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 39 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 40 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 41 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 42 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 43 | 44 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 45 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 46 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 47 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 48 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 49 | ] 50 | 51 | def __call__(self, img): 52 | policy_idx = random.randint(0, len(self.policies) - 1) 53 | return self.policies[policy_idx](img) 54 | 55 | def __repr__(self): 56 | return "AutoAugment ImageNet Policy" 57 | 58 | 59 | class CIFAR10Policy(object): 60 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 61 | 62 | Example: 63 | >>> policy = CIFAR10Policy() 64 | >>> transformed = policy(image) 65 | 66 | Example as a PyTorch Transform: 67 | >>> transform=transforms.Compose([ 68 | >>> transforms.Resize(256), 69 | >>> CIFAR10Policy(), 70 | >>> transforms.ToTensor()]) 71 | """ 72 | def __init__(self, fillcolor=(128, 128, 128)): 73 | self.policies = [ 74 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 75 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 76 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 77 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 78 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 79 | 80 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 81 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 82 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 83 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 84 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 85 | 86 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 87 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 88 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 89 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 90 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 91 | 92 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 93 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 94 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 95 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 96 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 97 | 98 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 99 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 100 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 101 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 102 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 103 | ] 104 | 105 | def __call__(self, img): 106 | policy_idx = random.randint(0, len(self.policies) - 1) 107 | return self.policies[policy_idx](img) 108 | 109 | def __repr__(self): 110 | return "AutoAugment CIFAR10 Policy" 111 | 112 | 113 | class SVHNPolicy(object): 114 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 115 | 116 | Example: 117 | >>> policy = SVHNPolicy() 118 | >>> transformed = policy(image) 119 | 120 | Example as a PyTorch Transform: 121 | >>> transform=transforms.Compose([ 122 | >>> transforms.Resize(256), 123 | >>> SVHNPolicy(), 124 | >>> transforms.ToTensor()]) 125 | """ 126 | def __init__(self, fillcolor=(128, 128, 128)): 127 | self.policies = [ 128 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 129 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 130 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 131 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 132 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 133 | 134 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 135 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 136 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 137 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 138 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 139 | 140 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 141 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 142 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 143 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 144 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 145 | 146 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 147 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 148 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 149 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 150 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 151 | 152 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 153 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 154 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 155 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 156 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 157 | ] 158 | 159 | def __call__(self, img): 160 | policy_idx = random.randint(0, len(self.policies) - 1) 161 | return self.policies[policy_idx](img) 162 | 163 | def __repr__(self): 164 | return "AutoAugment SVHN Policy" 165 | 166 | 167 | class SubPolicy(object): 168 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 169 | ranges = { 170 | "shearX": np.linspace(0, 0.3, 10), 171 | "shearY": np.linspace(0, 0.3, 10), 172 | "translateX": np.linspace(0, 150 / 331, 10), 173 | "translateY": np.linspace(0, 150 / 331, 10), 174 | "rotate": np.linspace(0, 30, 10), 175 | "color": np.linspace(0.0, 0.9, 10), 176 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int), 177 | "solarize": np.linspace(256, 0, 10), 178 | "contrast": np.linspace(0.0, 0.9, 10), 179 | "sharpness": np.linspace(0.0, 0.9, 10), 180 | "brightness": np.linspace(0.0, 0.9, 10), 181 | "autocontrast": [0] * 10, 182 | "equalize": [0] * 10, 183 | "invert": [0] * 10 184 | } 185 | 186 | func = { 187 | "shearX": ShearX(fillcolor=fillcolor), 188 | "shearY": ShearY(fillcolor=fillcolor), 189 | "translateX": TranslateX(fillcolor=fillcolor), 190 | "translateY": TranslateY(fillcolor=fillcolor), 191 | "rotate": Rotate(), 192 | "color": Color(), 193 | "posterize": Posterize(), 194 | "solarize": Solarize(), 195 | "contrast": Contrast(), 196 | "sharpness": Sharpness(), 197 | "brightness": Brightness(), 198 | "autocontrast": AutoContrast(), 199 | "equalize": Equalize(), 200 | "invert": Invert() 201 | } 202 | 203 | self.p1 = p1 204 | self.operation1 = func[operation1] 205 | self.magnitude1 = ranges[operation1][magnitude_idx1] 206 | self.p2 = p2 207 | self.operation2 = func[operation2] 208 | self.magnitude2 = ranges[operation2][magnitude_idx2] 209 | 210 | def __call__(self, img): 211 | if random.random() < self.p1: 212 | img = self.operation1(img, self.magnitude1) 213 | if random.random() < self.p2: 214 | img = self.operation2(img, self.magnitude2) 215 | return img 216 | -------------------------------------------------------------------------------- /models/pa2s.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader,Dataset 9 | from models.base import BaseLearner 10 | from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet 11 | from utils.toolkit import count_parameters, target2onehot, tensor2numpy 12 | 13 | EPSILON = 1e-8 14 | 15 | 16 | class PASS(BaseLearner): 17 | def __init__(self, args): 18 | super().__init__(args) 19 | self.args = args 20 | self._network = IncrementalNet(args, False) 21 | self._protos = [] 22 | self._radius = 0 23 | self._radiuses = [] 24 | 25 | 26 | def after_task(self): 27 | self._known_classes = self._total_classes 28 | self._old_network = self._network.copy().freeze() 29 | if hasattr(self._old_network,"module"): 30 | self.old_network_module_ptr = self._old_network.module 31 | else: 32 | self.old_network_module_ptr = self._old_network 33 | self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"])) 34 | def incremental_train(self, data_manager): 35 | self.data_manager = data_manager 36 | self._cur_task += 1 37 | 38 | self._total_classes = self._known_classes + \ 39 | data_manager.get_task_size(self._cur_task) 40 | self._network.update_fc(self._total_classes*4) 41 | self._network_module_ptr = self._network 42 | logging.info( 43 | 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) 44 | 45 | 46 | logging.info('All params: {}'.format(count_parameters(self._network))) 47 | logging.info('Trainable params: {}'.format( 48 | count_parameters(self._network, True))) 49 | 50 | train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', 51 | mode='train', appendent=self._get_memory()) 52 | self.train_loader = DataLoader( 53 | train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) 54 | test_dataset = data_manager.get_dataset( 55 | np.arange(0, self._total_classes), source='test', mode='test') 56 | self.test_loader = DataLoader( 57 | test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) 58 | 59 | if len(self._multiple_gpus) > 1: 60 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 61 | self._train(self.train_loader, self.test_loader) 62 | 63 | if len(self._multiple_gpus) > 1: 64 | self._network = self._network.module 65 | 66 | 67 | def _train(self, train_loader, test_loader): 68 | 69 | resume = False 70 | if self._cur_task in []: 71 | self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"]) 72 | resume = True 73 | self._network.to(self._device) 74 | if hasattr(self._network, "module"): 75 | self._network_module_ptr = self._network.module 76 | if not resume: 77 | self._epoch_num = self.args["epochs"] 78 | optimizer = torch.optim.Adam(self._network.parameters(), lr=self.args["lr"], weight_decay=self.args["weight_decay"]) 79 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"]) 80 | self._train_function(train_loader, test_loader, optimizer, scheduler) 81 | self._build_protos() 82 | 83 | 84 | def _build_protos(self): 85 | with torch.no_grad(): 86 | for class_idx in range(self._known_classes, self._total_classes): 87 | data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 88 | mode='test', ret_data=True) 89 | idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) 90 | vectors, _ = self._extract_vectors(idx_loader) 91 | class_mean = np.mean(vectors, axis=0) 92 | self._protos.append(class_mean) 93 | cov = np.cov(vectors.T) 94 | self._radiuses.append(np.trace(cov)/vectors.shape[1]) 95 | self._radius = np.sqrt(np.mean(self._radiuses)) 96 | 97 | def _train_function(self, train_loader, test_loader, optimizer, scheduler): 98 | prog_bar = tqdm(range(self._epoch_num)) 99 | for _, epoch in enumerate(prog_bar): 100 | self._network.train() 101 | losses = 0. 102 | losses_clf, losses_fkd, losses_proto = 0., 0., 0. 103 | correct, total = 0, 0 104 | for i, (_, inputs, targets) in enumerate(train_loader): 105 | inputs, targets = inputs.to( 106 | self._device, non_blocking=True), targets.to(self._device, non_blocking=True) 107 | inputs = torch.stack([torch.rot90(inputs, k, (2, 3)) for k in range(4)], 1) 108 | inputs = inputs.view(-1, 3, 32, 32) 109 | targets = torch.stack([targets * 4 + k for k in range(4)], 1).view(-1) 110 | logits, loss_clf, loss_fkd, loss_proto = self._compute_pass_loss(inputs,targets) 111 | loss = loss_clf + loss_fkd + loss_proto 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | losses += loss.item() 116 | losses_clf += loss_clf.item() 117 | losses_fkd += loss_fkd.item() 118 | losses_proto += loss_proto.item() 119 | _, preds = torch.max(logits, dim=1) 120 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 121 | total += len(targets) 122 | scheduler.step() 123 | train_acc = np.around(tensor2numpy( 124 | correct)*100 / total, decimals=2) 125 | if epoch % 5 != 0: 126 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format( 127 | self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc) 128 | else: 129 | test_acc = self._compute_accuracy(self._network, test_loader) 130 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( 131 | self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc) 132 | prog_bar.set_description(info) 133 | logging.info(info) 134 | 135 | def _compute_pass_loss(self,inputs, targets): 136 | logits = self._network(inputs)["logits"] 137 | loss_clf = F.cross_entropy(logits/self.args["temp"], targets) 138 | 139 | if self._cur_task == 0: 140 | return logits, loss_clf, torch.tensor(0.), torch.tensor(0.) 141 | 142 | features = self._network_module_ptr.extract_vector(inputs) 143 | features_old = self.old_network_module_ptr.extract_vector(inputs) 144 | loss_fkd = self.args["lambda_fkd"] * torch.dist(features, features_old, 2) 145 | 146 | # index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True) 147 | 148 | index = np.random.choice(range(self._known_classes),size=self.args["batch_size"]*int(self._known_classes/(self._total_classes-self._known_classes)),replace=True) 149 | # print(index) 150 | # print(np.concatenate(self._protos)) 151 | proto_features = np.array(self._protos)[index] 152 | # print(proto_features) 153 | proto_targets = 4*index 154 | proto_features = proto_features + np.random.normal(0,1,proto_features.shape)*self._radius 155 | proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True) 156 | proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True) 157 | 158 | 159 | proto_logits = self._network_module_ptr.fc(proto_features)["logits"] 160 | loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets) 161 | return logits, loss_clf, loss_fkd, loss_proto 162 | 163 | 164 | 165 | def _compute_accuracy(self, model, loader): 166 | model.eval() 167 | correct, total = 0, 0 168 | for i, (_, inputs, targets) in enumerate(loader): 169 | inputs = inputs.to(self._device) 170 | with torch.no_grad(): 171 | outputs = model(inputs)["logits"][:,::4] 172 | predicts = torch.max(outputs, dim=1)[1] 173 | correct += (predicts.cpu() == targets).sum() 174 | total += len(targets) 175 | 176 | return np.around(tensor2numpy(correct)*100 / total, decimals=2) 177 | 178 | def _eval_cnn(self, loader): 179 | self._network.eval() 180 | y_pred, y_true = [], [] 181 | for _, (_, inputs, targets) in enumerate(loader): 182 | inputs = inputs.to(self._device) 183 | with torch.no_grad(): 184 | outputs = self._network(inputs)["logits"][:,::4] 185 | predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] 186 | y_pred.append(predicts.cpu().numpy()) 187 | y_true.append(targets.cpu().numpy()) 188 | 189 | return np.concatenate(y_pred), np.concatenate(y_true) 190 | 191 | def eval_task(self): 192 | y_pred, y_true = self._eval_cnn(self.test_loader) 193 | cnn_accy = self._evaluate(y_pred, y_true) 194 | 195 | if hasattr(self, '_class_means'): 196 | y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) 197 | nme_accy = self._evaluate(y_pred, y_true) 198 | elif hasattr(self, '_protos'): 199 | y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None]) 200 | nme_accy = self._evaluate(y_pred, y_true) 201 | else: 202 | nme_accy = None 203 | 204 | return cnn_accy, nme_accy -------------------------------------------------------------------------------- /convs/resnet_cbam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam', 8 | 'resnet152_cbam'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class ChannelAttention(nn.Module): 27 | def __init__(self, in_planes, ratio=16): 28 | super(ChannelAttention, self).__init__() 29 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 30 | self.max_pool = nn.AdaptiveMaxPool2d(1) 31 | 32 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 33 | self.relu1 = nn.ReLU() 34 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 35 | 36 | self.sigmoid = nn.Sigmoid() 37 | 38 | def forward(self, x): 39 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 40 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 41 | out = avg_out + max_out 42 | return self.sigmoid(out) 43 | 44 | 45 | class SpatialAttention(nn.Module): 46 | def __init__(self, kernel_size=7): 47 | super(SpatialAttention, self).__init__() 48 | 49 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 50 | padding = 3 if kernel_size == 7 else 1 51 | 52 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 53 | self.sigmoid = nn.Sigmoid() 54 | 55 | def forward(self, x): 56 | avg_out = torch.mean(x, dim=1, keepdim=True) 57 | max_out, _ = torch.max(x, dim=1, keepdim=True) 58 | x = torch.cat([avg_out, max_out], dim=1) 59 | x = self.conv1(x) 60 | return self.sigmoid(x) 61 | 62 | 63 | class BasicBlock(nn.Module): 64 | expansion = 1 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(BasicBlock, self).__init__() 68 | self.conv1 = conv3x3(inplanes, planes, stride) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.conv2 = conv3x3(planes, planes) 72 | self.bn2 = nn.BatchNorm2d(planes) 73 | 74 | self.ca = ChannelAttention(planes) 75 | self.sa = SpatialAttention() 76 | 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | out += residual 90 | out = self.relu(out) 91 | return out 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | expansion = 4 96 | 97 | def __init__(self, inplanes, planes, stride=1, downsample=None): 98 | super(Bottleneck, self).__init__() 99 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 100 | self.bn1 = nn.BatchNorm2d(planes) 101 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 102 | padding=1, bias=False) 103 | self.bn2 = nn.BatchNorm2d(planes) 104 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 105 | self.bn3 = nn.BatchNorm2d(planes * 4) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.ca = ChannelAttention(planes * 4) 108 | self.sa = SpatialAttention() 109 | self.downsample = downsample 110 | self.stride = stride 111 | 112 | def forward(self, x): 113 | residual = x 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = self.relu(out) 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | out = self.conv3(out) 121 | out = self.bn3(out) 122 | out = self.ca(out) * out 123 | out = self.sa(out) * out 124 | if self.downsample is not None: 125 | residual = self.downsample(x) 126 | out += residual 127 | out = self.relu(out) 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | 133 | def __init__(self, block, layers, num_classes=100, args=None): 134 | self.inplanes = 64 135 | super(ResNet, self).__init__() 136 | assert args is not None, "you should pass args to resnet" 137 | if 'cifar' in args["dataset"]: 138 | self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 139 | nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) 140 | elif 'imagenet' in args["dataset"]: 141 | if args["init_cls"] == args["increment"]: 142 | self.conv1 = nn.Sequential( 143 | nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), 144 | nn.BatchNorm2d(self.inplanes), 145 | nn.ReLU(inplace=True), 146 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 147 | ) 148 | else: 149 | self.conv1 = nn.Sequential( 150 | nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 151 | nn.BatchNorm2d(self.inplanes), 152 | nn.ReLU(inplace=True), 153 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 154 | ) 155 | self.layer1 = self._make_layer(block, 64, layers[0]) 156 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 157 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 159 | self.feature = nn.AvgPool2d(4, stride=1) 160 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 161 | self.out_dim = 512 * block.expansion 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 166 | m.weight.data.normal_(0, math.sqrt(2. / n)) 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1): 172 | downsample = None 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | nn.Conv2d(self.inplanes, planes * block.expansion, 176 | kernel_size=1, stride=stride, bias=False), 177 | nn.BatchNorm2d(planes * block.expansion), 178 | ) 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample)) 181 | self.inplanes = planes * block.expansion 182 | for i in range(1, blocks): 183 | layers.append(block(self.inplanes, planes)) 184 | 185 | return nn.Sequential(*layers) 186 | 187 | def forward(self, x): 188 | x = self.conv1(x) 189 | 190 | x = self.layer1(x) 191 | x = self.layer2(x) 192 | x = self.layer3(x) 193 | x = self.layer4(x) 194 | dim = x.size()[-1] 195 | pool = nn.AvgPool2d(dim, stride=1) 196 | x = pool(x) 197 | x = x.view(x.size(0), -1) 198 | return {"features": x} 199 | 200 | def resnet18_cbam(pretrained=False, **kwargs): 201 | """Constructs a ResNet-18 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 206 | if pretrained: 207 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) 208 | now_state_dict = model.state_dict() 209 | now_state_dict.update(pretrained_state_dict) 210 | model.load_state_dict(now_state_dict) 211 | return model 212 | 213 | 214 | def resnet34_cbam(pretrained=False, **kwargs): 215 | """Constructs a ResNet-34 model. 216 | Args: 217 | pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | """ 219 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 220 | if pretrained: 221 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) 222 | now_state_dict = model.state_dict() 223 | now_state_dict.update(pretrained_state_dict) 224 | model.load_state_dict(now_state_dict) 225 | return model 226 | 227 | 228 | def resnet50_cbam(pretrained=False, **kwargs): 229 | """Constructs a ResNet-50 model. 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | """ 233 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 234 | if pretrained: 235 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) 236 | now_state_dict = model.state_dict() 237 | now_state_dict.update(pretrained_state_dict) 238 | model.load_state_dict(now_state_dict) 239 | return model 240 | 241 | 242 | def resnet101_cbam(pretrained=False, **kwargs): 243 | """Constructs a ResNet-101 model. 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | """ 247 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 248 | if pretrained: 249 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet101']) 250 | now_state_dict = model.state_dict() 251 | now_state_dict.update(pretrained_state_dict) 252 | model.load_state_dict(now_state_dict) 253 | return model 254 | 255 | 256 | def resnet152_cbam(pretrained=False, **kwargs): 257 | """Constructs a ResNet-152 model. 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | """ 261 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 262 | if pretrained: 263 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet152']) 264 | now_state_dict = model.state_dict() 265 | now_state_dict.update(pretrained_state_dict) 266 | model.load_state_dict(now_state_dict) 267 | return model -------------------------------------------------------------------------------- /models/fetril.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | results on CIFAR-100: 4 | 5 | | Reported Resnet18 | Reproduced Resnet32 6 | Protocols | Reported FC | Reported SVM | Reproduced FC | Reproduced SVM | 7 | 8 | T = 5 | 64.7 | 66.3 | 65.775 | 65.375 | 9 | 10 | T = 10 | 63.4 | 65.2 | 64.91 | 65.10 | 11 | 12 | T = 60 | 50.8 | 59.8 | 62.09 | 61.72 | 13 | 14 | ''' 15 | 16 | 17 | import logging 18 | import numpy as np 19 | from tqdm import tqdm 20 | import torch 21 | from torch import nn 22 | from torch import optim 23 | from torch.nn import functional as F 24 | from torch.utils.data import DataLoader,Dataset 25 | from models.base import BaseLearner 26 | from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet 27 | from utils.toolkit import count_parameters, target2onehot, tensor2numpy 28 | from sklearn.svm import LinearSVC 29 | from torchvision import datasets, transforms 30 | from utils.autoaugment import CIFAR10Policy,ImageNetPolicy 31 | from utils.ops import Cutout 32 | 33 | EPSILON = 1e-8 34 | 35 | 36 | class FeTrIL(BaseLearner): 37 | def __init__(self, args): 38 | super().__init__(args) 39 | self.args = args 40 | self._network = IncrementalNet(args, False) 41 | self._means = [] 42 | self._svm_accs = [] 43 | 44 | 45 | def after_task(self): 46 | self._known_classes = self._total_classes 47 | 48 | def incremental_train(self, data_manager): 49 | self.data_manager = data_manager 50 | self.data_manager._train_trsf = [ 51 | transforms.RandomCrop(32, padding=4), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ColorJitter(brightness=63/255), 54 | CIFAR10Policy(), 55 | transforms.ToTensor(), 56 | Cutout(n_holes=1, length=16), 57 | ] 58 | self._cur_task += 1 59 | 60 | self._total_classes = self._known_classes + \ 61 | data_manager.get_task_size(self._cur_task) 62 | self._network.update_fc(self._total_classes) 63 | self._network_module_ptr = self._network 64 | logging.info( 65 | 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) 66 | 67 | if self._cur_task > 0: 68 | for p in self._network.convnet.parameters(): 69 | p.requires_grad = False 70 | 71 | logging.info('All params: {}'.format(count_parameters(self._network))) 72 | logging.info('Trainable params: {}'.format( 73 | count_parameters(self._network, True))) 74 | 75 | train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', 76 | mode='train', appendent=self._get_memory()) 77 | self.train_loader = DataLoader( 78 | train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) 79 | test_dataset = data_manager.get_dataset( 80 | np.arange(0, self._total_classes), source='test', mode='test') 81 | self.test_loader = DataLoader( 82 | test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) 83 | 84 | if len(self._multiple_gpus) > 1: 85 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 86 | self._train(self.train_loader, self.test_loader) 87 | 88 | if len(self._multiple_gpus) > 1: 89 | self._network = self._network.module 90 | 91 | 92 | def _train(self, train_loader, test_loader): 93 | self._network.to(self._device) 94 | if hasattr(self._network, "module"): 95 | self._network_module_ptr = self._network.module 96 | if self._cur_task == 0: 97 | self._epoch_num = self.args["init_epochs"] 98 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, self._network.parameters( 99 | )), momentum=0.9, lr=self.args["init_lr"], weight_decay=self.args["init_weight_decay"]) 100 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 101 | optimizer=optimizer, T_max=self.args["init_epochs"]) 102 | self._train_function(train_loader, test_loader, optimizer, scheduler) 103 | self._compute_means() 104 | self._build_feature_set() 105 | else: 106 | self._epoch_num = self.args["epochs"] 107 | self._compute_means() 108 | self._compute_relations() 109 | self._build_feature_set() 110 | 111 | train_loader = DataLoader(self._feature_trainset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) 112 | optimizer = optim.SGD(self._network_module_ptr.fc.parameters(),momentum=0.9,lr=self.args["lr"],weight_decay=self.args["weight_decay"]) 113 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max = self.args["epochs"]) 114 | 115 | self._train_function(train_loader, test_loader, optimizer, scheduler) 116 | self._train_svm(self._feature_trainset,self._feature_testset) 117 | 118 | 119 | def _compute_means(self): 120 | with torch.no_grad(): 121 | for class_idx in range(self._known_classes, self._total_classes): 122 | data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 123 | mode='test', ret_data=True) 124 | idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) 125 | vectors, _ = self._extract_vectors(idx_loader) 126 | class_mean = np.mean(vectors, axis=0) 127 | self._means.append(class_mean) 128 | 129 | def _compute_relations(self): 130 | old_means = np.array(self._means[:self._known_classes]) 131 | new_means = np.array(self._means[self._known_classes:]) 132 | self._relations=np.argmax((old_means/np.linalg.norm(old_means,axis=1)[:,None])@(new_means/np.linalg.norm(new_means,axis=1)[:,None]).T,axis=1)+self._known_classes 133 | def _build_feature_set(self): 134 | self.vectors_train = [] 135 | self.labels_train = [] 136 | for class_idx in range(self._known_classes, self._total_classes): 137 | data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 138 | mode='test', ret_data=True) 139 | idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) 140 | vectors, _ = self._extract_vectors(idx_loader) 141 | self.vectors_train.append(vectors) 142 | self.labels_train.append([class_idx]*len(vectors)) 143 | for class_idx in range(0,self._known_classes): 144 | new_idx = self._relations[class_idx] 145 | self.vectors_train.append(self.vectors_train[new_idx-self._known_classes]-self._means[new_idx]+self._means[class_idx]) 146 | self.labels_train.append([class_idx]*len(self.vectors_train[-1])) 147 | 148 | self.vectors_train = np.concatenate(self.vectors_train) 149 | self.labels_train = np.concatenate(self.labels_train) 150 | self._feature_trainset = FeatureDataset(self.vectors_train,self.labels_train) 151 | 152 | self.vectors_test = [] 153 | self.labels_test = [] 154 | for class_idx in range(0, self._total_classes): 155 | data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='test', 156 | mode='test', ret_data=True) 157 | idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) 158 | vectors, _ = self._extract_vectors(idx_loader) 159 | self.vectors_test.append(vectors) 160 | self.labels_test.append([class_idx]*len(vectors)) 161 | self.vectors_test = np.concatenate(self.vectors_test) 162 | self.labels_test = np.concatenate(self.labels_test) 163 | 164 | self._feature_testset = FeatureDataset(self.vectors_test,self.labels_test) 165 | 166 | def _train_function(self, train_loader, test_loader, optimizer, scheduler): 167 | prog_bar = tqdm(range(self._epoch_num)) 168 | for _, epoch in enumerate(prog_bar): 169 | if self._cur_task == 0: 170 | self._network.train() 171 | else: 172 | self._network.eval() 173 | losses = 0. 174 | correct, total = 0, 0 175 | for i, (_, inputs, targets) in enumerate(train_loader): 176 | inputs, targets = inputs.to( 177 | self._device, non_blocking=True), targets.to(self._device, non_blocking=True) 178 | if self._cur_task ==0: 179 | logits = self._network(inputs)['logits'] 180 | else: 181 | logits = self._network_module_ptr.fc(inputs)['logits'] 182 | loss = F.cross_entropy(logits, targets) 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | losses += loss.item() 187 | _, preds = torch.max(logits, dim=1) 188 | correct += preds.eq(targets.expand_as(preds)).cpu().sum() 189 | total += len(targets) 190 | scheduler.step() 191 | train_acc = np.around(tensor2numpy( 192 | correct)*100 / total, decimals=2) 193 | if epoch % 5 != 0: 194 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format( 195 | self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc) 196 | else: 197 | test_acc = self._compute_accuracy(self._network, test_loader) 198 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( 199 | self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc, test_acc) 200 | prog_bar.set_description(info) 201 | logging.info(info) 202 | def _train_svm(self,train_set,test_set): 203 | train_features = train_set.features.numpy() 204 | train_labels = train_set.labels.numpy() 205 | test_features = test_set.features.numpy() 206 | test_labels = test_set.labels.numpy() 207 | train_features = train_features/np.linalg.norm(train_features,axis=1)[:,None] 208 | test_features = test_features/np.linalg.norm(test_features,axis=1)[:,None] 209 | svm_classifier = LinearSVC(random_state=42) 210 | svm_classifier.fit(train_features,train_labels) 211 | logging.info("svm train: acc: {}".format(np.around(svm_classifier.score(train_features,train_labels)*100,decimals=2))) 212 | acc = svm_classifier.score(test_features,test_labels) 213 | self._svm_accs.append(np.around(acc*100,decimals=2)) 214 | logging.info("svm evaluation: acc_list: {}".format(self._svm_accs)) 215 | 216 | class FeatureDataset(Dataset): 217 | def __init__(self, features, labels): 218 | assert len(features) == len(labels), "Data size error!" 219 | self.features = torch.from_numpy(features) 220 | self.labels = torch.from_numpy(labels) 221 | 222 | def __len__(self): 223 | return len(self.features) 224 | 225 | def __getitem__(self, idx): 226 | feature = self.features[idx] 227 | label = self.labels[idx] 228 | 229 | return idx, feature, label 230 | --------------------------------------------------------------------------------