├── .gitignore ├── model ├── __init__.py ├── sequential.py ├── independent.py ├── base.py ├── lwf.py ├── ewc.py ├── er.py ├── clpu_derpp.py ├── clu_er.py ├── derpp.py ├── lsf.py └── backbone.py ├── README.md ├── LICENCE.md ├── run.sh ├── config.py ├── data.py ├── plot.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | results/* 3 | data/* 4 | *.txt 5 | *.pyc 6 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequential import Sequential 2 | from .lsf import LSF 3 | from .lwf import LwF 4 | from .ewc import EWC 5 | from .er import ER 6 | from .clu_er import CLU_ER 7 | from .derpp import Derpp 8 | from .clpu_derpp import CLPU_Derpp 9 | from .independent import Independent 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual-Learning-Private-Unlearning 2 | Official PyTorch implementation for Continual Learning and Private Unlearning, which has been accepted to **CoLLAs 2022**. 3 | 4 | # To run the experiments 5 | ``` 6 | chmox +x run.sh && ./run.sh 7 | ``` 8 | 9 | ## Citations 10 | If you find our work interesting or the repo useful, please consider citing [this paper](https://arxiv.org/abs/2203.12817): 11 | ``` 12 | @article{liu2022continual, 13 | title={Continual Learning and Private Unlearning}, 14 | author={Liu, Bo and Liu, Qiang and Stone, Peter}, 15 | journal={arXiv preprint arXiv:2203.12817}, 16 | year={2022} 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bo Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/sequential.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import Base 15 | 16 | 17 | class Sequential(Base): 18 | def __init__(self, config): 19 | super(Sequential, self).__init__(config) 20 | 21 | def learn(self, task_id, dataset): 22 | loader = DataLoader(dataset, 23 | batch_size=self.config.batch_size, 24 | shuffle=True, 25 | num_workers=2) 26 | self.opt = SGD(self.net.parameters(), 27 | lr=self.config.lr, 28 | momentum=self.config.momentum, 29 | weight_decay=self.config.weight_decay) 30 | 31 | for i in range(self.config.n_epochs): 32 | for x, y in loader: 33 | x = x.to(self.device) 34 | y = y.to(self.device) 35 | self.opt.zero_grad() 36 | pred = self.forward(x, task_id) 37 | loss = self.loss_fn(pred, y) 38 | loss.backward() 39 | self.opt.step() 40 | -------------------------------------------------------------------------------- /model/independent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import Base 15 | 16 | 17 | class Independent(Base): 18 | def __init__(self, config): 19 | super(Independent, self).__init__(config) 20 | self.nets = {} # learn a separate network per task 21 | 22 | def forward(self, x, task): 23 | if (task in self.nets): 24 | pred = self.nets[task].forward(x).view(x.shape[0], -1) 25 | else: 26 | pred = self.net.forward(x).view(x.shape[0], -1) 27 | 28 | if task > 0: 29 | pred[:, :self.cpt*task].data.fill_(-10e10) 30 | if task < self.n_tasks-1: 31 | pred[:, self.cpt*(task+1):].data.fill_(-10e10) 32 | return pred 33 | 34 | def learn(self, task_id, dataset): 35 | self.nets[task_id] = copy.deepcopy(self.net) 36 | 37 | loader = DataLoader(dataset, 38 | batch_size=self.config.batch_size, 39 | shuffle=True, 40 | num_workers=2) 41 | 42 | opt = SGD(self.nets[task_id].parameters(), 43 | lr=self.config.lr, 44 | momentum=self.config.momentum, 45 | weight_decay=self.config.weight_decay) 46 | 47 | for epoch in range(self.config.n_epochs): 48 | for i, (x, y) in enumerate(loader): 49 | x = x.to(self.device) 50 | y = y.to(self.device) 51 | 52 | # current loss 53 | h = self.forward(x, task_id) 54 | loss = self.loss_fn(h, y) 55 | opt.zero_grad() 56 | loss.backward() 57 | opt.step() 58 | 59 | def forget(self, task_id): 60 | del self.nets[task_id] 61 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | seed=$1 2 | 3 | dataset=perm_mnist 4 | alpha=0.5 5 | beta=1.0 6 | python main.py --dataset $dataset --method sequential --seed $seed 7 | python main.py --dataset $dataset --method independent --seed $seed 8 | python main.py --dataset $dataset --method ewc --seed $seed 9 | python main.py --dataset $dataset --method er --seed $seed 10 | python main.py --dataset $dataset --method derpp --seed $seed --beta $beta --alpha $alpha 11 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha 12 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha --use_pretrain 13 | 14 | 15 | dataset=rot_mnist 16 | alpha=0.5 17 | beta=1.0 18 | python main.py --dataset $dataset --method sequential --seed $seed 19 | python main.py --dataset $dataset --method independent --seed $seed 20 | python main.py --dataset $dataset --method ewc --seed $seed 21 | python main.py --dataset $dataset --method er --seed $seed 22 | python main.py --dataset $dataset --method derpp --seed $seed --beta $beta --alpha $alpha 23 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha 24 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha --use_pretrain 25 | 26 | dataset=cifar10 27 | alpha=0.5 28 | beta=0.5 29 | python main.py --dataset $dataset --method sequential --seed $seed 30 | python main.py --dataset $dataset --method independent --seed $seed 31 | python main.py --dataset $dataset --method ewc --seed $seed 32 | python main.py --dataset $dataset --method er --seed $seed 33 | python main.py --dataset $dataset --method derpp --seed $seed --beta $beta --alpha $alpha 34 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha 35 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha --use_pretrain 36 | 37 | dataset=cifar10 38 | alpha=0.5 39 | beta=1.0 40 | python main.py --dataset $dataset --method sequential --seed $seed 41 | python main.py --dataset $dataset --method independent --seed $seed 42 | python main.py --dataset $dataset --method ewc --seed $seed 43 | python main.py --dataset $dataset --method er --seed $seed 44 | python main.py --dataset $dataset --method derpp --seed $seed --beta $beta --alpha $alpha 45 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha 46 | python main.py --dataset $dataset --method clpu_derpp --seed $seed --beta $beta --alpha $alpha --use_pretrain 47 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from .backbone import resnet18, MLP 13 | 14 | 15 | class Base(nn.Module): 16 | def __init__(self, config): 17 | super(Base, self).__init__() 18 | self.config = config 19 | self.device = config.device 20 | self.n_tasks = config.n_tasks 21 | self.cpt = config.class_per_task 22 | if "cifar" in config.dataset: 23 | if config.scenario == 'domain': 24 | self.net = resnet18(config.class_per_task) 25 | else: 26 | self.net = resnet18(config.class_per_task*config.n_tasks) 27 | else: 28 | self.net = MLP(config) 29 | 30 | self.loss_fn = nn.CrossEntropyLoss() 31 | self.loss_fn_reduction_none = nn.CrossEntropyLoss(reduction='none') 32 | self.opt = None 33 | self.task_status = {} 34 | self.prev_tasks = [] 35 | self.n_iters = 1 36 | 37 | def forward(self, x, task): 38 | x = self.net(x) 39 | x = x.view(x.shape[0], -1) 40 | 41 | if self.config.scenario != 'domain': 42 | if task > 0: 43 | x[:, :self.cpt*task].data.fill_(-10e10) 44 | if task < self.n_tasks-1: 45 | x[:, self.cpt*(task+1):].data.fill_(-10e10) 46 | return x 47 | 48 | def learn(self, task_id, dataset): 49 | pass 50 | 51 | def temporarily_learn(self, task_id, dataset): 52 | return self.learn(task_id, dataset) # default to same as learn 53 | 54 | def finally_learn(self, task_id): 55 | return # default: do nothing when we finally learn a task 56 | 57 | def forget(self, task_id): 58 | return # default: do nothing when we want to forget a task 59 | 60 | def continual_learn_and_unlearn(self, task_id, dataset, learn_type): 61 | if learn_type in ["R", "T"]: 62 | if task_id not in self.task_status: # first time learn the task 63 | self.task_status[task_id] = learn_type 64 | if learn_type == "R": 65 | self.learn(task_id, dataset) 66 | else: 67 | self.temporarily_learn(task_id, dataset) 68 | else: # second time consolidate 69 | assert learn_type == "R", "[ERROR] second time learn task should be memorize" 70 | assert self.task_status[task_id] == "T", "[ERROR] the task should have been temporarily learned" 71 | self.finally_learn(task_id) 72 | self.task_status[task_id] = "R" 73 | else: # learn type is "F" forget 74 | assert learn_type == "F", f"[ERROR] unknown learning type {learn_type}" 75 | assert task_id in self.task_status, f"[ERROR] {task_id} is not learned" 76 | assert self.task_status[task_id] == "T", f"[ERROR] {task_id} was remembered, cannot unlearn" 77 | self.forget(task_id) 78 | self.task_status[task_id] = "F" 79 | -------------------------------------------------------------------------------- /model/lwf.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import * 15 | 16 | 17 | def smooth(logits, temp, dim): 18 | log = logits ** (1 / temp) 19 | return log / torch.sum(log, dim).unsqueeze(1) 20 | 21 | 22 | def modified_kl_div(old, new): 23 | return -torch.mean(torch.sum(old * torch.log(new), 1)) 24 | 25 | 26 | class LwF(Base): 27 | def __init__(self, config): 28 | super(LwF, self).__init__(config) 29 | self.old_net = None 30 | self.soft = torch.nn.Softmax(dim=-1) 31 | 32 | def learn(self, task_id, dataset): 33 | self.old_net = copy.deepcopy(self.net) 34 | self.old_net.eval() 35 | for p in self.old_net.parameters(): 36 | p.requires_grad = False 37 | 38 | loader = DataLoader(dataset, 39 | batch_size=self.config.batch_size, 40 | shuffle=True, 41 | num_workers=2) 42 | self.opt = SGD(self.net.parameters(), 43 | lr=self.config.lr, 44 | momentum=self.config.momentum, 45 | weight_decay=self.config.weight_decay) 46 | 47 | if task_id > 0: 48 | self.opt_cls = SGD(self.net.classifier.parameters(), 49 | lr=self.config.lr, 50 | momentum=self.config.momentum, 51 | weight_decay=self.config.weight_decay) 52 | for epoch in range(self.config.n_epochs): 53 | for i, (x, y) in enumerate(loader): 54 | x = x.to(self.device) 55 | y = y.to(self.device) 56 | 57 | loss = self.loss_fn(self.forward(x, task_id), y) 58 | self.opt_cls.zero_grad() 59 | loss.backward() 60 | self.opt_cls.step() 61 | 62 | for epoch in range(self.config.n_epochs): 63 | for i, (x, y) in enumerate(loader): 64 | x = x.to(self.device) 65 | y = y.to(self.device) 66 | 67 | loss = self.loss_fn(self.forward(x, task_id), y) 68 | 69 | # current loss 70 | if task_id > 0: 71 | n_prev_tasks = len(self.prev_tasks) 72 | for t in self.prev_tasks: 73 | outputs = self.forward(x, t)[...,t*self.cpt:(t+1)*self.cpt] 74 | with torch.no_grad(): 75 | targets = self.old_net.forward(x)[..., t*self.cpt:(t+1)*self.cpt] 76 | loss += self.config.lwf_alpha * modified_kl_div( 77 | smooth(self.soft(targets), self.config.lwf_temp, 1), 78 | smooth(self.soft(outputs), self.config.lwf_temp, 1)) / n_prev_tasks 79 | self.opt.zero_grad() 80 | loss.backward() 81 | self.opt.step() 82 | 83 | self.prev_tasks.append(task_id) 84 | 85 | def forget(self, task_id): 86 | self.prev_tasks.remove(task_id) 87 | -------------------------------------------------------------------------------- /model/ewc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import * 15 | 16 | 17 | class EWC(Base): 18 | def __init__(self, config): 19 | super(EWC, self).__init__(config) 20 | self.logsoft = nn.LogSoftmax(dim=1) 21 | self.fish = {} 22 | self.checkpoint = {} 23 | 24 | def penalty(self): 25 | ### ewc penalty 26 | if len(self.prev_tasks) == 0: 27 | return torch.tensor(0.0).to(self.device) 28 | current_param = self.net.get_params() 29 | penalty = 0.0 30 | for t in self.prev_tasks: 31 | penalty += (self.fish[t] * (current_param - self.checkpoint[t]).pow(2)).sum() 32 | return penalty 33 | 34 | def learn(self, task_id, dataset): 35 | loader = DataLoader(dataset, 36 | batch_size=self.config.batch_size, 37 | shuffle=True, 38 | num_workers=2) 39 | self.opt = SGD(self.net.parameters(), 40 | lr=self.config.lr, 41 | momentum=self.config.momentum, 42 | weight_decay=self.config.weight_decay) 43 | 44 | ewc_lmbd = self.config.ewc_lmbd 45 | self.n_iters = self.config.n_epochs * len(loader) 46 | 47 | for epoch in range(self.config.n_epochs): 48 | for i, (x, y) in enumerate(loader): 49 | x = x.to(self.device) 50 | y = y.to(self.device) 51 | 52 | loss = self.loss_fn(self.forward(x, task_id), y) 53 | 54 | # current loss 55 | if task_id > 0: 56 | loss += ewc_lmbd * self.penalty() 57 | self.opt.zero_grad() 58 | loss.backward() 59 | self.opt.step() 60 | 61 | #if task_id > 0: 62 | # self.opt.zero_grad() 63 | # Lr.backward() 64 | # import pdb; pdb.set_trace() 65 | # self.opt.step() 66 | 67 | ### end of task 68 | fish = torch.zeros_like(self.net.get_params()) 69 | 70 | for j, (x, y) in enumerate(loader): 71 | x = x.to(self.device) 72 | y = y.to(self.device) 73 | for ex, lab in zip(x, y): 74 | self.opt.zero_grad() 75 | output = self.forward(ex.unsqueeze(0), task_id) 76 | loss = - F.nll_loss(self.logsoft(output), lab.unsqueeze(0), reduction='none') 77 | exp_cond_prob = torch.mean(torch.exp(loss.detach().clone())) 78 | loss = torch.mean(loss) 79 | loss.backward() 80 | fish += exp_cond_prob * self.net.get_grads() ** 2 81 | 82 | fish /= (len(loader)*self.config.batch_size) 83 | self.prev_tasks.append(task_id) 84 | self.fish[task_id] = fish 85 | self.checkpoint[task_id] = self.net.get_params().data.clone() 86 | 87 | def forget(self, task_id): 88 | assert task_id in self.prev_tasks, f"[ERROR] {task_id} not seen before" 89 | self.prev_tasks.remove(task_id) 90 | del self.fish[task_id] 91 | del self.checkpoint[task_id] 92 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | 5 | class Config: 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser() 8 | self.device = 'cpu' 9 | self.verbose = True 10 | self.init_args() 11 | 12 | def add_argument(self, *args, **kwargs): 13 | self.parser.add_argument(*args, **kwargs) 14 | 15 | def init_args(self): 16 | self.parser.add_argument('--method', default='er', help='[er, clu, retrain]') 17 | self.parser.add_argument('--seed', default=1, type=int, help='seed') 18 | self.parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'mnist', 'fashion', 'rot_mnist', 'perm_mnist', 'law']) 19 | self.parser.add_argument('--batch_size', default=32, type=int, help='batch size') 20 | self.parser.add_argument('--n_epochs', default=10, type=int, help='number of iterations') 21 | 22 | self.parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 23 | self.parser.add_argument('--momentum', default=0.0, type=float, help='momentum') 24 | self.parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight_decay') 25 | 26 | # EWC 27 | self.parser.add_argument('--ewc_lmbd', default=100., type=float, help='ewc lambda') 28 | 29 | # LSF 30 | self.parser.add_argument('--lsf_gamma', default=10.0, type=float, help='lsf gamma') 31 | 32 | # LWF 33 | self.parser.add_argument('--lwf_alpha', default=1.0, type=float, help='lsf gamma') 34 | self.parser.add_argument('--lwf_temp', default=2.0, type=float, help='lsf gamma') 35 | 36 | # ER 37 | self.parser.add_argument('--forget_iters', default=1000, type=int, help='number of forgetting iterations') 38 | self.parser.add_argument('--memorize_iters', default=1000, type=int, help='number of forgetting iterations') 39 | self.parser.add_argument('--mem_budget', default=200, type=int, help='memory budgeet') 40 | self.parser.add_argument('--mem_batch_size', default=32, type=int, help='memory batch size') 41 | 42 | # CLU_ER & CLPU_Derpp 43 | self.parser.add_argument('--use_pretrain', default=False, action='store_true', help='whether to initialize from previous model') 44 | 45 | # Derpp & CLPU_Derpp 46 | self.parser.add_argument('--alpha', default=0.5, type=float, help='memory batch size') 47 | self.parser.add_argument('--beta', default=1.0, type=float, help='memory batch size') 48 | 49 | args = self.parser.parse_args() 50 | dict_ = vars(args) 51 | 52 | for k, v in dict_.items(): 53 | setattr(self, k, v) 54 | self.scenario = 'class' 55 | 56 | if self.dataset == 'cifar100': 57 | self.dim_input = (3, 32, 32) 58 | self.class_per_task = 20 59 | self.n_tasks = 5 60 | elif self.dataset == 'cifar10': 61 | self.dim_input = (3, 32, 32) 62 | self.class_per_task = 2 63 | self.n_tasks = 5 64 | elif self.dataset == 'mnist': 65 | self.dim_input = (1, 32, 32) 66 | self.class_per_task = 2 67 | self.n_tasks = 5 68 | elif self.dataset == 'fashion': 69 | self.dim_input = (1, 32, 32) 70 | self.class_per_task = 2 71 | self.n_tasks = 5 72 | elif self.dataset == 'rot_mnist': 73 | self.dim_input = (784,) 74 | self.class_per_task = 10 75 | self.n_tasks = 5 76 | elif self.dataset == 'perm_mnist': 77 | self.dim_input = (784,) 78 | self.class_per_task = 10 79 | self.n_tasks = 5 80 | elif self.dataset == 'law': 81 | self.scenario = 'domain' 82 | self.dim_input = (11,) 83 | self.class_per_task = 2 84 | self.n_tasks = 5 85 | 86 | if self.verbose: 87 | print("="*80) 88 | print("[INFO] -- Experiment Configs --") 89 | print(" 1. data & task") 90 | print(" dataset: %s" % self.dataset) 91 | print(" n_tasks: %d" % self.n_tasks) 92 | print(" # class/task: %d" % self.class_per_task) 93 | print(" 2. training") 94 | print(" lr: %5.4f" % self.lr) 95 | print(" 3. model") 96 | print(" method: %s" % self.method) 97 | print("="*80) 98 | -------------------------------------------------------------------------------- /model/er.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import * 15 | 16 | 17 | class Replay(data.Dataset): 18 | """ 19 | A dataset wrapper used as a memory to store the data 20 | """ 21 | def __init__(self, buffer_size, dim_x, device): 22 | super(Replay, self).__init__() 23 | self.dim_x = dim_x 24 | self.buffer_size = buffer_size 25 | self.device = device 26 | self.buffer = {} 27 | 28 | def __len__(self): 29 | if not self.buffer: 30 | return 0 31 | else: 32 | n = 0 33 | for t in self.buffer: 34 | n += min(self.buffer[t]['num_seen'], self.buffer_size) 35 | return n 36 | 37 | def add(self, x, y, t): 38 | x = x.cpu() 39 | y = y.cpu() 40 | if t not in self.buffer: 41 | self.buffer[t] = { 42 | "X": torch.zeros([self.buffer_size] + list(self.dim_x)), 43 | "Y": torch.zeros(self.buffer_size).long(), 44 | "num_seen": 0, 45 | } 46 | 47 | n = x.shape[0] 48 | for i in range(n): 49 | self.buffer[t]['num_seen'] += 1 50 | 51 | if self.buffer[t]['num_seen'] <= self.buffer_size: 52 | idx = self.buffer[t]['num_seen'] - 1 53 | else: 54 | rand = np.random.randint(0, self.buffer[t]['num_seen']) 55 | idx = rand if rand < self.buffer_size else -1 56 | 57 | self.buffer[t]['X'][idx] = x[i] 58 | self.buffer[t]['Y'][idx] = y[i] 59 | 60 | def sample(self, n, exclude=[]): 61 | nb = self.__len__() 62 | if nb == 0: 63 | return None, None 64 | 65 | X = []; Y = [] 66 | for t, v in self.buffer.items(): 67 | if t in exclude: 68 | continue 69 | idx = torch.randperm(min(v['num_seen'], v['X'].shape[0]))[:min(min(n, v['num_seen']), v['X'].shape[0])] 70 | X.append(v['X'][idx]) 71 | Y.append(v['Y'][idx]) 72 | return torch.cat(X, 0).to(self.device), torch.cat(Y, 0).to(self.device) 73 | 74 | def sample_task(self, n, task_id): 75 | X = []; Y = [] 76 | assert task_id in self.buffer, f"[ERROR] not found {task_id} in buffer" 77 | v = self.buffer[task_id] 78 | idx = torch.randperm(min(v['num_seen'], v['X'].shape[0]))[:min(min(n, v['num_seen']), v['X'].shape[0])] 79 | X.append(v['X'][idx]) 80 | Y.append(v['Y'][idx]) 81 | return torch.cat(X, 0).to(self.device), torch.cat(Y, 0).to(self.device) 82 | 83 | def remove(self, t): 84 | X = self.buffer[t]['X'] 85 | Y = self.buffer[t]['Y'] 86 | del self.buffer[t] 87 | return X, Y 88 | 89 | 90 | class ER(Base): 91 | def __init__(self, config): 92 | super(ER, self).__init__(config) 93 | self.n_mem = self.config.mem_budget 94 | self.memory = Replay(self.n_mem, config.dim_input, self.device) 95 | 96 | def learn(self, task_id, dataset): 97 | loader = DataLoader(dataset, 98 | batch_size=self.config.batch_size, 99 | shuffle=True, 100 | num_workers=2) 101 | self.opt = SGD(self.net.parameters(), 102 | lr=self.config.lr, 103 | momentum=self.config.momentum, 104 | weight_decay=self.config.weight_decay) 105 | 106 | self.n_iters = self.config.n_epochs * len(loader) 107 | for epoch in range(self.config.n_epochs): 108 | for i, (x, y) in enumerate(loader): 109 | x = x.to(self.device) 110 | y = y.to(self.device) 111 | 112 | # current loss 113 | loss = self.loss_fn(self.forward(x, task_id), y) 114 | 115 | n_prev_tasks = len(self.prev_tasks) 116 | for t in self.prev_tasks: 117 | x_past, y_past = self.memory.sample_task(self.config.mem_batch_size//n_prev_tasks, t) 118 | loss += self.loss_fn(self.forward(x_past, t), y_past) / n_prev_tasks 119 | 120 | self.opt.zero_grad() 121 | loss.backward() 122 | self.opt.step() 123 | self.memory.add(x, y, task_id) 124 | 125 | self.prev_tasks.append(task_id) 126 | 127 | def forget(self, task_id): 128 | self.memory.remove(task_id) 129 | self.prev_tasks.remove(task_id) 130 | 131 | self.opt = SGD(self.net.parameters(), 132 | lr=self.config.lr, 133 | momentum=self.config.momentum, 134 | weight_decay=self.config.weight_decay) 135 | 136 | n_prev_tasks = len(self.prev_tasks) 137 | #for i in range(self.config.forget_iters): 138 | for i in range(self.n_iters): 139 | self.opt.zero_grad() 140 | loss = 0.0 141 | for t in self.prev_tasks: 142 | x_past, y_past = self.memory.sample_task(self.config.mem_batch_size // n_prev_tasks, t) 143 | loss += self.loss_fn(self.forward(x_past, t), y_past) / n_prev_tasks 144 | loss.backward() 145 | self.opt.step() 146 | -------------------------------------------------------------------------------- /model/clpu_derpp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import * 15 | from .derpp import Derpp, Replay 16 | 17 | 18 | class CLPU_Derpp(Derpp): 19 | def __init__(self, config): 20 | super(CLPU_Derpp, self).__init__(config) 21 | self.side_nets = {} 22 | 23 | def forward(self, x, task): 24 | if (task in self.side_nets): 25 | pred = self.side_nets[task].forward(x).view(x.shape[0], -1) 26 | else: 27 | pred = self.net.forward(x).view(x.shape[0], -1) 28 | 29 | if self.config.scenario != 'domain': 30 | if task > 0: 31 | pred[:, :self.cpt*task].data.fill_(-10e10) 32 | if task < self.n_tasks-1: 33 | pred[:, self.cpt*(task+1):].data.fill_(-10e10) 34 | return pred 35 | 36 | def learn(self, task_id, dataset): 37 | loader = DataLoader(dataset, 38 | batch_size=self.config.batch_size, 39 | shuffle=True, 40 | num_workers=2) 41 | self.opt = SGD(self.net.parameters(), 42 | lr=self.config.lr, 43 | momentum=self.config.momentum, 44 | weight_decay=self.config.weight_decay) 45 | 46 | exclude_list = [t for t in self.task_status.keys() if self.task_status[t] == 'T' ] 47 | 48 | self.n_iters = self.config.n_epochs * len(loader) 49 | for epoch in range(self.config.n_epochs): 50 | for i, (x, y) in enumerate(loader): 51 | x = x.to(self.device) 52 | y = y.to(self.device) 53 | 54 | # current loss 55 | h = self.forward(x, task_id) 56 | loss = self.loss_fn(h, y) 57 | 58 | R_prev_list = [t for t in self.prev_tasks if t not in exclude_list] 59 | n_prev_tasks = len(self.prev_tasks) 60 | for t in self.prev_tasks: 61 | x_past, y_past, h_past = self.memory.sample_task(self.config.mem_batch_size//n_prev_tasks, t) 62 | h_tmp = self.forward(x_past, t) 63 | loss += self.alpha * self.der_loss(h_tmp, h_past, t) / n_prev_tasks 64 | loss += self.beta * self.loss_fn(h_tmp, y_past) / n_prev_tasks 65 | 66 | self.opt.zero_grad() 67 | loss.backward() 68 | self.opt.step() 69 | self.memory.add(x, y, h.detach(), task_id) 70 | 71 | self.prev_tasks.append(task_id) 72 | 73 | def temporarily_learn(self, task_id, dataset): 74 | # initialize a side network 75 | assert task_id not in self.side_nets, f"[ERROR] should not see {task_id} in side nets" 76 | 77 | if self.config.use_pretrain: 78 | self.side_nets[task_id] = copy.deepcopy(self.net) 79 | else: 80 | if "cifar" in self.config.dataset: 81 | if self.config.scenario != 'domain': 82 | self.side_nets[task_id] = resnet18(self.cpt*self.n_tasks).to(self.config.device) 83 | else: 84 | self.side_nets[task_id] = resnet18(self.cpt).to(self.config.device) 85 | else: 86 | self.side_nets[task_id] = MLP(self.config).to(self.config.device) 87 | 88 | 89 | loader = DataLoader(dataset, 90 | batch_size=self.config.batch_size, 91 | shuffle=True, 92 | num_workers=2) 93 | 94 | opt = SGD(self.side_nets[task_id].parameters(), 95 | lr=self.config.lr, 96 | momentum=self.config.momentum, 97 | weight_decay=self.config.weight_decay) 98 | 99 | for epoch in range(self.config.n_epochs): 100 | for i, (x, y) in enumerate(loader): 101 | x = x.to(self.device) 102 | y = y.to(self.device) 103 | 104 | # current loss 105 | h = self.forward(x, task_id) 106 | loss = self.loss_fn(h, y) 107 | opt.zero_grad() 108 | loss.backward() 109 | opt.step() 110 | self.memory.add(x, y, h.detach(), task_id) 111 | 112 | def finally_learn(self, task_id): 113 | # use knowledge distillation to merge two networks 114 | self.opt = SGD(self.net.parameters(), 115 | lr=self.config.lr, 116 | momentum=self.config.momentum, 117 | weight_decay=self.config.weight_decay) 118 | 119 | task_net = self.side_nets[task_id] 120 | task_net.eval() 121 | del self.side_nets[task_id] 122 | 123 | exclude_list = [t for t in self.task_status.keys() if self.task_status[t] == 'T' ] 124 | R_prev_list = [t for t in self.prev_tasks if t not in exclude_list] 125 | n_prev_tasks = len(R_prev_list) 126 | 127 | #for it in range(self.config.memorize_iters): 128 | for it in range(self.n_iters): 129 | loss = 0.0 130 | x_t, y_t, h_t = self.memory.sample_task(self.config.mem_batch_size, task_id) 131 | for tt in R_prev_list: 132 | x_past, y_past, h_past = self.memory.sample_task(self.config.mem_batch_size//n_prev_tasks, tt) 133 | h_tmp = self.forward(x_past, tt) 134 | loss += self.alpha * self.der_loss(h_tmp, h_past, tt) / n_prev_tasks 135 | loss += self.beta * self.loss_fn(h_tmp, y_past) / n_prev_tasks 136 | 137 | h_tmp_t = self.forward(x_t, task_id) 138 | #h_tmp_t_target = task_net.forward(x_t).detach() 139 | loss += self.beta * self.loss_fn(h_tmp_t, y_t) + self.alpha * self.der_loss(h_tmp_t, h_t, task_id) 140 | 141 | self.opt.zero_grad() 142 | loss.backward() 143 | self.opt.step() 144 | self.prev_tasks.append(task_id) 145 | 146 | def forget(self, task_id): 147 | self.memory.remove(task_id) 148 | del self.side_nets[task_id] 149 | -------------------------------------------------------------------------------- /model/clu_er.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18, MLP 14 | from .base import Base 15 | from .er import Replay, ER 16 | 17 | 18 | class CLU_ER(ER): 19 | def __init__(self, config): 20 | super(CLU_ER, self).__init__(config) 21 | self.side_nets = {} 22 | 23 | def forward(self, x, task): 24 | if (task in self.side_nets): 25 | pred = self.side_nets[task].forward(x).view(x.shape[0], -1) 26 | else: 27 | pred = self.net.forward(x).view(x.shape[0], -1) 28 | 29 | if task > 0: 30 | pred[:, :self.cpt*task].data.fill_(-10e10) 31 | if task < self.n_tasks-1: 32 | pred[:, self.cpt*(task+1):].data.fill_(-10e10) 33 | return pred 34 | 35 | def distill_loss(self, scores, target_scores, task_id, T=2.): 36 | cpt = self.cpt 37 | log_scores_norm = F.log_softmax(scores / T, dim=1)[:, task_id*cpt:(task_id+1)*cpt] 38 | targets_norm = F.softmax(target_scores / T, dim=1)[:, task_id*cpt:(task_id+1)*cpt] 39 | # calculate distillation loss (see e.g., Li and Hoiem, 2017) 40 | loss = -(targets_norm * log_scores_norm).sum(1).mean() * T**2 41 | return loss 42 | 43 | def learn(self, task_id, dataset): 44 | loader = DataLoader(dataset, 45 | batch_size=self.config.batch_size, 46 | shuffle=False, 47 | num_workers=2) 48 | 49 | self.opt = SGD(self.net.parameters(), 50 | lr=self.config.lr, 51 | momentum=self.config.momentum, 52 | weight_decay=self.config.weight_decay) 53 | 54 | #exclude_list = [t if self.task_status[t] == 'T' for t in self.task_status.keys()] 55 | exclude_list = [t for t in self.task_status.keys() if self.task_status[t] == 'T' ] 56 | 57 | for epoch in range(self.config.n_epochs): 58 | for i, (x, y) in enumerate(loader): 59 | x = x.to(self.device) 60 | y = y.to(self.device) 61 | 62 | #x_past, y_past = self.memory.sample(self.config.mem_batch_size, exclude_list) 63 | #if x_past is None: 64 | # x_ = x; y_ = y; 65 | #else: 66 | # x_ = torch.cat((x, x_past)) 67 | # y_ = torch.cat((y, y_past)) 68 | 69 | # current loss 70 | #loss = self.loss_fn(self.forward(x_), y_) 71 | #self.opt.zero_grad() 72 | #loss.backward() 73 | #self.opt.step() 74 | #self.memory.add(x, y, task_id) 75 | 76 | loss = self.loss_fn(self.forward(x, task_id), y) 77 | 78 | R_prev_list = [t for t in self.prev_tasks if t not in exclude_list] 79 | n_prev_tasks = len(R_prev_list) 80 | for t in R_prev_list: 81 | x_past, y_past = self.memory.sample_task(self.config.mem_batch_size//n_prev_tasks, t) 82 | loss += self.loss_fn(self.forward(x_past, t), y_past) / n_prev_tasks 83 | 84 | self.opt.zero_grad() 85 | loss.backward() 86 | self.opt.step() 87 | self.memory.add(x, y, task_id) 88 | 89 | self.prev_tasks.append(task_id) 90 | 91 | def temporarily_learn(self, task_id, dataset): 92 | # initialize a side network 93 | assert task_id not in self.side_nets, f"[ERROR] should not see {task_id} in side nets" 94 | 95 | if self.config.use_pretrain: 96 | self.side_nets[task_id] = copy.deepcopy(self.net) 97 | else: 98 | if "cifar" in self.config.dataset: 99 | self.side_nets[task_id] = resnet18(self.cpt*self.n_tasks).to(self.config.device) 100 | else: 101 | self.side_nets[task_id] = MLP(self.config).to(self.config.device) 102 | 103 | 104 | loader = DataLoader(dataset, 105 | batch_size=self.config.batch_size, 106 | shuffle=True, 107 | num_workers=2) 108 | 109 | opt = SGD(self.side_nets[task_id].parameters(), 110 | lr=self.config.lr, 111 | momentum=self.config.momentum, 112 | weight_decay=self.config.weight_decay) 113 | 114 | for epoch in range(self.config.n_epochs): 115 | for i, (x, y) in enumerate(loader): 116 | x = x.to(self.device) 117 | y = y.to(self.device) 118 | 119 | # current loss 120 | loss = self.loss_fn(self.forward(x, task_id), y) 121 | opt.zero_grad() 122 | loss.backward() 123 | opt.step() 124 | self.memory.add(x, y, task_id) 125 | 126 | def finally_learn(self, task_id): 127 | # use knowledge distillation to merge two networks 128 | self.opt = SGD(self.net.parameters(), 129 | lr=self.config.lr, 130 | momentum=self.config.momentum, 131 | weight_decay=self.config.weight_decay) 132 | 133 | task_net = self.side_nets[task_id] 134 | task_net.eval() 135 | del self.side_nets[task_id] 136 | 137 | exclude_list = [t for t in self.task_status.keys() if self.task_status[t] == 'T' ] 138 | R_prev_list = [t for t in self.prev_tasks if t not in exclude_list] 139 | n_prev_tasks = len(R_prev_list) 140 | 141 | for it in range(self.config.memorize_iters): 142 | loss = 0.0 143 | x_t, y_t = self.memory.sample_task(self.config.mem_batch_size, task_id) 144 | for tt in R_prev_list: 145 | x_past, y_past = self.memory.sample_task(self.config.mem_batch_size//n_prev_tasks, tt) 146 | loss += self.loss_fn(self.forward(x_past, tt), y_past) / n_prev_tasks 147 | loss += self.distill_loss(self.forward(x_t, task_id), task_net.forward(x_t).detach(), task_id) 148 | loss += self.loss_fn(self.forward(x_t, task_id), y_t) 149 | self.opt.zero_grad() 150 | loss.backward() 151 | self.opt.step() 152 | self.prev_tasks.append(task_id) 153 | 154 | def forget(self, task_id): 155 | self.memory.remove(task_id) 156 | del self.side_nets[task_id] 157 | -------------------------------------------------------------------------------- /model/derpp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import * 15 | 16 | 17 | class Replay(data.Dataset): 18 | """ 19 | A dataset wrapper used as a memory to store the data 20 | """ 21 | def __init__(self, buffer_size, dim_x, device): 22 | super(Replay, self).__init__() 23 | self.dim_x = dim_x 24 | self.buffer_size = buffer_size 25 | self.device = device 26 | self.buffer = {} 27 | 28 | def __len__(self): 29 | if not self.buffer: 30 | return 0 31 | else: 32 | n = 0 33 | for t in self.buffer: 34 | n += min(self.buffer[t]['num_seen'], self.buffer_size) 35 | return n 36 | 37 | def add(self, x, y, h, t): 38 | x = x.cpu() 39 | y = y.cpu() 40 | h = h.cpu() 41 | dim_h = h.shape[-1] 42 | 43 | if t not in self.buffer: 44 | self.buffer[t] = { 45 | "X": torch.zeros([self.buffer_size] + list(self.dim_x)), 46 | "Y": torch.zeros(self.buffer_size).long(), 47 | "H": torch.zeros([self.buffer_size] + [dim_h]), 48 | "num_seen": 0, 49 | } 50 | 51 | n = x.shape[0] 52 | for i in range(n): 53 | self.buffer[t]['num_seen'] += 1 54 | 55 | if self.buffer[t]['num_seen'] <= self.buffer_size: 56 | idx = self.buffer[t]['num_seen'] - 1 57 | else: 58 | rand = np.random.randint(0, self.buffer[t]['num_seen']) 59 | idx = rand if rand < self.buffer_size else -1 60 | 61 | self.buffer[t]['X'][idx] = x[i] 62 | self.buffer[t]['Y'][idx] = y[i] 63 | self.buffer[t]['H'][idx] = h[i] 64 | 65 | def sample(self, n, exclude=[]): 66 | nb = self.__len__() 67 | if nb == 0: 68 | return None, None 69 | 70 | X = []; Y = []; H = [] 71 | for t, v in self.buffer.items(): 72 | if t in exclude: 73 | continue 74 | idx = torch.randperm(min(v['num_seen'], v['X'].shape[0]))[:min(min(n, v['num_seen']), v['X'].shape[0])] 75 | X.append(v['X'][idx]) 76 | Y.append(v['Y'][idx]) 77 | H.append(v['H'][idx]) 78 | return torch.cat(X, 0).to(self.device), torch.cat(Y, 0).to(self.device), torch.cat(H, 0).to(self.device) 79 | 80 | def sample_task(self, n, task_id): 81 | X = []; Y = []; H = [] 82 | assert task_id in self.buffer, f"[ERROR] not found {task_id} in buffer" 83 | v = self.buffer[task_id] 84 | idx = torch.randperm(min(v['num_seen'], v['X'].shape[0]))[:min(min(n, v['num_seen']), v['X'].shape[0])] 85 | X.append(v['X'][idx]) 86 | Y.append(v['Y'][idx]) 87 | H.append(v['H'][idx]) 88 | return torch.cat(X, 0).to(self.device), torch.cat(Y, 0).to(self.device), torch.cat(H, 0).to(self.device) 89 | 90 | def remove(self, t): 91 | X = self.buffer[t]['X'] 92 | Y = self.buffer[t]['Y'] 93 | H = self.buffer[t]['H'] 94 | del self.buffer[t] 95 | return X, Y, H 96 | 97 | 98 | class Derpp(Base): 99 | def __init__(self, config): 100 | super(Derpp, self).__init__(config) 101 | self.n_mem = self.config.mem_budget 102 | self.memory = Replay(self.n_mem, config.dim_input, self.device) 103 | self.alpha = config.alpha 104 | self.beta = config.beta 105 | 106 | def der_loss(self, a, b, task_id): 107 | if self.config.scenario != 'domain': 108 | cpt = self.cpt 109 | a_ = a[..., task_id*cpt:(task_id+1)*cpt] 110 | b_ = b[..., task_id*cpt:(task_id+1)*cpt] 111 | loss = F.mse_loss(a_, b_) 112 | else: 113 | loss = F.mse_loss(a, b) 114 | return loss 115 | 116 | def learn(self, task_id, dataset): 117 | loader = DataLoader(dataset, 118 | batch_size=self.config.batch_size, 119 | shuffle=True, 120 | num_workers=2) 121 | self.opt = SGD(self.net.parameters(), 122 | lr=self.config.lr, 123 | momentum=self.config.momentum, 124 | weight_decay=self.config.weight_decay) 125 | 126 | self.n_iters = self.config.n_epochs * len(loader) 127 | for epoch in range(self.config.n_epochs): 128 | for i, (x, y) in enumerate(loader): 129 | x = x.to(self.device) 130 | y = y.to(self.device) 131 | 132 | # current loss 133 | h = self.forward(x, task_id) 134 | loss = self.loss_fn(h, y) 135 | 136 | n_prev_tasks = len(self.prev_tasks) 137 | for t in self.prev_tasks: 138 | x_past, y_past, h_past = self.memory.sample_task(self.config.mem_batch_size//n_prev_tasks, t) 139 | h_tmp = self.forward(x_past, t) 140 | loss += self.alpha * self.der_loss(h_tmp, h_past, t) / n_prev_tasks 141 | loss += self.beta * self.loss_fn(h_tmp, y_past) / n_prev_tasks 142 | 143 | self.opt.zero_grad() 144 | loss.backward() 145 | self.opt.step() 146 | self.memory.add(x, y, h.detach(), task_id) 147 | 148 | self.prev_tasks.append(task_id) 149 | 150 | def forget(self, task_id): 151 | self.memory.remove(task_id) 152 | self.prev_tasks.remove(task_id) 153 | 154 | self.opt = SGD(self.net.parameters(), 155 | lr=self.config.lr, 156 | momentum=self.config.momentum, 157 | weight_decay=self.config.weight_decay) 158 | 159 | n_prev_tasks = len(self.prev_tasks) 160 | #for i in range(self.config.forget_iters): 161 | for i in range(self.n_iters): 162 | self.opt.zero_grad() 163 | loss = 0.0 164 | for t in self.prev_tasks: 165 | x_past, y_past, h_past = self.memory.sample_task(self.config.mem_batch_size // n_prev_tasks, t) 166 | h_tmp = self.forward(x_past, t) 167 | loss += self.der_loss(h_tmp, h_past, t) / n_prev_tasks 168 | loss += self.loss_fn(h_tmp, y_past) / n_prev_tasks 169 | loss.backward() 170 | self.opt.step() 171 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | 5 | from config import Config 6 | from sklearn import preprocessing 7 | from torchvision import datasets, transforms 8 | from torch.utils.data import ConcatDataset, Dataset, random_split 9 | 10 | # ================== 11 | # Dataset Transforms 12 | # ================== 13 | 14 | _MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [ 15 | transforms.Pad(2), 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.5,), (0.5,)), 18 | ] 19 | 20 | _FASHION_MNIST_TRAIN_TRANSFORMS = _FASHION_MNIST_TEST_TRANSFORMS = [ 21 | transforms.Pad(2), 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.5,), (0.5,)), 24 | ] 25 | 26 | _E_MNIST_TRAIN_TRANSFORMS = _E_MNIST_TEST_TRANSFORMS = [ 27 | transforms.Pad(2), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.5,), (0.5,)), 30 | ] 31 | 32 | _CIFAR100_TRAIN_TRANSFORMS = [ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 37 | ] 38 | 39 | 40 | _CIFAR100_TEST_TRANSFORMS = [ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 43 | ] 44 | 45 | _CIFAR10_TRAIN_TRANSFORMS = [ 46 | transforms.RandomCrop(32, padding=4), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)) 50 | ] 51 | 52 | 53 | _CIFAR10_TEST_TRANSFORMS = [ 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)) 56 | ] 57 | 58 | 59 | class SubDataset(Dataset): 60 | '''To sub-sample a dataset, taking only those samples with label in [sub_labels]. 61 | 62 | After this selection of samples has been made, it is possible to transform the target-labels, 63 | which can be useful when doing continual learning with fixed number of output units.''' 64 | 65 | def __init__(self, original_dataset, sub_labels, is_train, target_transform=None): 66 | super().__init__() 67 | self.dataset = original_dataset 68 | self.sub_indices = [] 69 | for index in range(len(self.dataset)): 70 | if self.dataset.target_transform is None: 71 | label = self.dataset.targets[index] 72 | else: 73 | label = self.dataset.target_transform(self.dataset.targets[index]) 74 | if label in sub_labels: 75 | self.sub_indices.append(index) 76 | self.target_transform = target_transform 77 | 78 | def __len__(self): 79 | return len(self.sub_indices) 80 | 81 | def __getitem__(self, index): 82 | sample = self.dataset[self.sub_indices[index]] 83 | if self.target_transform: 84 | target = self.target_transform(sample[1]) 85 | sample = (sample[0], target) 86 | return sample 87 | 88 | 89 | def split_data(config): 90 | T = config.n_tasks 91 | CPT = config.class_per_task 92 | C = T * CPT 93 | permutation = np.arange(C) 94 | 95 | data = { 96 | 'mnist': datasets.MNIST, 97 | 'fashion': datasets.FashionMNIST, 98 | 'emnist': datasets.EMNIST, 99 | 'cifar100': datasets.CIFAR100, 100 | 'cifar10': datasets.CIFAR10, 101 | } 102 | train_transform = { 103 | 'mnist': _MNIST_TRAIN_TRANSFORMS, 104 | 'fashion': _FASHION_MNIST_TRAIN_TRANSFORMS, 105 | 'emnist': _E_MNIST_TRAIN_TRANSFORMS, 106 | 'cifar100': _CIFAR100_TRAIN_TRANSFORMS, 107 | 'cifar10': _CIFAR10_TRAIN_TRANSFORMS, 108 | } 109 | test_transform = { 110 | 'mnist': _MNIST_TEST_TRANSFORMS, 111 | 'fashion': _FASHION_MNIST_TEST_TRANSFORMS, 112 | 'emnist': _E_MNIST_TEST_TRANSFORMS, 113 | 'cifar100': _CIFAR100_TEST_TRANSFORMS, 114 | 'cifar10': _CIFAR10_TEST_TRANSFORMS, 115 | } 116 | 117 | train = data[config.dataset]('./data', 118 | train=True, 119 | download=True, 120 | transform=transforms.Compose(train_transform[config.dataset])) 121 | test = data[config.dataset]('./data', 122 | train=False, 123 | download=True, 124 | transform=transforms.Compose(test_transform[config.dataset])) 125 | 126 | # generate labels-per-task 127 | labels_per_task = [list(np.array(range(CPT)) + CPT * task_id) for task_id in range(T)] 128 | 129 | SD = SubDataset 130 | # split them up into sub-tasks 131 | train_datasets = [] 132 | test_datasets = [] 133 | for labels in labels_per_task: 134 | target_transform = None 135 | #target_transform = transforms.Lambda(lambda y, x=labels[0]: y-x) 136 | train_datasets.append(SD(train, labels, True, target_transform=target_transform)) 137 | test_datasets.append(SD(test, labels, False, target_transform=target_transform)) 138 | 139 | user_request_sequence = [ 140 | (0, "R"), 141 | (1, "T"), 142 | (2, "T"), 143 | (3, "R"), 144 | (1, "R"), 145 | (2, "F"), 146 | (4, "T"), 147 | (4, "F"), 148 | ] 149 | return train_datasets, test_datasets, user_request_sequence 150 | 151 | 152 | class TransformedMNISTDataset(Dataset): 153 | def __init__(self, data): 154 | self.X = data[1] 155 | self.Y = data[2] 156 | 157 | def __len__(self): 158 | return self.X.shape[0] 159 | 160 | def __getitem__(self, index): 161 | return self.X[index], self.Y[index] 162 | 163 | 164 | def get_transformed_mnist_dataset(config): 165 | if config.dataset == "rot_mnist": 166 | load_path = "./data/mnist_rotations.pt" 167 | elif config.dataset == "perm_mnist": 168 | load_path = "./data/mnist_permutations.pt" 169 | else: 170 | raise Exception(f"[ERROR] unknown dataset {config.dataset}") 171 | 172 | d_tr, d_te = torch.load(load_path) 173 | d_tr = d_tr[:config.n_tasks] 174 | d_te = d_te[:config.n_tasks] 175 | 176 | for t, (tr, te) in enumerate(zip(d_tr, d_te)): 177 | tr[2] = tr[2] + t * config.class_per_task 178 | te[2] = te[2] + t * config.class_per_task 179 | 180 | n_inputs = d_tr[0][1].size(1) 181 | n_outputs = 0 182 | for i in range(len(d_tr)): 183 | n_outputs = max(n_outputs, d_tr[i][2].max().item()) 184 | n_outputs = max(n_outputs, d_te[i][2].max().item()) 185 | 186 | train_datasets = [TransformedMNISTDataset(dataset) for dataset in d_tr] 187 | test_datasets = [TransformedMNISTDataset(dataset) for dataset in d_te] 188 | 189 | user_request_sequence = [ 190 | (0, "R"), 191 | (1, "T"), 192 | (2, "T"), 193 | (3, "R"), 194 | (1, "R"), 195 | (2, "F"), 196 | (4, "T"), 197 | (4, "F"), 198 | ] 199 | return train_datasets, test_datasets, user_request_sequence 200 | 201 | 202 | def get_cl_dataset(config): 203 | if config.dataset in ["rot_mnist", "perm_mnist"]: 204 | return get_transformed_mnist_dataset(config) 205 | else: 206 | return split_data(config) 207 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from config import Config 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from scipy import stats as st 8 | 9 | 10 | methods = { 11 | "perm-mnist": [ 12 | "perm_mnist_sequential_lr0.01", 13 | "perm_mnist_independent_lr0.01", 14 | "perm_mnist_ewc_lr0.01_ewclmbd100.0", 15 | "perm_mnist_er_lr0.01_mem200", 16 | "perm_mnist_derpp_lr0.01_mem200_alpha0.5_beta1.0", 17 | "perm_mnist_clpu_derpp_lr0.01_mem200_alpha0.5_beta1.0_pretrainFalse", 18 | "perm_mnist_clpu_derpp_lr0.01_mem200_alpha0.5_beta1.0_pretrainTrue", 19 | ], 20 | 21 | "rot-mnist": [ 22 | "rot_mnist_sequential_lr0.01", 23 | "rot_mnist_independent_lr0.01", 24 | "rot_mnist_ewc_lr0.01_ewclmbd100.0", 25 | "rot_mnist_er_lr0.01_mem200", 26 | "rot_mnist_derpp_lr0.01_mem200_alpha0.5_beta1.0", 27 | "rot_mnist_clpu_derpp_lr0.01_mem200_alpha0.5_beta1.0_pretrainFalse", 28 | "rot_mnist_clpu_derpp_lr0.01_mem200_alpha0.5_beta1.0_pretrainTrue", 29 | ], 30 | 31 | "split-cifar10": [ 32 | "cifar10_sequential_lr0.01", 33 | "cifar10_independent_lr0.01", 34 | "cifar10_ewc_lr0.01_ewclmbd500.0", 35 | "cifar10_er_lr0.01_mem200", 36 | "cifar10_derpp_lr0.01_mem200_alpha0.5_beta0.5", 37 | "cifar10_clpu_derpp_lr0.01_mem200_alpha0.5_beta0.5_pretrainFalse", 38 | "cifar10_clpu_derpp_lr0.01_mem200_alpha0.5_beta0.5_pretrainTrue", 39 | ], 40 | 41 | "split-cifar100": [ 42 | "cifar100_sequential_lr0.01", 43 | "cifar100_independent_lr0.01", 44 | "cifar100_ewc_lr0.01_ewclmbd1000.0", 45 | "cifar100_er_lr0.01_mem200", 46 | "cifar100_derpp_lr0.01_mem200_alpha0.5_beta1.0", 47 | "cifar100_clpu_derpp_lr0.01_mem200_alpha0.5_beta1.0_pretrainFalse", 48 | "cifar100_clpu_derpp_lr0.01_mem200_alpha0.5_beta1.0_pretrainTrue", 49 | ], 50 | } 51 | 52 | def get_log(name): 53 | 54 | seeds = [1,2,3,4,5] 55 | 56 | AA, FF = [], [] 57 | FA = [] 58 | 59 | IL = {} 60 | AL = {} 61 | mask = {} 62 | 63 | for seed in seeds: 64 | IL[seed] = {} 65 | AL[seed] = {} 66 | mask[seed] = {} 67 | 68 | res = torch.load(f"./results/{name}_seed{seed}.log") 69 | 70 | stats = res["stats"][0] 71 | user_requests = res["user_requests"] 72 | 73 | acc = stats["accuracy"] # (n_requests, n_tasks) 74 | Dr = stats["Dr_mask"] # (n_requests, n_tasks) 75 | Df = stats["Df_mask"] # (n_requests, n_tasks) 76 | logits = stats["logits"] # List (n_requests, n_data, class_per_task) * n_tasks 77 | 78 | a = (acc * Dr).sum(-1).sum(0) / Dr.sum() 79 | #print("="*80) 80 | #print(user_requests) 81 | #print(acc) 82 | #print(res["stats"][1]["accuracy"]) 83 | 84 | task_set = set() 85 | first_time_acc = torch.zeros(acc.shape[-1]) 86 | forget_idx = 0 87 | kl = 0.0 88 | 89 | forget_mask = torch.zeros_like(acc) 90 | 91 | for request_id, (task_id, learn_type, dr) in enumerate(user_requests): 92 | if (learn_type in ["R", "T"]) and (task_id not in task_set): 93 | first_time_acc[task_id] = acc[request_id, task_id] 94 | task_set.add(task_id) 95 | elif learn_type == "F": 96 | #forget_idx = min(1, forget_idx + 1) 97 | forget_idx = forget_idx + 1 98 | stats_f = res["stats"][1]#forget_idx] 99 | logits_ = stats_f["logits"] 100 | logits_ = [l[-1] for l in logits_] # List(n_data, class_per_task) * n_tasks 101 | 102 | IL[seed][forget_idx] = [l[request_id] for l in logits] 103 | AL[seed][forget_idx] = logits_ 104 | mask[seed][forget_idx] = Df[request_id] 105 | 106 | #L[forget_idx]["in"].append(l[request_id] for l in logits]) 107 | #L[forget_idx]["across"].append(logits_) 108 | #L[forget_idx]["mask"].append(Df[request_id]) 109 | 110 | #for my_l, forget_l, mask in zip([l[request_id] for l in logits], logits_, Df[request_id]): 111 | # if mask.item() > 0: 112 | # try: 113 | # kl += (F.softmax(my_l, -1) * (F.log_softmax(my_l, -1) - F.log_softmax(forget_l, -1))).sum(-1).mean() 114 | # except: 115 | # import pdb; pdb.set_trace() 116 | forget_mask[request_id][task_id] = 1.0 117 | 118 | #kl = kl / (forget_idx+1e-4) 119 | f = ((first_time_acc.view(1,-1) - acc) * Dr).sum(-1).sum(0) / Dr.sum() 120 | 121 | #fm = torch.Tensor(forget_mask).view(-1, 1) 122 | #fa = (acc * Df * fm).sum(-1).sum(0) / (Df * fm).sum() 123 | fa = (acc * forget_mask).sum() / forget_mask.sum() 124 | 125 | AA.append(a) 126 | FF.append(f) 127 | FA.append(fa) 128 | 129 | AA = torch.stack(AA) 130 | FF = torch.stack(FF) 131 | FA = torch.stack(FA) 132 | 133 | # calculate the KL divergence 134 | 135 | def js(a, b, m): 136 | #kl_ = 0.0 137 | js = 0.0 138 | cnt = 0 139 | for k in a.keys(): 140 | cnt += 1 141 | for my_l, forget_l, mm in zip(a[k], b[k], m[k]): 142 | if mm.item() > 0: 143 | p = F.softmax(my_l, -1) 144 | q = F.softmax(forget_l, -1) 145 | m = (p+q)/2 146 | js += 0.5 * (p * (p.log() - m.log())).sum(-1).mean() + 0.5 * (q * (q.log() - m.log())).sum(-1).mean() 147 | #kl_ += (F.softmax(my_l, -1) * (F.log_softmax(my_l, -1) - F.log_softmax(forget_l, -1))).sum(-1).mean() 148 | return js / cnt 149 | 150 | IGKL = [] 151 | AGKL = [] 152 | 153 | # in group 154 | for i in seeds: 155 | for j in seeds: 156 | if i == j: 157 | continue 158 | i1 = AL[i] 159 | i2 = AL[j] 160 | m = mask[i] 161 | IGKL.append(js(i1, i2, m)) 162 | 163 | # across group 164 | for i in seeds: 165 | for j in seeds: 166 | i1 = IL[i] 167 | i2 = AL[j] 168 | m = mask[i] 169 | AGKL.append(js(i1, i2, m)) 170 | 171 | IGKL = torch.stack(IGKL) 172 | AGKL = torch.stack(AGKL) 173 | return AA, FF, FA, IGKL, AGKL 174 | 175 | def get_method(x): 176 | if "sequential" in x: 177 | return "Seq" 178 | elif "lwf" in x: 179 | return "LwF" 180 | elif "ewc" in x and "lsf" not in x: 181 | return "EWC" 182 | elif "lsf" in x: 183 | return "LSF" 184 | elif "er" in x and "derpp" not in x: 185 | return "ER" 186 | elif "derpp" in x and "clpu" not in x: 187 | return "DER++" 188 | elif "clpu" in x and "pretrainTrue" in x: 189 | return "CLPU-DER++" 190 | else: 191 | return "CLPU-DER++ (w/o pretraining)" 192 | 193 | for data, method in methods.items(): 194 | print("="*30 + f" {data} "+ "="*30) 195 | for m in method: 196 | try: 197 | AA, FF, FA, IGKL, AGKL = get_log(m) 198 | except: 199 | print('failed ', m) 200 | continue 201 | mmax = IGKL.numpy().max() 202 | rate = (AGKL < mmax).sum().item() / AGKL.shape[0] 203 | mm = get_method(m) 204 | print(f"{mm:50s} & {AA.mean()*100:4.2f} \\fs {{ {AA.std()*100:4.2f} }} & {FF.mean()*100:4.2f} \\fs {{ {FF.std()*100:4.2f} }} & {IGKL.mean():4.2f} \\fs {{ {IGKL.std():4.2f} }} & {AGKL.mean():4.2f} \\fs {{ {AGKL.std():4.2f} }} & {abs(AGKL.mean()-IGKL.mean())/IGKL.mean():4.2f} & {rate:4.2f} \\\\") 205 | -------------------------------------------------------------------------------- /model/lsf.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import quadprog 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | 11 | from torch.optim import Adam, SGD, RMSprop 12 | from torch.utils.data import DataLoader 13 | from .backbone import resnet18 14 | from .base import * 15 | from .lwf import smooth, modified_kl_div 16 | 17 | 18 | class LSF(Base): 19 | def __init__(self, config): 20 | super(LSF, self).__init__(config) 21 | self.dim_input = config.dim_input 22 | self.mnemonic_code = torch.randn( 23 | config.n_tasks*config.class_per_task, 24 | *config.dim_input 25 | ).to(self.device) # all mnemonic codes 26 | 27 | self.mnemonic_target = torch.arange( 28 | config.n_tasks*config.class_per_task 29 | ).to(self.device) 30 | if self.config.scenario == 'domain': 31 | self.mnemonic_target = self.mnemonic_target % config.class_per_task 32 | 33 | self.logsoft = nn.LogSoftmax(dim=1) 34 | self.fish = {} 35 | self.checkpoint = {} 36 | 37 | self.old_net = None 38 | self.soft = torch.nn.Softmax(dim=-1) 39 | self.prev_dataset = None 40 | 41 | def penalty(self): 42 | ### ewc penalty 43 | if len(self.prev_tasks) == 0: 44 | return torch.tensor(0.0).to(self.device) 45 | current_param = self.net.get_params() 46 | penalty = 0.0 47 | for t in self.prev_tasks: 48 | penalty += (self.fish[t] * (current_param - self.checkpoint[t]).pow(2)).sum() 49 | return penalty 50 | 51 | def learn(self, task_id, dataset, is_forget=False): 52 | self.old_net = copy.deepcopy(self.net) 53 | self.old_net.eval() 54 | for p in self.old_net.parameters(): 55 | p.requires_grad = False 56 | 57 | loader = DataLoader(dataset, 58 | batch_size=self.config.batch_size, 59 | shuffle=True, 60 | num_workers=2) 61 | self.opt = SGD(self.net.parameters(), 62 | lr=self.config.lr, 63 | momentum=self.config.momentum, 64 | weight_decay=self.config.weight_decay) 65 | 66 | if task_id > 0: 67 | self.opt_cls = SGD(self.net.classifier.parameters(), 68 | lr=self.config.lr, 69 | momentum=self.config.momentum, 70 | weight_decay=self.config.weight_decay) 71 | for epoch in range(self.config.n_epochs): 72 | for i, (x, y) in enumerate(loader): 73 | x = x.to(self.device) 74 | y = y.to(self.device) 75 | loss = self.loss_fn(self.forward(x, task_id), y) 76 | self.opt_cls.zero_grad() 77 | loss.backward() 78 | self.opt_cls.step() 79 | 80 | lsf_gamma = self.config.lsf_gamma 81 | ewc_lmbd = self.config.ewc_lmbd 82 | 83 | self.n_iters = self.config.n_epochs * len(loader) 84 | 85 | if not is_forget: 86 | self.prev_dataset = (task_id, dataset) 87 | 88 | for epoch in range(self.config.n_epochs): 89 | for i, (x, y) in enumerate(loader): 90 | x = x.to(self.device) 91 | y = y.to(self.device) 92 | 93 | target_shape = [x.shape[0]] + [1] * (len(x.shape) - 1) 94 | lsf_lmbd = torch.rand(*target_shape).to(x.device) 95 | y_idx = y 96 | hat_x = lsf_lmbd * x + (1-lsf_lmbd) * self.mnemonic_code[y_idx] 97 | 98 | x_ = torch.cat([x, hat_x], 0) 99 | y_ = torch.cat([y, y], 0) 100 | loss = self.loss_fn(self.forward(x_, task_id), y_) 101 | 102 | if task_id > 0: 103 | loss += ewc_lmbd * self.penalty() 104 | 105 | n_prev_tasks = len(self.prev_tasks) 106 | for t in self.prev_tasks: 107 | loss += lsf_gamma * self.loss_fn( 108 | self.forward( 109 | self.mnemonic_code[t*self.cpt:(t+1)*self.cpt].view( 110 | -1, *self.dim_input), 111 | t), 112 | self.mnemonic_target[t*self.cpt:(t+1)*self.cpt] 113 | ) / n_prev_tasks 114 | 115 | # lwf 116 | outputs = self.forward(x, t)[...,t*self.cpt:(t+1)*self.cpt] 117 | with torch.no_grad(): 118 | targets = self.old_net.forward(x)[..., t*self.cpt:(t+1)*self.cpt] 119 | loss += self.config.lwf_alpha * modified_kl_div( 120 | smooth(self.soft(targets), 2, 1), 121 | smooth(self.soft(outputs), 2, 1)) / n_prev_tasks 122 | 123 | # current loss 124 | self.opt.zero_grad() 125 | loss.backward() 126 | self.opt.step() 127 | 128 | ### end of task 129 | if not is_forget: 130 | fish = torch.zeros_like(self.net.get_params()) 131 | 132 | for j, (x, y) in enumerate(loader): 133 | x = x.to(self.device) 134 | y = y.to(self.device) 135 | for ex, lab in zip(x, y): 136 | self.opt.zero_grad() 137 | output = self.net(ex.unsqueeze(0)) 138 | loss = - F.nll_loss(self.logsoft(output), lab.unsqueeze(0), reduction='none') 139 | exp_cond_prob = torch.mean(torch.exp(loss.detach().clone())) 140 | loss = torch.mean(loss) 141 | loss.backward() 142 | fish += exp_cond_prob * self.net.get_grads() ** 2 143 | 144 | fish /= (len(loader)*self.config.batch_size) 145 | self.prev_tasks.append(task_id) 146 | self.fish[task_id] = fish 147 | self.checkpoint[task_id] = self.net.get_params().data.clone() 148 | 149 | def forget(self, task_id): 150 | cpt = self.cpt 151 | self.prev_tasks.remove(task_id) 152 | del self.fish[task_id] 153 | del self.checkpoint[task_id] 154 | self.learn(self.prev_dataset[0], self.prev_dataset[1], is_forget=True) 155 | #self.opt = SGD(self.net.parameters(), 156 | # lr=self.config.lr, 157 | # momentum=self.config.momentum, 158 | # weight_decay=self.config.weight_decay) 159 | 160 | #lsf_gamma = self.config.lsf_gamma 161 | #ewc_lmbd = self.config.ewc_lmbd 162 | 163 | #for it in range(self.n_iters): 164 | # loss = ewc_lmbd * self.penalty() 165 | # n_prev_tasks = len(self.prev_tasks) 166 | # for t in self.prev_tasks: 167 | # loss += lsf_gamma * self.loss_fn( 168 | # self.forward( 169 | # self.mnemonic_code[t*self.cpt:(t+1)*self.cpt].view( 170 | # -1, *self.dim_input), t), 171 | # self.mnemonic_target[t*self.cpt:(t+1)*self.cpt] 172 | # ) / n_prev_tasks 173 | 174 | # # lwf 175 | # outputs = self.forward(x, t)[...,t*self.cpt:(t+1)*self.cpt] 176 | # with torch.no_grad(): 177 | # targets = self.old_net.forward(x)[..., t*self.cpt:(t+1)*self.cpt] 178 | # loss += self.config.lwf_alpha * modified_kl_div( 179 | # smooth(self.soft(targets), 2, 1), 180 | # smooth(self.soft(outputs), 2, 1)) / n_prev_tasks 181 | 182 | # # current loss 183 | # self.opt.zero_grad() 184 | # loss.backward() 185 | # self.opt.step() 186 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from config import Config 12 | from data import * 13 | from model import * 14 | from torch.utils.data import DataLoader, ConcatDataset 15 | 16 | 17 | def check_path(path): 18 | if not os.path.exists(path): 19 | print("[INFO] making folder %s" % path) 20 | os.makedirs(path) 21 | 22 | 23 | def evaluate(testsets, config, model, Dr, Df): 24 | model.eval() 25 | n_tasks = config.n_tasks 26 | L = torch.zeros(n_tasks) 27 | A = torch.zeros(n_tasks) 28 | logits = [] 29 | 30 | max_task = -1 31 | cpt = config.class_per_task 32 | with torch.no_grad(): 33 | 34 | for task, dataset in enumerate(testsets): 35 | max_task = max(max_task, task+1) 36 | 37 | bch = config.batch_size 38 | loader = DataLoader(dataset, batch_size=bch, shuffle=False) 39 | 40 | l = a = n = 0.0 41 | logit_ = torch.zeros(len(dataset), cpt) 42 | 43 | for i, (x, y) in enumerate(loader): 44 | y_tensor = y.to(config.device) 45 | x_tensor = x.to(config.device) 46 | y_ = model(x_tensor, task) 47 | 48 | l += F.cross_entropy(y_, y_tensor, reduction='sum').item() 49 | a += y_.argmax(-1).eq(y_tensor).float().sum().item() 50 | logit_[i*bch:i*bch+y_tensor.shape[0]].copy_( 51 | y_[...,cpt*task:cpt*(task+1)].cpu()) 52 | n += y_tensor.shape[0] 53 | 54 | L[task] = l / n 55 | A[task] = a / n 56 | logits.append(logit_) 57 | 58 | model.train() 59 | print("[INFO] loss: ", L[:max_task]) 60 | print("[INFO] acc.: ", A[:max_task]) 61 | return { 62 | 'loss': L, 63 | 'accuracy': A, 64 | 'logits': logits, 65 | } 66 | 67 | 68 | def get_continual_learning_unlearning_dataset(config): 69 | train_datasets, test_datasets, user_request_sequence = get_cl_dataset(config) # list(n_tasks * dataset), list(n_tasks * datasets) 70 | #user_request_sequence = [ 71 | # (0, "R"), 72 | # (1, "T"), 73 | # (2, "T"), 74 | # (3, "R"), 75 | # (1, "R"), 76 | # (2, "F"), 77 | 78 | # #(0, "R"), 79 | # #(1, "T"), 80 | # #(1, "F"), 81 | 82 | # #(0, "R"), 83 | # #(1, "T"), 84 | # #(1, "R"), 85 | #] 86 | 87 | def clear_all_forget_request(li): 88 | remove_list = [] 89 | for request_id, (task_id, learn_type, dr_list) in enumerate(li): 90 | if learn_type == "F": 91 | remove_list.append(request_id) 92 | for j in range(request_id): 93 | if li[j][0] == task_id and li[j][1] == "T": 94 | remove_list.append(j) 95 | break 96 | new_list = [] 97 | Dr_list = [] 98 | for request_id, (task_id, learn_type, dr_list) in enumerate(li): 99 | if request_id not in remove_list: 100 | if (learn_type in ["R", "T"]) and (task_id not in Dr_list): 101 | Dr_list.append(task_id) 102 | new_list.append((task_id, learn_type, list(Dr_list))) 103 | return new_list 104 | 105 | user_request_sequence_with_Dr = [] 106 | Dr_list = [] 107 | for task_id, learn_type in user_request_sequence: 108 | if (learn_type in ["R", "T"]) and (task_id not in Dr_list): 109 | Dr_list.append(task_id) 110 | elif learn_type == "F": 111 | Dr_list.remove(task_id) 112 | user_request_sequence_with_Dr.append((task_id, learn_type, list(Dr_list))) 113 | 114 | forget_learn_list = [] 115 | for request_id, (task_id, learn_type, dr_list) in enumerate(user_request_sequence_with_Dr): 116 | if learn_type == "F": 117 | list_upto = list(user_request_sequence_with_Dr[:request_id+1]) 118 | forget_learn_list.append(clear_all_forget_request(list_upto)) 119 | print(user_request_sequence_with_Dr) 120 | print(forget_learn_list) 121 | #forget_learn_list = [list(user_request_sequence_with_Dr)] 122 | return train_datasets, test_datasets, user_request_sequence_with_Dr, forget_learn_list 123 | 124 | 125 | def clu_train(config, model, train_datasets, test_datasets, user_request_sequence_with_Dr): 126 | 127 | n_tasks = config.n_tasks 128 | Df = [] 129 | 130 | loss = torch.zeros(len(user_request_sequence_with_Dr), n_tasks) 131 | accuracy = torch.zeros(len(user_request_sequence_with_Dr), n_tasks) 132 | times = torch.zeros(len(user_request_sequence_with_Dr)) 133 | Df_mask = torch.zeros(len(user_request_sequence_with_Dr), n_tasks) 134 | Dr_mask = torch.zeros(len(user_request_sequence_with_Dr), n_tasks) 135 | logits = [torch.zeros(len(user_request_sequence_with_Dr), 136 | len(ds), config.class_per_task) for ds in test_datasets] 137 | 138 | for request_id, (task_id, learn_type, Dr) in enumerate(user_request_sequence_with_Dr): 139 | if config.verbose: 140 | print('='*80) 141 | learn_type_str = { 142 | "R": "Learning", 143 | "T": "Temporarily learning", 144 | "F": "Forgetting", 145 | }[learn_type] 146 | print(f'[INFO] {learn_type_str} Task {task_id} ...') 147 | 148 | if learn_type == "F": # forget 149 | Df.append(task_id) 150 | 151 | # learn 152 | t0 = time.time() 153 | model.continual_learn_and_unlearn(task_id, train_datasets[task_id], learn_type) 154 | t1 = time.time() 155 | 156 | # evaluate 157 | for df in Df: 158 | Df_mask[request_id][df] = 1. 159 | for dr in Dr: 160 | Dr_mask[request_id][dr] = 1. 161 | stat = evaluate(test_datasets, config, model, Dr, Df) 162 | loss[request_id] = stat['loss'] 163 | accuracy[request_id] = stat['accuracy'] 164 | times[request_id] = t1-t0 165 | for t in range(n_tasks): 166 | #if stat['logits'][t] is not None: 167 | logits[t][request_id] = stat['logits'][t] 168 | 169 | return { 170 | 'loss': loss, 171 | 'accuracy': accuracy, 172 | 'times': times, 173 | 'Df_mask': Df_mask, 174 | 'Dr_mask': Dr_mask, 175 | 'logits': logits 176 | } 177 | 178 | 179 | def run(config): 180 | train_datasets, test_datasets, user_request_sequence_with_Dr, forget_learn_list = get_continual_learning_unlearning_dataset(config) 181 | print("[INFO] finish processing data") 182 | 183 | fn_map = { 184 | "independent": Independent, 185 | "sequential": Sequential, 186 | "ewc": EWC, 187 | "er" : ER, 188 | "lsf": LSF, 189 | "lwf": LwF, 190 | "clu_er": CLU_ER, 191 | "derpp": Derpp, 192 | "clpu_derpp": CLPU_Derpp, 193 | } 194 | model = fn_map[config.method](config).to(config.device) 195 | sd = model.state_dict() 196 | check_path('./results') 197 | 198 | if config.method == "ewc": 199 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_ewclmbd{config.ewc_lmbd}_seed{config.seed}" 200 | elif config.method == "lwf": 201 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_lwfalpha{config.lwf_alpha}_lwftemp{config.lwf_temp}_seed{config.seed}" 202 | elif config.method == "lsf": 203 | exp_name = f"{config.dataset}_new{config.method}_lr{config.lr}_lsfgamma{config.lsf_gamma}_ewclmbd{config.ewc_lmbd}_seed{config.seed}" 204 | elif config.method == "er": 205 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_mem{config.mem_budget}_seed{config.seed}" 206 | elif config.method == "clu_er": 207 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_mem{config.mem_budget}_pretrain{config.use_pretrain}_seed{config.seed}" 208 | elif config.method == "derpp": 209 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_mem{config.mem_budget}_alpha{config.alpha}_beta{config.beta}_seed{config.seed}" 210 | elif config.method == "clpu_derpp": 211 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_mem{config.mem_budget}_alpha{config.alpha}_beta{config.beta}_pretrain{config.use_pretrain}_seed{config.seed}" 212 | else: 213 | exp_name = f"{config.dataset}_{config.method}_lr{config.lr}_seed{config.seed}" 214 | 215 | stats = [] 216 | #for urs in [user_request_sequence_with_Dr] + forget_learn_list: 217 | for urs in [user_request_sequence_with_Dr] + [forget_learn_list[0]]: 218 | print("[INFO] processing user's requests:") 219 | print(urs) 220 | model = fn_map[config.method](config).to(config.device) 221 | model.load_state_dict(sd) 222 | stat = clu_train(config, model, train_datasets, test_datasets, urs) 223 | stats.append(stat) 224 | 225 | result = { 226 | 'stats': stats, 227 | 'user_requests': user_request_sequence_with_Dr, 228 | 'forget_learn_list': forget_learn_list, 229 | } 230 | torch.save(result, f"./results/{exp_name}.log") 231 | 232 | 233 | if __name__ == "__main__": 234 | config = Config() 235 | 236 | # control seed 237 | torch.manual_seed(config.seed) 238 | torch.cuda.manual_seed_all(config.seed) 239 | np.random.seed(config.seed) 240 | random.seed(config.seed) 241 | torch.backends.cudnn.enabled=False 242 | torch.backends.cudnn.deterministic=True 243 | 244 | run(config) 245 | -------------------------------------------------------------------------------- /model/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.functional import relu, avg_pool2d 12 | from typing import List 13 | 14 | ############################################################################### 15 | # 16 | # Backbone 17 | # 18 | ############################################################################### 19 | 20 | 21 | def xavier(m: nn.Module) -> None: 22 | """ 23 | Applies Xavier initialization to linear modules. 24 | 25 | :param m: the module to be initialized 26 | 27 | Example:: 28 | >>> net = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) 29 | >>> net.apply(xavier) 30 | """ 31 | if m.__class__.__name__ == 'Linear': 32 | fan_in = m.weight.data.size(1) 33 | fan_out = m.weight.data.size(0) 34 | std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out)) 35 | a = math.sqrt(3.0) * std 36 | m.weight.data.uniform_(-a, a) 37 | if m.bias is not None: 38 | m.bias.data.fill_(0.0) 39 | 40 | 41 | def num_flat_features(x: torch.Tensor) -> int: 42 | """ 43 | Computes the total number of items except the first dimension. 44 | 45 | :param x: input tensor 46 | :return: number of item from the second dimension onward 47 | """ 48 | size = x.size()[1:] 49 | num_features = 1 50 | for ff in size: 51 | num_features *= ff 52 | return num_features 53 | 54 | 55 | class MammothBackbone(nn.Module): 56 | 57 | def __init__(self, **kwargs) -> None: 58 | super(MammothBackbone, self).__init__() 59 | 60 | def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: 61 | raise NotImplementedError 62 | 63 | def features(self, x: torch.Tensor) -> torch.Tensor: 64 | return self.forward(x, returnt='features') 65 | 66 | def get_params(self) -> torch.Tensor: 67 | """ 68 | Returns all the parameters concatenated in a single tensor. 69 | :return: parameters tensor (??) 70 | """ 71 | params = [] 72 | for pp in list(self.parameters()): 73 | params.append(pp.view(-1)) 74 | return torch.cat(params) 75 | 76 | def set_params(self, new_params: torch.Tensor) -> None: 77 | """ 78 | Sets the parameters to a given value. 79 | :param new_params: concatenated values to be set (??) 80 | """ 81 | assert new_params.size() == self.get_params().size() 82 | progress = 0 83 | for pp in list(self.parameters()): 84 | cand_params = new_params[progress: progress + 85 | torch.tensor(pp.size()).prod()].view(pp.size()) 86 | progress += torch.tensor(pp.size()).prod() 87 | pp.data = cand_params 88 | 89 | def get_grads(self) -> torch.Tensor: 90 | """ 91 | Returns all the gradients concatenated in a single tensor. 92 | :return: gradients tensor (??) 93 | """ 94 | return torch.cat(self.get_grads_list()) 95 | 96 | def get_grads_list(self): 97 | """ 98 | Returns a list containing the gradients (a tensor for each layer). 99 | :return: gradients list 100 | """ 101 | grads = [] 102 | for pp in list(self.parameters()): 103 | grads.append(pp.grad.view(-1)) 104 | return grads 105 | 106 | 107 | ############################################################################### 108 | # 109 | # Resnet-18 110 | # 111 | ############################################################################### 112 | 113 | def conv3x3(in_planes: int, out_planes: int, stride: int=1) -> F.conv2d: 114 | """ 115 | Instantiates a 3x3 convolutional layer with no bias. 116 | :param in_planes: number of input channels 117 | :param out_planes: number of output channels 118 | :param stride: stride of the convolution 119 | :return: convolutional layer 120 | """ 121 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 122 | padding=1, bias=False) 123 | 124 | 125 | class BasicBlock(nn.Module): 126 | """ 127 | The basic block of ResNet. 128 | """ 129 | expansion = 1 130 | 131 | def __init__(self, in_planes: int, planes: int, stride: int=1) -> None: 132 | """ 133 | Instantiates the basic block of the network. 134 | :param in_planes: the number of input channels 135 | :param planes: the number of channels (to be possibly expanded) 136 | """ 137 | super(BasicBlock, self).__init__() 138 | self.conv1 = conv3x3(in_planes, planes, stride) 139 | self.bn1 = nn.BatchNorm2d(planes) 140 | self.conv2 = conv3x3(planes, planes) 141 | self.bn2 = nn.BatchNorm2d(planes) 142 | 143 | self.shortcut = nn.Sequential() 144 | if stride != 1 or in_planes != self.expansion * planes: 145 | self.shortcut = nn.Sequential( 146 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, 147 | stride=stride, bias=False), 148 | nn.BatchNorm2d(self.expansion * planes) 149 | ) 150 | 151 | def forward(self, x: torch.Tensor) -> torch.Tensor: 152 | """ 153 | Compute a forward pass. 154 | :param x: input tensor (batch_size, input_size) 155 | :return: output tensor (10) 156 | """ 157 | out = relu(self.bn1(self.conv1(x))) 158 | out = self.bn2(self.conv2(out)) 159 | out += self.shortcut(x) 160 | out = relu(out) 161 | return out 162 | 163 | 164 | class ResNet(MammothBackbone): 165 | """ 166 | ResNet network architecture. Designed for complex datasets. 167 | """ 168 | 169 | def __init__(self, block: BasicBlock, num_blocks: List[int], 170 | num_classes: int, nf: int) -> None: 171 | """ 172 | Instantiates the layers of the network. 173 | :param block: the basic ResNet block 174 | :param num_blocks: the number of blocks per layer 175 | :param num_classes: the number of output classes 176 | :param nf: the number of filters 177 | """ 178 | super(ResNet, self).__init__() 179 | self.in_planes = nf 180 | self.block = block 181 | self.num_classes = num_classes 182 | self.nf = nf 183 | self.conv1 = conv3x3(3, nf * 1) 184 | self.bn1 = nn.BatchNorm2d(nf * 1) 185 | self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) 186 | self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) 187 | self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) 188 | self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) 189 | self.linear = nn.Linear(nf * 8 * block.expansion, num_classes) 190 | 191 | self._features = nn.Sequential(self.conv1, 192 | self.bn1, 193 | nn.ReLU(), 194 | self.layer1, 195 | self.layer2, 196 | self.layer3, 197 | self.layer4 198 | ) 199 | self.classifier = self.linear 200 | 201 | def _make_layer(self, block: BasicBlock, planes: int, 202 | num_blocks: int, stride: int) -> nn.Module: 203 | """ 204 | Instantiates a ResNet layer. 205 | :param block: ResNet basic block 206 | :param planes: channels across the network 207 | :param num_blocks: number of blocks 208 | :param stride: stride 209 | :return: ResNet layer 210 | """ 211 | strides = [stride] + [1] * (num_blocks - 1) 212 | layers = [] 213 | for stride in strides: 214 | layers.append(block(self.in_planes, planes, stride)) 215 | self.in_planes = planes * block.expansion 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: 219 | """ 220 | Compute a forward pass. 221 | :param x: input tensor (batch_size, *input_shape) 222 | :param returnt: return type (a string among 'out', 'features', 'all') 223 | :return: output tensor (output_classes) 224 | """ 225 | 226 | out = relu(self.bn1(self.conv1(x))) # 64, 32, 32 227 | if hasattr(self, 'maxpool'): 228 | out = self.maxpool(out) 229 | out = self.layer1(out) # -> 64, 32, 32 230 | out = self.layer2(out) # -> 128, 16, 16 231 | out = self.layer3(out) # -> 256, 8, 8 232 | out = self.layer4(out) # -> 512, 4, 4 233 | out = avg_pool2d(out, out.shape[2]) # -> 512, 1, 1 234 | feature = out.view(out.size(0), -1) # 512 235 | 236 | if returnt == 'features': 237 | return feature 238 | 239 | out = self.classifier(feature) 240 | 241 | if returnt == 'out': 242 | return out 243 | elif returnt == 'all': 244 | return (out, feature) 245 | 246 | raise NotImplementedError("Unknown return type") 247 | 248 | 249 | def resnet18(nclasses: int, nf: int=64) -> ResNet: 250 | """ 251 | Instantiates a ResNet18 network. 252 | :param nclasses: number of output classes 253 | :param nf: number of filters 254 | :return: ResNet network 255 | """ 256 | return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf) 257 | 258 | 259 | class MLP(MammothBackbone): 260 | def __init__(self, config): 261 | super(MLP, self).__init__() 262 | 263 | dim_hidden = 100 264 | dim_input = np.prod(config.dim_input) 265 | if config.scenario == 'domain': 266 | dim_output = config.class_per_task 267 | else: 268 | dim_output = config.n_tasks * config.class_per_task 269 | 270 | # Linear layers 271 | self.net = nn.Sequential( 272 | nn.Linear(dim_input, dim_hidden), 273 | nn.ReLU(), 274 | nn.Linear(dim_hidden, dim_hidden), 275 | nn.ReLU()) 276 | self.classifier = nn.Linear(dim_hidden, dim_output) 277 | 278 | def forward(self, x): 279 | x = torch.flatten(x, 1) 280 | return self.classifier(self.net(x)) 281 | --------------------------------------------------------------------------------