├── README.md ├── buffer.py ├── current_buffer.py ├── min_norm_solvers.py ├── ours.py └── oursRehSel.py /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Example Influence in Continual Learning 2 | This is the official implementation of the Exploring Example Influence in Continual Learning in PyTorch, published at NeurIPS 2022. 3 | 4 | ## Requirements 5 | Pytorch>=1.3.0 6 | 7 | ## Main Algorithm 8 | 9 | ours.py is the method to acquire influence and update model by example influence, whose results are reported as 'Ours'. 10 | 11 | oursRehSel.py adds the rehearsal selection strategy using example influence, whose results are reported as 'Ours+RehSel'. 12 | 13 | min_norm_solvers.py focuses on influence fusion to SP Pareto Optimal. 14 | 15 | buffer.py is memory buffer with fixed size. 16 | 17 | current_buffer.py is the buffer for new task, which is actually not needed. This part is just for our convenience in coding. 18 | 19 | ## Cite us 20 | 21 | ``` 22 | @inproceedings{MetaSP, 23 | title={Exploring Example Influence in Continual Learning}, 24 | author={Sun, Qing and Lyu, Fan and Shang, Fanhua and Feng, Wei and Wan, Liang}, 25 | booktitle={NeurIPS}, 26 | year={2022} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Tuple 4 | from torchvision import transforms 5 | 6 | 7 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 8 | if num_seen_examples < buffer_size: 9 | return num_seen_examples 10 | 11 | rand = np.random.randint(0, num_seen_examples + 1) 12 | if rand < buffer_size: 13 | return rand 14 | else: 15 | return -1 16 | 17 | 18 | class Buffer: 19 | """ 20 | The memory buffer. 21 | """ 22 | def __init__(self, buffer_size, device): 23 | self.buffer_size = buffer_size 24 | self.device = device 25 | self.num_seen_examples = 0 26 | self.attributes = ['examples', 'labels', 'logits', 'task_labels', 'score'] 27 | 28 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 29 | logits: torch.Tensor, task_labels: torch.Tensor, score: torch.Tensor) -> None: 30 | """ 31 | Initializes just the required tensors. 32 | :param examples: tensor containing the images 33 | :param labels: tensor containing the labels 34 | :param logits: tensor containing the outputs of the network 35 | :param task_labels: tensor containing the task labels 36 | :param score: tensor containing the influence 37 | """ 38 | for attr_str in self.attributes: 39 | attr = eval(attr_str) 40 | if attr is not None and not hasattr(self, attr_str): 41 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 42 | setattr(self, attr_str, torch.zeros((self.buffer_size, *attr.shape[1:]), dtype=typ, device=self.device)) 43 | 44 | def add_data(self, examples, labels=None, logits=None, task_labels=None, score=None): 45 | if not hasattr(self, 'examples'): 46 | self.init_tensors(examples, labels, logits, task_labels, score) 47 | 48 | for i in range(examples.shape[0]): 49 | index = reservoir(self.num_seen_examples, self.buffer_size) 50 | self.num_seen_examples += 1 51 | if index >= 0: 52 | self.examples[index] = examples[i].to(self.device) 53 | if labels is not None: 54 | self.labels[index] = labels[i].to(self.device) 55 | if logits is not None: 56 | self.logits[index] = logits[i].to(self.device) 57 | if task_labels is not None: 58 | self.task_labels[index] = task_labels[i].to(self.device) 59 | if score is not None: 60 | self.score[index] = score[i].to(self.device) 61 | 62 | def get_data(self, size: int, transform: transforms=None, fsr=False, current_task=0) -> Tuple: 63 | if size > self.examples.shape[0]: 64 | size = self.examples.shape[0] 65 | if fsr and current_task > 0: 66 | past_examples = self.examples[self.task_labels != current_task] 67 | if size > past_examples.shape[0]: 68 | size = past_examples.shape[0] 69 | if past_examples.shape[0]: 70 | choice = np.random.choice(min(self.num_seen_examples, past_examples.shape[0]), size=size, replace=False) 71 | if transform is None: transform = lambda x: x 72 | ret_tuple = (torch.stack([transform(ee.cpu()) for ee in past_examples[choice]]).to(self.device),) 73 | for attr_str in self.attributes[1:]: 74 | if hasattr(self, attr_str): 75 | attr = getattr(self, attr_str) 76 | ret_tuple += (attr[self.task_labels != current_task][choice],) 77 | else: return tuple([torch.tensor([0])] * 4) 78 | else: 79 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), size=min(self.num_seen_examples, size), replace=False) 80 | if transform is None: transform = lambda x: x 81 | ret_tuple = (torch.stack([transform(ee.cpu()) for ee in self.examples[choice]]).to(self.device),) 82 | for attr_str in self.attributes[1:]: 83 | if hasattr(self, attr_str): 84 | attr = getattr(self, attr_str) 85 | ret_tuple += (attr[choice],) 86 | 87 | return ret_tuple 88 | 89 | def get_data_gmed(self, size: int, transform: transforms=None, fsr=False, current_task=0) -> Tuple: 90 | if size > self.examples.shape[0]: 91 | size = self.examples.shape[0] 92 | if fsr and current_task > 0: 93 | past_examples = self.examples[self.task_labels != current_task] 94 | if size > past_examples.shape[0]: 95 | size = past_examples.shape[0] 96 | if past_examples.shape[0]: 97 | choice = np.random.choice(min(self.num_seen_examples, past_examples.shape[0]), size=size, replace=False) 98 | if transform is None: transform = lambda x: x 99 | ret_tuple = (torch.stack([transform(ee.cpu()) for ee in past_examples[choice]]).to(self.device),) 100 | for attr_str in self.attributes[1:]: 101 | if hasattr(self, attr_str): 102 | attr = getattr(self, attr_str) 103 | ret_tuple += (attr[self.task_labels != current_task][choice],) 104 | else: return tuple([torch.tensor([0])] * 4) 105 | return ret_tuple + (choice,) 106 | 107 | 108 | def replace_data(self, index, input, label, task_id): 109 | if index.shape[0] != input.shape[0]: 110 | choice = np.random.choice(min(index.shape[0], input.shape[0]), size=min(index.shape[0], input.shape[0]), replace=False) 111 | if index.shape[0] > input.shape[0]: 112 | self.examples[index[choice]] = input.to(self.device) 113 | self.labels[index[choice]] = label.to(self.device) 114 | self.task_labels[index[choice]] = task_id.to(self.device) 115 | elif index.shape[0] < input.shape[0]: 116 | self.examples[index] = input[choice].to(self.device) 117 | self.labels[index] = label[choice].to(self.device) 118 | self.task_labels[index] = task_id[choice].to(self.device) 119 | else: 120 | self.examples[index] = input.to(self.device) 121 | self.labels[index] = label.to(self.device) 122 | self.task_labels[index] = task_id.to(self.device) 123 | 124 | 125 | def replace_keshihua_data(self, index, input, label, task_id, score): 126 | if index.shape[0] != input.shape[0]: 127 | choice = np.random.choice(min(index.shape[0], input.shape[0]), size=min(index.shape[0], input.shape[0]), replace=False) 128 | if index.shape[0] > input.shape[0]: 129 | self.examples[index[choice]] = input.to(self.device) 130 | self.labels[index[choice]] = label.to(self.device) 131 | self.task_labels[index[choice]] = task_id.to(self.device) 132 | self.score[index[choice]] = score.to(self.device) 133 | elif index.shape[0] < input.shape[0]: 134 | self.examples[index] = input[choice].to(self.device) 135 | self.labels[index] = label[choice].to(self.device) 136 | self.task_labels[index] = task_id[choice].to(self.device) 137 | self.score[index] = score[choice].to(self.device) 138 | else: 139 | self.examples[index] = input.to(self.device) 140 | self.labels[index] = label.to(self.device) 141 | self.task_labels[index] = task_id.to(self.device) 142 | self.score[index] = score.to(self.device) 143 | 144 | def replace_score(self, new_score, index): 145 | self.score[index] = new_score.to(self.device) 146 | 147 | def delete_data(self, num, task): 148 | index = [] 149 | class_num = torch.max(self.labels).item()+1 150 | if num < class_num: 151 | num = class_num 152 | for i in range(class_num): 153 | task_ind = torch.where(self.labels == i) 154 | ranking = task_ind[0][:(num // class_num)].tolist() 155 | index = index + ranking 156 | index = torch.tensor(index) 157 | return index 158 | 159 | def delete_data_basedscore(self, num, task): 160 | index = [] 161 | class_num = torch.max(self.labels).item() + 1 162 | if num < class_num: 163 | num = class_num 164 | for i in range(class_num): 165 | task_ind = torch.where(self.labels == i) 166 | a = torch.sort(self.score[task_ind[0]][:, 2], descending=True) 167 | ranking = task_ind[0][a[1]][:(num // class_num)].tolist() 168 | index = index + ranking 169 | index = torch.tensor(index) 170 | return index 171 | 172 | 173 | def reset_score(self): 174 | new_score = torch.ones((self.examples.shape[0], 3), device=self.device) 175 | self.score = new_score -------------------------------------------------------------------------------- /current_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Tuple 4 | from torchvision import transforms 5 | from utils.ourskmeans import cluster 6 | 7 | 8 | def reservoir(num_examples: int, buffer_size: int) -> int: 9 | if num_examples < buffer_size: 10 | return num_examples 11 | 12 | rand = np.random.randint(0, num_examples + 1) 13 | if rand < buffer_size: 14 | return rand 15 | else: 16 | return -1 17 | 18 | 19 | class CurrentBuffer: 20 | """ 21 | The new task buffer which is actually not needed. 22 | This part is just for our convenience in coding. 23 | """ 24 | def __init__(self, buffer_size, device): 25 | self.buffer_size = buffer_size 26 | self.device = device 27 | self.num_examples = 0 28 | self.attributes = ['examples', 'labels', 'task_labels', 'scores', 'img_id'] 29 | 30 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 31 | task_labels: torch.Tensor, scores: torch.Tensor, img_id: torch.Tensor) -> None: 32 | """ 33 | Initializes just the required tensors. 34 | :param examples: tensor containing the images 35 | :param labels: tensor containing the labels 36 | :param logits: tensor containing the outputs of the network 37 | :param task_labels: tensor containing the task labels 38 | :param scores: tensor example influence 39 | :param img_id: tensor image id for compute influence in multi epochs 40 | """ 41 | for attr_str in self.attributes: 42 | attr = eval(attr_str) 43 | if attr is not None and not hasattr(self, attr_str): 44 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 45 | setattr(self, attr_str, torch.zeros((self.buffer_size, *attr.shape[1:]), dtype=typ, device=self.device)) 46 | 47 | def add_data(self, examples, labels=None, task_labels=None, scores=None, img_id=None): 48 | if not hasattr(self, 'examples'): 49 | self.init_tensors(examples, labels, task_labels, scores, img_id) 50 | 51 | for i in range(examples.shape[0]): 52 | index = reservoir(self.num_examples, self.buffer_size) 53 | self.num_examples += 1 54 | if index >= 0: 55 | self.examples[index] = examples[i].to(self.device) 56 | if labels is not None: 57 | self.labels[index] = labels[i].to(self.device) 58 | if task_labels is not None: 59 | self.task_labels[index] = task_labels[i].to(self.device) 60 | if scores is not None: 61 | self.scores[index] = scores[i].to(self.device) 62 | if img_id is not None: 63 | self.img_id[index] = img_id[i].to(self.device) 64 | 65 | def get_data(self, size: int, transform: transforms=None, fsr=False, current_task=0) -> Tuple: 66 | 67 | if size > self.examples.shape[0]: 68 | size = self.examples.shape[0] 69 | choice = np.random.choice(min(self.num_examples, self.examples.shape[0]), size=min(self.num_examples, size), replace=False) 70 | if transform is None: transform = lambda x: x 71 | ret_tuple = (torch.stack([transform(ee.cpu()) for ee in self.examples[choice]]).to(self.device),) 72 | for attr_str in self.attributes[1:]: 73 | if hasattr(self, attr_str): 74 | attr = getattr(self, attr_str) 75 | ret_tuple += (attr[choice],) 76 | return ret_tuple[:2] 77 | 78 | 79 | def get_all_data(self, size: int, transform: transforms=None, fsr=False, current_task=0) -> Tuple: 80 | if size > self.examples.shape[0]: 81 | size = self.examples.shape[0] 82 | choice = torch.from_numpy(np.random.choice(min(self.num_examples, self.examples.shape[0]), size=min(self.num_examples, size), replace=False)).to(self.device) 83 | if transform is None: transform = lambda x: x 84 | ret_tuple = (torch.stack([transform(ee.cpu()) for ee in self.examples[choice]]).to(self.device),) 85 | for attr_str in self.attributes[1:]: 86 | if hasattr(self, attr_str): 87 | attr = getattr(self, attr_str) 88 | ret_tuple += (attr[choice],) 89 | return ret_tuple + (choice,) 90 | 91 | 92 | def get_input_score(self, img_id, shape): 93 | a = [torch.where(self.img_id == img_id[i])[0].cpu().numpy()[0] for i in range(shape)] 94 | index = torch.tensor(a) 95 | return index, self.scores[index] 96 | 97 | 98 | def replace_scores(self, index, mem_scores): 99 | for i in range(len(mem_scores)): 100 | self.scores[index[i]] = mem_scores[i].to(self.device) 101 | 102 | def ourkmeans(self, replace): 103 | num_centers = replace 104 | kmeansdata = torch.reshape(self.examples, [self.examples.shape[0], -1]) 105 | centers, codes, distance = cluster(kmeansdata, num_centers, self.device) 106 | return codes, distance 107 | 108 | 109 | def score(self, replace, codes): 110 | ranking = [] 111 | for i in range(replace): 112 | kmeams_label = torch.where(codes == i) 113 | maxscore_index = kmeams_label[0][torch.argmin(self.scores[kmeams_label][:, 2]).item()].item() 114 | ranking.append(maxscore_index) 115 | ranking = torch.tensor(ranking).to(self.device) 116 | return ranking -------------------------------------------------------------------------------- /min_norm_solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class MinNormSolver: 6 | MAX_ITER = 250 7 | STOP_CRIT = 1e-5 8 | 9 | def _min_norm_element_from2(v1v1, v1v2, v2v2): 10 | """ 11 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2 12 | d is the distance (objective) optimzed 13 | v1v1 = 14 | v1v2 = 15 | v2v2 = 16 | """ 17 | if v1v2 >= v1v1: 18 | # Case: Fig 1, third column 19 | gamma = 0.999 20 | cost = v1v1 21 | return gamma, cost 22 | if v1v2 >= v2v2: 23 | # Case: Fig 1, first column 24 | gamma = 0.001 25 | cost = v2v2 26 | return gamma, cost 27 | # Case: Fig 1, second column 28 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2)) 29 | cost = v2v2 + gamma * (v1v2 - v2v2) 30 | return gamma, cost 31 | 32 | def _min_norm_2d(vecs, dps): 33 | """ 34 | Find the minimum norm solution as combination of two points 35 | This is correct only in 2D 36 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j 37 | """ 38 | dmin = 1e8 39 | for i in range(len(vecs)): 40 | for j in range(i + 1, len(vecs)): 41 | if (i, j) not in dps: 42 | dps[(i, j)] = 0.0 43 | for k in range(len(vecs[i])): 44 | dps[(i, j)] += torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu() 45 | dps[(j, i)] = dps[(i, j)] 46 | if (i, i) not in dps: 47 | dps[(i, i)] = 0.0 48 | for k in range(len(vecs[i])): 49 | dps[(i, i)] += torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu() 50 | if (j, j) not in dps: 51 | dps[(j, j)] = 0.0 52 | for k in range(len(vecs[i])): 53 | dps[(j, j)] += torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu() 54 | c, d = MinNormSolver._min_norm_element_from2(dps[(i, i)], dps[(i, j)], dps[(j, j)]) 55 | if d < dmin: 56 | dmin = d 57 | sol = [(i, j), c, d] 58 | return sol, dps 59 | 60 | 61 | def find_min_norm_element(vecs): 62 | """ 63 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 64 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 65 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 66 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence 67 | """ 68 | # Solution lying at the combination of two points 69 | dps = {} 70 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 71 | 72 | n = len(vecs) 73 | sol_vec = np.zeros(n) 74 | sol_vec[init_sol[0][0]] = init_sol[1] 75 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 76 | 77 | if n < 3: 78 | # This is optimal for n=2, so return the solution 79 | return sol_vec, init_sol[2] 80 | 81 | def gradient_normalizers(grads, losses, normalization_type): 82 | # gradient normalization method supposed in MGDA 83 | gn = [] 84 | if normalization_type == 'ours': 85 | for t in range(len(grads)): 86 | gn.append(np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])) / losses[t]) 87 | return gn -------------------------------------------------------------------------------- /ours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer import Buffer 3 | from utils.args import * 4 | from models.utils.continual_model import ContinualModel 5 | import random 6 | from utils.current_buffer import CurrentBuffer 7 | import higher 8 | from utils.min_norm_solvers import MinNormSolver, gradient_normalizers 9 | import numpy as np 10 | 11 | def get_parser() -> ArgumentParser: 12 | parser = ArgumentParser(description='Continual learning') 13 | add_management_args(parser) 14 | add_experiment_args(parser) 15 | add_rehearsal_args(parser) 16 | return parser 17 | 18 | 19 | class Ours(ContinualModel): 20 | NAME = 'ours' 21 | COMPATIBILITY = ['class-il', 'task-il'] 22 | 23 | def __init__(self, backbone, loss, args, transform): 24 | super(Ours, self).__init__(backbone, loss, args, transform) 25 | self.buffer = Buffer(self.args.buffer_size, self.device) 26 | self.transform = None 27 | self.current_task = 0 28 | 29 | def end_task(self, dataset): 30 | replace = self.args.buffer_size // (self.current_task + 1) 31 | delete_ind = self.buffer.delete_data(replace, task = self.current_task) 32 | bx, by, b_ids, b_scores, b_imgid, b_ind = self.currentbuffer.get_all_data(self.currentbuffer.num_examples, transform=self.transform) 33 | index = np.random.choice(bx.shape[0], size=min(bx.shape[0], replace), replace=False) 34 | self.buffer.replace_data(delete_ind, bx[index], by[index], b_ids[index]) 35 | self.current_task = self.current_task + 1 36 | 37 | 38 | def observe(self, inputs, labels, img_id, not_aug_inputs, task, args, epoch): 39 | 40 | real_batch_size = inputs.shape[0] 41 | task_labels = torch.ones(real_batch_size, dtype=torch.long).to(self.device) * task 42 | if task == 0: 43 | self.opt.zero_grad() 44 | outputs = self.net(inputs) 45 | loss = self.loss(outputs, labels) 46 | loss.backward() 47 | self.opt.step() 48 | # for task 1, random select data to store 49 | self.buffer.add_data(examples=inputs, labels=labels, task_labels=task_labels) 50 | return loss.item() 51 | else: 52 | if epoch<45: 53 | # naive fine-tuning 54 | self.opt.zero_grad() 55 | outputs = self.net(inputs) 56 | loss = self.loss(outputs, labels) 57 | loss.backward() 58 | self.opt.step() 59 | return loss.item() 60 | else: 61 | self.opt.zero_grad() 62 | # get mem data 63 | mem_x, mem_y, mem_ids = self.buffer.get_data(self.args.minibatch_size, transform=self.transform, fsr=True, current_task=task) 64 | total = torch.cat((inputs, mem_x)) 65 | total_labels = torch.cat((labels, mem_y)) 66 | # size of validation sets 67 | subsample = self.buffer.buffer_size // 10 68 | # get old task validation set 69 | bx, by, b_ids = self.buffer.get_data(subsample, transform=self.transform, fsr=True, current_task=task) 70 | # get new task validation set 71 | nx, ny = self.currentbuffer.get_data(subsample, transform=self.transform) 72 | 73 | 74 | iteration = 1 75 | # example influence on stability 76 | with higher.innerloop_ctx(self.net, self.opt) as (meta_model, meta_opt): 77 | base1 = torch.ones(total.shape[0], device=self.device) 78 | eps1 = torch.zeros(total.shape[0], requires_grad=True, device=self.device) 79 | # pseudo update 80 | for i in range(iteration): 81 | meta_train_outputs = meta_model(total) 82 | meta_train_loss = self.loss(meta_train_outputs, total_labels, reduction="none") 83 | meta_train_loss = (torch.sum(eps1 * meta_train_loss) + torch.sum(base1 * meta_train_loss)) / torch.tensor(total.shape[0]) 84 | meta_opt.step(meta_train_loss) 85 | meta_val1_outputs = meta_model(bx) 86 | meta_val1_loss = self.loss(meta_val1_outputs, by, reduction="mean") 87 | eps_grads1 = torch.autograd.grad(meta_val1_loss, eps1)[0].detach() 88 | 89 | # example influence on plasticity 90 | with higher.innerloop_ctx(self.net, self.opt) as (meta_model2, meta_opt2): 91 | base2 = torch.ones(total.shape[0], device=self.device) 92 | eps2 = torch.zeros(total.shape[0], requires_grad=True, device=self.device) 93 | # pseudo update 94 | for i in range(iteration): 95 | meta_train_outputs2 = meta_model2(total) 96 | meta_train_loss2 = self.loss(meta_train_outputs2, total_labels, reduction="none") 97 | meta_train_loss2 = (torch.sum(eps2 * meta_train_loss2) + torch.sum(base2 * meta_train_loss2)) / torch.tensor(total.shape[0]) 98 | meta_opt2.step(meta_train_loss2) 99 | meta_val2_outputs = meta_model2(nx) 100 | meta_val2_loss = self.loss(meta_val2_outputs, ny, reduction="mean") 101 | eps_grads2 = torch.autograd.grad(meta_val2_loss, eps2)[0].detach() 102 | 103 | 104 | gn = gradient_normalizers([eps_grads1, eps_grads2], [meta_val1_loss.item(), meta_val2_loss.item()], "ours") 105 | for gr_i in range(len(eps_grads1)): 106 | eps_grads1[gr_i] = eps_grads1[gr_i] / gn[0] 107 | for gr_i in range(len(eps_grads2)): 108 | eps_grads2[gr_i] = eps_grads2[gr_i] / gn[1] 109 | # compute gamma 110 | sol, min_norm = MinNormSolver.find_min_norm_element([eps_grads1, eps_grads2]) 111 | # fused influence 112 | w_tilde = sol[0] * eps_grads1 + (1 - sol[0]) * eps_grads2 113 | 114 | 115 | # update 116 | w_tilde = torch.ones(total.shape[0], device=self.device) - 1 * w_tilde 117 | l1_norm = torch.sum(w_tilde) 118 | if l1_norm != 0: 119 | w = w_tilde / l1_norm 120 | else: 121 | w = w_tilde 122 | self.opt.zero_grad() 123 | outputs = self.net(total) 124 | loss_batch = self.loss(outputs, total_labels, reduction="none") 125 | loss = torch.sum(w * loss_batch) 126 | loss.backward() 127 | self.opt.step() 128 | return loss.item() 129 | -------------------------------------------------------------------------------- /oursRehSel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.buffer import Buffer 3 | from utils.args import * 4 | from models.utils.continual_model import ContinualModel 5 | import random 6 | from utils.current_buffer import CurrentBuffer 7 | import higher 8 | from utils.min_norm_solvers import MinNormSolver, gradient_normalizers 9 | import numpy as np 10 | 11 | def get_parser() -> ArgumentParser: 12 | parser = ArgumentParser(description='Continual learning') 13 | add_management_args(parser) 14 | add_experiment_args(parser) 15 | add_rehearsal_args(parser) 16 | return parser 17 | 18 | 19 | class Mem(ContinualModel): 20 | NAME = 'mem' 21 | COMPATIBILITY = ['class-il', 'task-il'] 22 | 23 | def __init__(self, backbone, loss, args, transform): 24 | super(Mem, self).__init__(backbone, loss, args, transform) 25 | self.buffer = Buffer(self.args.buffer_size, self.device) 26 | self.current_task = 0 27 | self.transform = None 28 | 29 | def end_task(self, dataset): 30 | replace = self.args.buffer_size // (self.current_task + 1) 31 | delete_ind = self.buffer.delete_data_basedscore(replace, task=self.current_task) 32 | bx, by, b_ids, b_scores, b_imgid, b_ind = self.currentbuffer.get_all_data(self.currentbuffer.num_examples, transform=self.transform) 33 | new_codes, distance = self.currentbuffer.ourkmeans(replace) 34 | index = self.currentbuffer.score(replace, new_codes) 35 | new_score = torch.ones((index.shape[0], 3), device=self.device) 36 | self.buffer.replace_keshihua_data(delete_ind, bx[index], by[index], b_ids[index], new_score) 37 | self.current_task = self.current_task + 1 38 | self.times = torch.zeros(self.args.buffer_size, device=self.device) 39 | self.buffer.reset_score() 40 | 41 | 42 | def observe(self, inputs, labels, img_id, not_aug_inputs, task, args, epoch): 43 | 44 | real_batch_size = inputs.shape[0] 45 | task_labels = torch.ones(real_batch_size, dtype=torch.long).to(self.device) * task 46 | if task == 0: 47 | self.opt.zero_grad() 48 | outputs = self.net(inputs) 49 | loss = self.loss(outputs, labels) 50 | loss.backward() 51 | self.opt.step() 52 | score = torch.ones((real_batch_size, 3), dtype=torch.long).to(self.device) 53 | self.buffer.add_data(examples=inputs, labels=labels, task_labels=task_labels, score=score) 54 | return loss.item() 55 | else: 56 | if epoch<45: 57 | self.opt.zero_grad() 58 | outputs = self.net(inputs) 59 | loss = self.loss(outputs, labels) 60 | loss.backward() 61 | self.opt.step() 62 | return loss.item() 63 | else: 64 | self.opt.zero_grad() 65 | mem_x, mem_y, mem_ids, mem_score, mem_index = self.buffer.get_data_gmed(self.args.minibatch_size, transform=self.transform, fsr=True, current_task=task) 66 | total = torch.cat((inputs, mem_x)) 67 | total_labels = torch.cat((labels, mem_y)) 68 | subsample = self.buffer.buffer_size // 10 69 | bx, by, b_ids, b_score = self.buffer.get_data(subsample, transform=self.transform, fsr=True, current_task=task) 70 | nx, ny = self.currentbuffer.get_data(subsample, transform=self.transform) 71 | input_id, get_input_score = self.currentbuffer.get_input_score(img_id, shape=real_batch_size) 72 | 73 | iteration = 1 74 | with higher.innerloop_ctx(self.net, self.opt) as (meta_model, meta_opt): 75 | base1 = torch.ones(total.shape[0], device=self.device) 76 | eps1 = torch.zeros(total.shape[0], requires_grad=True, device=self.device) 77 | for i in range(iteration): 78 | meta_train_outputs = meta_model(total) 79 | meta_train_loss = self.loss(meta_train_outputs, total_labels, reduction="none") 80 | meta_train_loss = (torch.sum(eps1 * meta_train_loss) + torch.sum(base1 * meta_train_loss)) / torch.tensor(total.shape[0]) 81 | meta_opt.step(meta_train_loss) 82 | meta_val1_outputs = meta_model(bx) 83 | meta_val1_loss = self.loss(meta_val1_outputs, by, reduction="mean") 84 | eps_grads1 = torch.autograd.grad(meta_val1_loss, eps1)[0].detach() 85 | 86 | with higher.innerloop_ctx(self.net, self.opt) as (meta_model2, meta_opt2): 87 | base2 = torch.ones(total.shape[0], device=self.device) 88 | eps2 = torch.zeros(total.shape[0], requires_grad=True, device=self.device) 89 | for i in range(iteration): 90 | meta_train_outputs2 = meta_model2(total) 91 | meta_train_loss2 = self.loss(meta_train_outputs2, total_labels, reduction="none") 92 | meta_train_loss2 = (torch.sum(eps2 * meta_train_loss2) + torch.sum(base2 * meta_train_loss2)) / torch.tensor(total.shape[0]) 93 | meta_opt2.step(meta_train_loss2) 94 | meta_val2_outputs = meta_model2(nx) 95 | meta_val2_loss = self.loss(meta_val2_outputs, ny, reduction="mean") 96 | eps_grads2 = torch.autograd.grad(meta_val2_loss, eps2)[0].detach() 97 | 98 | gn = gradient_normalizers([eps_grads1, eps_grads2], [meta_val1_loss.item(), meta_val2_loss.item()], "ours") 99 | for gr_i in range(len(eps_grads1)): 100 | eps_grads1[gr_i] = eps_grads1[gr_i] / gn[0] 101 | for gr_i in range(len(eps_grads2)): 102 | eps_grads2[gr_i] = eps_grads2[gr_i] / gn[1] 103 | sol, min_norm = MinNormSolver.find_min_norm_element([eps_grads1, eps_grads2]) 104 | w_tilde = sol[0] * eps_grads1 + (1 - sol[0]) * eps_grads2 105 | 106 | 107 | # store influence 108 | mem_score[:, 0] = (mem_score[:, 0] * self.times[mem_index] + eps_grads1[real_batch_size:]) / (self.times[mem_index] + 1) 109 | mem_score[:, 1] = (mem_score[:, 1] * self.times[mem_index] + eps_grads2[real_batch_size:]) / (self.times[mem_index] + 1) 110 | mem_score[:, 2] = (mem_score[:, 2] * self.times[mem_index] + w_tilde[real_batch_size:]) / (self.times[mem_index] + 1) 111 | # update mem score 112 | self.buffer.replace_score(mem_score, mem_index) 113 | self.times[mem_index] = self.times[mem_index] + 1 114 | 115 | cur_epoch = epoch - 45 116 | # store influence 117 | get_input_score[:, 0] = (get_input_score[:, 0] * cur_epoch + eps_grads1[:real_batch_size]) / (cur_epoch + 1) 118 | get_input_score[:, 1] = (get_input_score[:, 1] * cur_epoch + eps_grads2[:real_batch_size]) / (cur_epoch + 1) 119 | get_input_score[:, 2] = (get_input_score[:, 2] * cur_epoch + w_tilde[:real_batch_size]) / (cur_epoch + 1) 120 | # update new data score 121 | self.currentbuffer.replace_scores(index=input_id, mem_scores=get_input_score) 122 | 123 | 124 | w_tilde = torch.ones(total.shape[0], device=self.device) - 1 * w_tilde 125 | l1_norm = torch.sum(w_tilde) 126 | if l1_norm != 0: 127 | w = w_tilde / l1_norm 128 | else: 129 | w = w_tilde 130 | self.opt.zero_grad() 131 | outputs = self.net(total) 132 | loss_batch = self.loss(outputs, total_labels, reduction="none") 133 | loss = torch.sum(w * loss_batch) 134 | loss.backward() 135 | self.opt.step() 136 | return loss.item() --------------------------------------------------------------------------------