├── .gitignore ├── README.md ├── approaches ├── after_train.py ├── before_train.py ├── eval.py ├── noncl.py └── train.py ├── config.py ├── dataloader └── data.py ├── eval.py ├── figures └── TPL.png ├── main.py ├── networks ├── __init__.py └── vit_hat.py ├── requirements.txt ├── scripts ├── clip.sh ├── deit_small.sh ├── deit_small_in661.sh ├── deit_tiny.sh ├── dino.sh ├── mae.sh ├── vit_small.sh └── vit_tiny.sh ├── sequence ├── C100_10T ├── C100_20T ├── C10_5T ├── T_10T └── T_5T └── utils ├── __init__.py ├── baseline.py ├── sgd_hat.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | results 4 | deit_pretrained 5 | ckpt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Class Incremental Learning via Likelihood Ratio Based Task Prediction 2 | 3 | This repository contains the code for our ICLR2024 paper [Class Incremental Learning via Likelihood Ratio Based Task Prediction](https://arxiv.org/abs/2309.15048) by [Haowei Lin](https://linhaowei1.github.io/), [Yijia Shao](https://shaoyijia.github.io/), Weinan Qian, Ningxin Pan, Yiduo Guo, and [Bing Liu](https://www.cs.uic.edu/~liub/). 4 | 5 | **Update [2024.2.10]: Now we support DER++, Non-CL, and more pre-trained visual encoders!** 6 | 7 | ## Quick Links 8 | 9 | - [Overview](#overview) 10 | - [Requirements](#requirements) 11 | - [Training](#training) 12 | - [Extension](#extension) 13 | - [Bugs or Questions?](#bugs-or-questions) 14 | - [Acknowledgements](acknowledgements#) 15 | - [Citation](#citation) 16 | 17 | ## Overview 18 | 19 | ![](figures/TPL.png) 20 | 21 | ## Requirements 22 | 23 | First, install PyTorch by following the instructions from [the official website](https://pytorch.org/). We run the experiments on Pytorch 2.0.1, and PyTorch version higher than `1.6.0` should also work. For example, if you use Linux and **CUDA11** ([how to check CUDA version](https://varhowto.com/check-cuda-version/)), install PyTorch by the following command, 24 | 25 | ``` 26 | pip install torch==1.6.0+cu110 -f https://download.pytorch.org/whl/torch_stable.html 27 | ``` 28 | 29 | If you instead use **CUDA** `<11` or **CPU**, install PyTorch by the following command, 30 | 31 | ``` 32 | pip install torch==1.6.0 33 | ``` 34 | 35 | Then run the following script to install the remaining dependencies, 36 | 37 | ``` 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | **Attention**: Our model is based on `timm==0.4.12`. Using them from other versions may cause some unexpected bugs. 42 | 43 | ## Training 44 | 45 | In the following section, we describe how to train the TPL model by using our code. 46 | 47 | **Data** 48 | 49 | Before training and evaluation, please download the datasets (CIFAR-10, CIFAR-100, TinyImageNet). The default working directory is set as ``~/data`` in our code. You can modify it according to your need. 50 | 51 | **Pre-train Model** 52 | 53 | We use the pre-train DeiT model provided by [MORE](https://github.com/k-gyuhak/MORE). Please download it and save the file as ``./ckpt/pretrained/deit_small_patch16_224_in661.pth``. If you would like to test other pre-trained visual encoders, also download to the same place (you can find the pre-trained weights in timm or huggingface). We provide the scripts for Dino, MAE, CILP, ViT (small, tiny), DeiT (small, tiny). 54 | 55 | **Training scripts** 56 | 57 | We provide the examplar training and evaluation script as `deit_small_in661.sh`. Just run the following command and you will get the results: 58 | 59 | ```bash 60 | bash scripts/deit_small_in661.sh 61 | ``` 62 | 63 | This script performs both training and testing. The default training will train TPL for 5 random seeds. In training, the results will be logged in `ckpt` and the training results are $HAT_{CIL}$ without using TPL inference techniques. After running evaluation, it will be replaced with new results. If you find you get a bad results, try to check if you run the `eval.py` accurately. The results for the first run with `seed=2023` will be saved in `./ckpt/seq0/seed2023/progressive_main_2023`. 64 | 65 | For the results in the paper, we use Nvidia A100 GPUs with CUDA 11.7. Using different types of devices or different versions of CUDA/other software may lead to slightly different performance. 66 | 67 | ## Extension 68 | 69 | Our repo also supports running baselines like DER++. If you are interested in other baselines, just follow the same way of DER++ to integrate your new code. Also, if you want to test TIL+OOD methods, you can just modify the inference code and include the OOD score computation in `baseline.py`. Our code base is vey extensible. 70 | 71 | ## Bugs or questions? 72 | 73 | If you have any questions related to the code or the paper, feel free to email [Haowei](mailto:linhaowei@pku.edu.cn). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 74 | 75 | ## Acknowledgements 76 | 77 | We thank [PyContinual](https://github.com/ZixuanKe/PyContinual) for providing an extensible framework for continual learning. We use their code structure as a reference when developing this code base. 78 | 79 | ## Citation 80 | 81 | Please cite our paper if you use this code or part of it in your work: 82 | 83 | ```bibtex 84 | @inproceedings{lin2024class, 85 | title={Class Incremental Learning via Likelihood Ratio Based Task Prediction}, 86 | author={Haowei Lin and Yijia Shao and Weinan Qian and Ningxin Pan and Yiduo Guo and Bing Liu}, 87 | year={2024}, 88 | booktitle={International Conference on Learning Representations} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /approaches/after_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils.sgd_hat import cum_mask, freeze_mask 4 | def compute(args, model): 5 | if 'hat' in args.baseline: 6 | args.mask_pre = cum_mask(smax=args.smax, t=args.task, model=model, mask_pre=args.mask_pre) 7 | args.mask_back = freeze_mask(args.smax, args.task, model, args.mask_pre) 8 | torch.save(args.mask_pre, os.path.join(args.output_dir, 'mask_pre')) 9 | torch.save(args.mask_back, os.path.join(args.output_dir, 'mask_back')) 10 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'model')) 11 | 12 | -------------------------------------------------------------------------------- /approaches/before_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def prepare(args, model): 5 | 6 | if 'hat' in args.baseline: 7 | args.mask_pre = None 8 | args.mask_back = None 9 | args.reg_lambda = 0.75 10 | if args.task > 0: 11 | print('load mask matrix ....') 12 | args.mask_pre = torch.load(os.path.join(args.prev_output, 'mask_pre'), map_location='cpu') 13 | args.mask_back = torch.load(os.path.join(args.prev_output, 'mask_back'), map_location='cpu') 14 | 15 | for k, v in args.mask_pre.items(): 16 | args.mask_pre[k] = args.mask_pre[k].cuda() 17 | 18 | for k, v in args.mask_back.items(): 19 | args.mask_back[k] = args.mask_back[k].cuda() 20 | 21 | for n, p in model.named_parameters(): 22 | p.grad = None 23 | if n in args.mask_back.keys(): 24 | p.hat = args.mask_back[n] 25 | else: 26 | p.hat = None 27 | -------------------------------------------------------------------------------- /approaches/eval.py: -------------------------------------------------------------------------------- 1 | from modulefinder import IMPORT_NAME 2 | from approaches import before_train, after_train 3 | from tqdm.auto import tqdm 4 | import torch.nn as nn 5 | import os 6 | import shutil 7 | import torch 8 | import json 9 | import numpy as np 10 | import logging 11 | import math 12 | from transformers import get_scheduler 13 | from utils import utils, baseline 14 | import faiss 15 | logger = logging.getLogger(__name__) 16 | 17 | class Appr(object): 18 | 19 | def __init__(self, args): 20 | super().__init__() 21 | self.args = args 22 | 23 | def eval(self, model, train_loaders, test_loaders, replay_loader, accelerator): 24 | 25 | model = accelerator.prepare(model) 26 | model.eval() 27 | replay_loader = accelerator.prepare(replay_loader) 28 | train_hidden, train_labels = None, None 29 | if os.path.exists(os.path.join(self.args.output_dir, 'train_hidden')): 30 | with open(os.path.join(self.args.output_dir, 'results'), 'r') as f: 31 | results = json.load(f) 32 | results = {int(k): v for (k,v) in results.items()} 33 | with open(os.path.join(self.args.output_dir, 'train_hidden'), 'r') as f: 34 | train_hidden = json.load(f) 35 | train_hidden = {int(k): v for (k,v) in train_hidden.items()} 36 | with open(os.path.join(self.args.output_dir, 'train_labels'), 'r') as f: 37 | train_labels = json.load(f) 38 | train_labels = {int(k): v for (k,v) in train_labels.items()} 39 | with open(os.path.join(self.args.output_dir, 'train_logits'), 'r') as f: 40 | train_logits = json.load(f) 41 | train_logits = {int(k): v for (k,v) in train_logits.items()} 42 | else: 43 | results = {} 44 | train_hidden = {} 45 | train_labels = {} 46 | train_logits = {} 47 | model.eval() 48 | 49 | for eval_t in tqdm(range(self.args.task + 1)): 50 | 51 | results[eval_t] = { 52 | 'predictions': [], # [N x data], prediction of N task mask 53 | 'references': [], # [data] 54 | 'hidden': [], # [N x data] 55 | 'logits': [], # [N x data] 56 | 'softmax_prob': [], # [N x data] 57 | 'total_num': 0 58 | } 59 | train_hidden[eval_t] = [] 60 | train_labels[eval_t] = [] 61 | train_logits[eval_t] = [] 62 | test_loader, train_loader = accelerator.prepare(test_loaders[eval_t], train_loaders[eval_t]) 63 | 64 | for task_mask in range(self.args.task + 1): 65 | 66 | train_hidden_list = [] 67 | hidden_list = [] 68 | prediction_list = [] 69 | logits_list = [] 70 | softmax_list = [] 71 | train_logits_list = [] 72 | 73 | for _, batch in enumerate(test_loader): 74 | with torch.no_grad(): 75 | features, _ = model.forward_features(batch[0], task_mask, s=self.args.smax) 76 | output = model.forward_classifier(features, task_mask) 77 | output = output[:, task_mask * self.args.class_num: (task_mask+1) * self.args.class_num] 78 | score, prediction = torch.max(torch.softmax(output, dim=1), dim=1) 79 | 80 | hidden_list += (features).cpu().numpy().tolist() 81 | prediction_list += (prediction + self.args.class_num * task_mask).cpu().numpy().tolist() 82 | softmax_list += score.cpu().numpy().tolist() 83 | logits_list += output.cpu().numpy().tolist() 84 | 85 | if task_mask == 0: 86 | results[eval_t]['total_num'] += batch[0].shape[0] 87 | results[eval_t]['references'] += batch[1].cpu().numpy().tolist() 88 | 89 | results[eval_t]['hidden'].append(hidden_list) 90 | results[eval_t]['predictions'].append(prediction_list) 91 | results[eval_t]['softmax_prob'].append(softmax_list) 92 | results[eval_t]['logits'].append(logits_list) 93 | 94 | 95 | for _, batch in enumerate(train_loader): 96 | with torch.no_grad(): 97 | features, _ = model.forward_features(batch[0], eval_t, s=self.args.smax) 98 | output = model.forward_classifier(features, eval_t) 99 | output = output[:, eval_t * self.args.class_num: (eval_t+1) * self.args.class_num] 100 | train_logits[eval_t] += output.cpu().numpy().tolist() 101 | train_hidden[eval_t] += (features).cpu().numpy().tolist() 102 | train_labels[eval_t] += (batch[1] - self.args.class_num * eval_t).cpu().numpy().tolist() 103 | 104 | # with open(os.path.join(self.args.output_dir, 'results'), 'w') as f: 105 | # json.dump(results, f) 106 | # with open(os.path.join(self.args.output_dir, 'train_hidden'), 'w') as f: 107 | # json.dump(train_hidden, f) 108 | # with open(os.path.join(self.args.output_dir, 'train_labels'), 'w') as f: 109 | # json.dump(train_labels, f) 110 | # with open(os.path.join(self.args.output_dir, 'train_logits'), 'w') as f: 111 | # json.dump(train_logits, f) 112 | 113 | out_features = {task_mask: [] for task_mask in range(self.args.task + 1)} 114 | in_features = {task_mask: [] for task_mask in range(self.args.task + 1)} 115 | features_dict = {task_mask: [] for task_mask in range(self.args.task + 1)} 116 | logits_dict = {task_mask: [] for task_mask in range(self.args.task + 1)} 117 | replay_labels = [] 118 | 119 | for idx, batch in enumerate(replay_loader): 120 | 121 | with torch.no_grad(): 122 | for task_mask in range(self.args.task + 1): 123 | if idx == task_mask: 124 | features, _ = model.forward_features(batch[0], task_mask, s=self.args.smax) 125 | in_features[task_mask] += (features).cpu().numpy().tolist() 126 | else: 127 | features, _ = model.forward_features(batch[0], task_mask, s=self.args.smax) 128 | out_features[task_mask] += (features).cpu().numpy().tolist() 129 | logits = model.forward_classifier(features, task_mask)[:, task_mask * self.args.class_num: (task_mask+1) * self.args.class_num] 130 | features_dict[task_mask] += features.cpu().numpy().tolist() 131 | logits_dict[task_mask] += logits.cpu().numpy().tolist() 132 | 133 | replay_labels += batch[1].cpu().numpy().tolist() 134 | 135 | ## replay data 136 | self.args.out_features = out_features 137 | self.args.in_features = in_features 138 | self.args.features_dict = features_dict 139 | self.args.logits_dict = logits_dict 140 | self.args.replay_loader = replay_loader 141 | self.args.replay_labels = replay_labels 142 | 143 | ## train data 144 | self.args.train_logits = train_logits 145 | self.args.train_labels = train_labels 146 | self.args.train_hidden = train_hidden 147 | self.args.model = model 148 | self.args.test_loaders = test_loaders 149 | 150 | ## maha feat 151 | self.args.feat_mean_list, self.args.precision_list = utils.load_maha(self.args, train_hidden, train_labels) 152 | 153 | self.args.calib_w = torch.ones(self.args.task + 1) 154 | self.args.calib_b = torch.zeros(self.args.task + 1) 155 | self.args.mls_scale = [1.0 for _ in range(self.args.task + 1)] 156 | self.args.mds_scale = [1.0 for _ in range(self.args.task + 1)] 157 | self.args.knn_scale = [1.0 for _ in range(self.args.task + 1)] 158 | self.args.index_out = [None for _ in range(self.args.task + 1)] 159 | self.args.tplr_setup = [False for _ in range(self.args.task + 1)] 160 | 161 | for task_mask in range(self.args.task + 1): 162 | self.args.index_out[task_mask] = faiss.IndexFlatL2(len(self.args.out_features[task_mask][0])) 163 | self.args.index_out[task_mask].add(utils.normalize(self.args.out_features[task_mask]).astype(np.float32)) 164 | self.args.tplr_setup[task_mask] = True 165 | 166 | baseline.scaling(self.args) 167 | baseline.calibration(self.args) 168 | baseline.baseline(self.args, results) -------------------------------------------------------------------------------- /approaches/noncl.py: -------------------------------------------------------------------------------- 1 | from modulefinder import IMPORT_NAME 2 | from approaches import before_train, after_train 3 | from tqdm.auto import tqdm 4 | import torch.nn as nn 5 | import os 6 | import shutil 7 | import torch 8 | import numpy as np 9 | import logging 10 | import math 11 | from transformers import get_scheduler 12 | from utils.sgd_hat import HAT_reg, compensation, compensation_clamp 13 | from utils.sgd_hat import SGD_hat as SGD 14 | from utils import utils 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Appr(object): 19 | 20 | def __init__(self, args): 21 | super().__init__() 22 | self.args = args 23 | 24 | def train(self, model, train_loader, test_loaders, replay_loader): 25 | 26 | # Optimizer 27 | # Split weights in two groups, one with weight decay and the other not. 28 | optimizer = SGD(model.adapter_parameters(), lr=self.args.learning_rate, 29 | momentum=0.9, weight_decay=5e-4, nesterov=True) 30 | # Scheduler and math around the number of training steps. 31 | num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps) 32 | if self.args.max_train_steps is None: 33 | self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch 34 | else: 35 | self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch) 36 | 37 | model = model.cuda() 38 | 39 | 40 | # Train! 41 | logger.info("***** Running training *****") 42 | logger.info(" Num examples = {}".format(len(train_loader) * self.args.batch_size)) 43 | logger.info(f" Num Epochs = {self.args.num_train_epochs}, checkpoint Model = {self.args.model_name_or_path}") 44 | logger.info(f" Instantaneous batch size per device = {self.args.batch_size}") 45 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 46 | logger.info(f" Learning Rate = {self.args.learning_rate}") 47 | logger.info(f" Seq ID = {self.args.idrandom}, Task id = {self.args.task}, Task Name = {self.args.task_name}, Num task = {self.args.ntasks}") 48 | 49 | progress_bar = tqdm(range(self.args.max_train_steps)) 50 | completed_steps = 0 51 | starting_epoch = 0 52 | 53 | for epoch in range(starting_epoch, self.args.num_train_epochs): 54 | model.train() 55 | 56 | for step, batch in enumerate(train_loader): 57 | 58 | batch[0] = batch[0].cuda() 59 | batch[1] = batch[1].cuda() 60 | 61 | 62 | outputs = model(batch[0]) 63 | 64 | loss = nn.functional.cross_entropy(outputs, batch[1]) 65 | 66 | loss.backward() 67 | 68 | if step % self.args.gradient_accumulation_steps == 0 or step == len(train_loader) - 1: 69 | 70 | optimizer.step() 71 | optimizer.zero_grad() 72 | progress_bar.update(1) 73 | completed_steps += 1 74 | progress_bar.set_description( 75 | 'Train Iter (Epoch=%3d,loss=%5.3f)' % ((epoch, loss.item()))) # show the loss, mean while 76 | 77 | if completed_steps >= self.args.max_train_steps: 78 | break 79 | 80 | 81 | for eval_t in range(self.args.task + 1): 82 | results = self.eval_cil(model, test_loaders, eval_t) 83 | print("*task {}, til_acc = {}, cil_acc = {}, tp_acc = {}".format( 84 | eval_t, results['til_accuracy'], results['cil_accuracy'], results['TP_accuracy'])) 85 | utils.write_result(results, eval_t, self.args) 86 | 87 | def eval_cil(self, model, test_loaders, eval_t): 88 | model.eval() 89 | dataloader = test_loaders[eval_t] 90 | label_list = [] 91 | cil_prediction_list, til_prediction_list = [], [] 92 | total_num = 0 93 | 94 | for _, batch in enumerate(dataloader): 95 | with torch.no_grad(): 96 | 97 | features = model.forward_features(batch[0].cuda()) 98 | logits = model.forward_classifier(features) 99 | cil_outputs = logits[..., : (self.args.task + 1) * self.args.class_num] 100 | til_outputs = logits[..., eval_t * self.args.class_num: (eval_t+1) * self.args.class_num] 101 | _, cil_prediction = torch.max(torch.softmax(cil_outputs, dim=1), dim=1) 102 | _, til_prediction = torch.max(torch.softmax(til_outputs, dim=1), dim=1) 103 | til_prediction += eval_t * self.args.class_num 104 | 105 | references = batch[1] 106 | total_num += batch[0].shape[0] 107 | 108 | label_list += references.cpu().numpy().tolist() 109 | cil_prediction_list += cil_prediction.cpu().numpy().tolist() 110 | til_prediction_list += til_prediction.cpu().numpy().tolist() 111 | 112 | cil_accuracy = sum( 113 | [1 if label_list[i] == cil_prediction_list[i] else 0 for i in range(total_num)] 114 | ) / total_num 115 | 116 | til_accuracy = sum( 117 | [1 if label_list[i] == til_prediction_list[i] else 0 for i in range(total_num)] 118 | ) / total_num 119 | 120 | tp_accuracy = sum( 121 | [1 if cil_prediction_list[i] // self.args.class_num == eval_t else 0 for i in range(total_num)] 122 | ) / total_num 123 | 124 | results = { 125 | 'til_accuracy': round(til_accuracy, 4), 126 | 'cil_accuracy': round(cil_accuracy, 4), 127 | 'TP_accuracy': round(tp_accuracy, 4) 128 | } 129 | return results -------------------------------------------------------------------------------- /approaches/train.py: -------------------------------------------------------------------------------- 1 | from modulefinder import IMPORT_NAME 2 | from approaches import before_train, after_train 3 | from tqdm.auto import tqdm 4 | import torch.nn as nn 5 | import os 6 | import shutil 7 | import torch 8 | import numpy as np 9 | import logging 10 | import math 11 | from transformers import get_scheduler 12 | from utils.sgd_hat import HAT_reg, compensation, compensation_clamp 13 | from utils.sgd_hat import SGD_hat as SGD 14 | from utils import utils 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Appr(object): 19 | 20 | def __init__(self, args): 21 | super().__init__() 22 | self.args = args 23 | 24 | def train(self, model, train_loader, test_loaders, replay_loader): 25 | 26 | # Optimizer 27 | # Split weights in two groups, one with weight decay and the other not. 28 | optimizer = SGD(model.adapter_parameters(), lr=self.args.learning_rate, 29 | momentum=0.9, weight_decay=5e-4, nesterov=True) 30 | # Scheduler and math around the number of training steps. 31 | num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps) 32 | if self.args.max_train_steps is None: 33 | self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch 34 | else: 35 | self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch) 36 | 37 | model = model.cuda() 38 | if 'derpp' in self.args.baseline: 39 | self.args.teacher_model = self.args.teacher_model.cuda() 40 | for p in self.args.teacher_model.parameters(): 41 | p.requires_grad = False 42 | 43 | if replay_loader is not None: 44 | replay_iterator = iter(replay_loader) 45 | 46 | before_train.prepare(self.args, model) 47 | 48 | # Train! 49 | logger.info("***** Running training *****") 50 | logger.info(" Num examples = {}".format(len(train_loader) * self.args.batch_size)) 51 | logger.info(f" Num Epochs = {self.args.num_train_epochs}, checkpoint Model = {self.args.model_name_or_path}") 52 | logger.info(f" Instantaneous batch size per device = {self.args.batch_size}") 53 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 54 | logger.info(f" Learning Rate = {self.args.learning_rate}") 55 | logger.info(f" Seq ID = {self.args.idrandom}, Task id = {self.args.task}, Task Name = {self.args.task_name}, Num task = {self.args.ntasks}") 56 | 57 | progress_bar = tqdm(range(self.args.max_train_steps)) 58 | completed_steps = 0 59 | starting_epoch = 0 60 | 61 | for epoch in range(starting_epoch, self.args.num_train_epochs): 62 | model.train() 63 | 64 | for step, batch in enumerate(train_loader): 65 | 66 | s = (self.args.smax - 1 / self.args.smax) * step / len( 67 | train_loader) + 1 / self.args.smax 68 | 69 | if replay_loader is not None: 70 | try: 71 | replay_batch = next(replay_iterator) 72 | batch[0] = torch.cat((batch[0], replay_batch[0]), dim=0) 73 | batch[1] = torch.cat((batch[1], replay_batch[1]), dim=0) 74 | except: 75 | replay_iterator = iter(replay_loader) 76 | replay_batch = next(replay_iterator) 77 | batch[0] = torch.cat((batch[0], replay_batch[0]), dim=0) 78 | batch[1] = torch.cat((batch[1], replay_batch[1]), dim=0) 79 | 80 | batch[0] = batch[0].cuda() 81 | batch[1] = batch[1].cuda() 82 | 83 | if 'hat' in self.args.baseline: 84 | features, masks = model.forward_features(batch[0], self.args.task, s=s) 85 | outputs = model.forward_classifier(features, self.args.task) 86 | else: 87 | features = model.forward_features(batch[0]) 88 | outputs = model.forward_classifier(features)[..., :(self.args.task+1) * self.args.class_num] 89 | 90 | loss = nn.functional.cross_entropy(outputs, batch[1]) 91 | 92 | if 'hat' in self.args.baseline: 93 | loss += HAT_reg(self.args, masks) 94 | elif 'derpp' in self.args.baseline and replay_loader is not None: 95 | with torch.no_grad(): 96 | prev_feature = self.args.teacher_model.forward_features(replay_batch[0].cuda()) 97 | loss += nn.functional.mse_loss(features[-prev_feature.shape[0]:, ...], prev_feature) 98 | 99 | loss.backward() 100 | 101 | if step % self.args.gradient_accumulation_steps == 0 or step == len(train_loader) - 1: 102 | 103 | if 'hat' in self.args.baseline: 104 | compensation(model, self.args, thres_cosh=self.args.thres_cosh, s=s) 105 | optimizer.step(hat=(self.args.task > 0)) 106 | compensation_clamp(model, thres_emb=6) 107 | 108 | else: 109 | optimizer.step() 110 | 111 | optimizer.zero_grad() 112 | progress_bar.update(1) 113 | completed_steps += 1 114 | progress_bar.set_description( 115 | 'Train Iter (Epoch=%3d,loss=%5.3f)' % ((epoch, loss.item()))) # show the loss, mean while 116 | 117 | if completed_steps >= self.args.max_train_steps: 118 | break 119 | 120 | after_train.compute(self.args, model) 121 | 122 | for eval_t in range(self.args.task + 1): 123 | if 'hat' in self.args.baseline: 124 | results = self.eval_hat(model, test_loaders, eval_t) 125 | else: 126 | results = self.eval_cil(model, test_loaders, eval_t) 127 | 128 | print("*task {}, til_acc = {}, cil_acc = {}, tp_acc = {}".format( 129 | eval_t, results['til_accuracy'], results['cil_accuracy'], results['TP_accuracy'])) 130 | utils.write_result(results, eval_t, self.args) 131 | 132 | def eval_cil(self, model, test_loaders, eval_t): 133 | model.eval() 134 | dataloader = test_loaders[eval_t] 135 | label_list = [] 136 | cil_prediction_list, til_prediction_list = [], [] 137 | total_num = 0 138 | 139 | for _, batch in enumerate(dataloader): 140 | with torch.no_grad(): 141 | 142 | features = model.forward_features(batch[0].cuda()) 143 | logits = model.forward_classifier(features) 144 | cil_outputs = logits[..., : (self.args.task + 1) * self.args.class_num] 145 | til_outputs = logits[..., eval_t * self.args.class_num: (eval_t+1) * self.args.class_num] 146 | _, cil_prediction = torch.max(torch.softmax(cil_outputs, dim=1), dim=1) 147 | _, til_prediction = torch.max(torch.softmax(til_outputs, dim=1), dim=1) 148 | til_prediction += eval_t * self.args.class_num 149 | 150 | references = batch[1] 151 | total_num += batch[0].shape[0] 152 | 153 | label_list += references.cpu().numpy().tolist() 154 | cil_prediction_list += cil_prediction.cpu().numpy().tolist() 155 | til_prediction_list += til_prediction.cpu().numpy().tolist() 156 | 157 | cil_accuracy = sum( 158 | [1 if label_list[i] == cil_prediction_list[i] else 0 for i in range(total_num)] 159 | ) / total_num 160 | 161 | til_accuracy = sum( 162 | [1 if label_list[i] == til_prediction_list[i] else 0 for i in range(total_num)] 163 | ) / total_num 164 | 165 | tp_accuracy = sum( 166 | [1 if cil_prediction_list[i] // self.args.class_num == eval_t else 0 for i in range(total_num)] 167 | ) / total_num 168 | 169 | results = { 170 | 'til_accuracy': round(til_accuracy, 4), 171 | 'cil_accuracy': round(cil_accuracy, 4), 172 | 'TP_accuracy': round(tp_accuracy, 4) 173 | } 174 | return results 175 | 176 | def eval_hat(self, model, test_loaders, eval_t): 177 | 178 | model.eval() 179 | dataloader = test_loaders[eval_t] 180 | label_list = [] 181 | prediction_list = [] 182 | taskscore_list = [] 183 | total_num = 0 184 | for task_mask in range(self.args.task + 1): 185 | total_num = 0 186 | task_pred = [] 187 | task_confidence = [] 188 | task_label = [] 189 | for _, batch in enumerate(dataloader): 190 | with torch.no_grad(): 191 | 192 | features, _ = model.forward_features(batch[0].cuda(), task_mask, s=self.args.smax) 193 | outputs = model.forward_classifier(features, task_mask)[ 194 | :, task_mask * self.args.class_num: (task_mask+1) * self.args.class_num] 195 | score, prediction = torch.max(torch.softmax(outputs, dim=1), dim=1) 196 | 197 | predictions = prediction + task_mask * self.args.class_num 198 | references = batch[1] 199 | 200 | total_num += batch[0].shape[0] 201 | task_confidence += score.cpu().numpy().tolist() 202 | task_label += references.cpu().numpy().tolist() 203 | task_pred += predictions.cpu().numpy().tolist() 204 | 205 | label_list = task_label 206 | prediction_list.append(task_pred) 207 | taskscore_list.append(np.array(task_confidence)) 208 | 209 | task_pred = np.argmax(np.stack(taskscore_list, axis=0), axis=0) 210 | cil_pred = [prediction_list[task_pred[i]][i] for i in range(total_num)] 211 | til_pred = [prediction_list[eval_t][i] for i in range(total_num)] 212 | 213 | cil_accuracy = sum( 214 | [1 if label_list[i] == cil_pred[i] else 0 for i in range(total_num)] 215 | ) / total_num 216 | til_accuracy = sum( 217 | [1 if label_list[i] == til_pred[i] else 0 for i in range(total_num)] 218 | ) / total_num 219 | TP_accuracy = sum( 220 | [1 if task_pred[i] == eval_t else 0 for i in range(total_num)] 221 | ) / total_num 222 | 223 | results = { 224 | 'til_accuracy': round(til_accuracy, 4), 225 | 'cil_accuracy': round(cil_accuracy, 4), 226 | 'TP_accuracy': round(TP_accuracy, 4) 227 | } 228 | return results 229 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--batch_size', type=int, default=64) 7 | parser.add_argument('--baseline', type=str) 8 | parser.add_argument('--K', type=int, default=5) 9 | parser.add_argument('--task', type=int, default=0) 10 | parser.add_argument('--idrandom', type=int, default=0) 11 | parser.add_argument('--alpha', type=float, default=0.2) 12 | parser.add_argument('--training', action='store_true') 13 | parser.add_argument('--calibration', action='store_true') 14 | parser.add_argument('--scaling', action='store_true') 15 | parser.add_argument('--visual_encoder', type=str, default='deit_small_patch16_224_in661') 16 | parser.add_argument('--class_order', type=int, default=0) 17 | parser.add_argument('--base_dir', type=str, default='~/data') 18 | parser.add_argument("--sequence_file", type=str, help="sequence file") 19 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 20 | parser.add_argument("--smax", default=400, type=int, help="smax") 21 | parser.add_argument('--warmup_ratio', type=float) 22 | parser.add_argument('--replay_buffer_size', type=int, default=200) 23 | parser.add_argument('--latent', type=int, default=64) 24 | parser.add_argument('--eval_during_training', action="store_true") 25 | parser.add_argument('--replay_batch_size', type=int, default=64) 26 | parser.add_argument( 27 | "--learning_rate", 28 | type=float, 29 | default=5e-5, 30 | help="Initial learning rate (after the potential warmup period) to use.", 31 | ) 32 | parser.add_argument( 33 | "--lr_scheduler_type", 34 | type=str, 35 | default="cosine", 36 | help="The scheduler type to use.", 37 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 38 | ) 39 | parser.add_argument( 40 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 41 | ) 42 | parser.add_argument( 43 | "--max_train_steps", 44 | type=int, 45 | default=None, 46 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 47 | ) 48 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 49 | parser.add_argument("--clipgrad", type=float, default=1.0) 50 | parser.add_argument('--thres_cosh', default=50, type=int, required=False, help='(default=%(default)d)') 51 | parser.add_argument( 52 | "--gradient_accumulation_steps", 53 | type=int, 54 | default=1, 55 | help="Number of updates steps to accumulate before performing a backward/update pass.", 56 | ) 57 | parser.add_argument("--num_train_epochs", type=int, help="Total number of training epochs to perform.") 58 | return parser.parse_args() 59 | -------------------------------------------------------------------------------- /dataloader/data.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import os 3 | from torch.utils.data.dataset import Subset 4 | from torchvision import datasets, transforms 5 | import timm 6 | from timm.data import resolve_data_config 7 | from timm.data.transforms_factory import create_transform 8 | import numpy as np 9 | from copy import deepcopy 10 | from PIL import Image 11 | import math 12 | import torch 13 | import PIL 14 | 15 | DATA_PATH = './data' 16 | 17 | def get_transform(args): 18 | 19 | config = {'input_size': (3, 224, 224), 'interpolation': 'bicubic', 'mean': ( 20 | 0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 0.9} 21 | TRANSFORM = create_transform(**config) 22 | return TRANSFORM, TRANSFORM 23 | 24 | 25 | def get_dataset(args): 26 | 27 | f_name = os.path.join('./sequence', args.sequence_file) 28 | 29 | with open(f_name, 'r') as f_random_seq: 30 | random_sep = f_random_seq.readlines()[args.idrandom].split() 31 | dataset_name = random_sep[0] 32 | 33 | if 'C10-5T' in dataset_name: 34 | args.total_class = 10 35 | args.class_num = int(args.total_class / args.ntasks) 36 | args.mean = (0.4914, 0.4822, 0.4465) 37 | args.std = (0.2023, 0.1994, 0.2010) 38 | train_transform, test_transform = get_transform(args) 39 | train = datasets.CIFAR10(DATA_PATH, train=True, download=True, transform=train_transform) 40 | test = datasets.CIFAR10(DATA_PATH, train=False, download=False, transform=test_transform) 41 | label_map = { 42 | 0: list(range(10)), 43 | 1: [3, 9, 1, 8, 0, 2, 6, 4, 5, 7], 44 | 2: [6, 0, 2, 8, 1, 9, 7, 3, 5, 4], 45 | 3: [2, 6, 1, 5, 9, 8, 0, 4, 3, 7], 46 | 4: [1, 5, 7, 2, 0, 3, 4, 6, 8, 9] 47 | } 48 | 49 | train.targets = [label_map[args.class_order].index(x) for x in train.targets] 50 | test.targets = [label_map[args.class_order].index(x) for x in test.targets] 51 | 52 | elif 'C100-' in dataset_name: 53 | args.total_class = 100 54 | args.class_num = int(args.total_class / args.ntasks) 55 | args.mean = [x / 255 for x in [129.3, 124.1, 112.4]] 56 | args.std = [x / 255 for x in [68.2, 65.4, 70.4]] 57 | train_transform, test_transform = get_transform(args) 58 | train = datasets.CIFAR100(DATA_PATH, train=True, download=True, transform=train_transform) 59 | test = datasets.CIFAR100(DATA_PATH, train=False, download=False, transform=test_transform) 60 | label_map = { 61 | 0: list(range(100)), 62 | 1: [49, 98, 97, 53, 48, 62, 89, 23, 82, 13, 40, 35, 71, 59, 34, 95, 67, 11, 27, 7, 47, 85, 36, 70, 51, 32, 60, 16, 29, 84, 39, 8, 17, 42, 72, 18, 15, 55, 83, 10, 37, 99, 66, 22, 14, 57, 24, 38, 80, 0, 52, 88, 77, 3, 50, 6, 41, 44, 93, 9, 96, 81, 45, 58, 5, 64, 86, 2, 78, 68, 75, 56, 46, 91, 43, 20, 87, 1, 33, 28, 19, 61, 30, 74, 65, 79, 63, 25, 12, 21, 31, 90, 69, 73, 4, 94, 76, 92, 54, 26], 63 | 2: [64, 79, 89, 9, 88, 3, 26, 94, 61, 62, 73, 69, 83, 8, 75, 23, 45, 92, 74, 1, 84, 71, 96, 52, 7, 95, 2, 5, 70, 28, 77, 60, 43, 22, 91, 78, 34, 80, 48, 51, 58, 37, 6, 25, 85, 97, 40, 27, 32, 98, 36, 21, 39, 31, 15, 49, 66, 72, 67, 24, 20, 93, 87, 54, 90, 76, 99, 30, 53, 29, 82, 57, 65, 4, 19, 11, 14, 41, 16, 86, 59, 68, 35, 55, 38, 17, 33, 50, 81, 63, 42, 10, 18, 56, 44, 13, 46, 0, 47, 12], 64 | 3: [97, 1, 48, 88, 58, 46, 87, 18, 35, 71, 45, 6, 31, 69, 21, 96, 9, 44, 14, 68, 98, 27, 56, 38, 13, 63, 47, 57, 22, 64, 8, 73, 78, 94, 52, 4, 23, 28, 85, 2, 19, 10, 92, 7, 93, 76, 42, 34, 49, 80, 40, 37, 66, 83, 33, 99, 36, 12, 41, 39, 75, 25, 3, 95, 16, 0, 29, 53, 60, 11, 24, 82, 86, 32, 91, 43, 65, 89, 15, 81, 17, 62, 90, 54, 51, 20, 55, 30, 77, 59, 50, 5, 74, 84, 67, 79, 70, 61, 72, 26], 65 | 4: [34, 31, 97, 47, 83, 59, 39, 4, 32, 44, 26, 73, 45, 33, 56, 87, 82, 23, 88, 10, 51, 57, 65, 84, 43, 37, 9, 74, 28, 24, 90, 25, 60, 80, 5, 64, 63, 62, 40, 19, 49, 21, 77, 95, 99, 16, 12, 14, 70, 54, 53, 38, 8, 72, 18, 68, 15, 94, 36, 7, 1, 69, 2, 61, 98, 75, 85, 11, 17, 76, 22, 27, 92, 71, 3, 0, 66, 42, 96, 67, 35, 30, 46, 81, 48, 93, 79, 6, 13, 86, 20, 91, 78, 50, 89, 41, 52, 55, 29, 58] 66 | } 67 | 68 | train.targets = [label_map[args.class_order].index(x) for x in train.targets] 69 | test.targets = [label_map[args.class_order].index(x) for x in test.targets] 70 | 71 | elif dataset_name.startswith('T-'): 72 | args.total_class = 200 73 | args.class_num = int(args.total_class / args.ntasks) 74 | args.mean = (0.4914, 0.4822, 0.4465) 75 | args.std = (0.2023, 0.1994, 0.2010) 76 | train_transform, test_transform = get_transform(args) 77 | train = datasets.ImageFolder(root=f'{DATA_PATH}/tiny-imagenet-200/train', transform=train_transform) 78 | test = datasets.ImageFolder(root=f'{DATA_PATH}/tiny-imagenet-200/val', transform=test_transform) 79 | label_map = { 80 | 0: list(range(200)), 81 | 1: [117, 8, 183, 39, 40, 47, 75, 133, 193, 28, 130, 31, 98, 119, 188, 161, 57, 92, 54, 134, 6, 71, 147, 70, 139, 68, 77, 149, 17, 87, 132, 184, 59, 52, 194, 187, 159, 196, 166, 50, 63, 62, 141, 20, 126, 99, 19, 182, 164, 34, 2, 13, 97, 78, 151, 85, 150, 74, 111, 11, 61, 83, 41, 24, 55, 101, 110, 88, 60, 14, 65, 4, 51, 5, 30, 171, 158, 84, 15, 10, 46, 165, 118, 140, 90, 186, 107, 148, 180, 42, 152, 64, 189, 109, 136, 106, 91, 66, 178, 73, 172, 29, 25, 103, 44, 108, 191, 36, 72, 76, 82, 167, 160, 199, 9, 155, 175, 174, 179, 144, 177, 197, 170, 81, 121, 113, 58, 21, 89, 0, 69, 157, 137, 1, 26, 37, 153, 124, 143, 95, 23, 105, 79, 48, 32, 3, 190, 38, 135, 80, 198, 33, 53, 56, 49, 112, 125, 156, 131, 116, 129, 67, 162, 173, 123, 12, 181, 7, 192, 169, 185, 104, 100, 138, 168, 195, 43, 93, 45, 35, 22, 142, 146, 16, 127, 86, 128, 114, 27, 120, 145, 163, 102, 122, 18, 154, 94, 96, 115, 176], 82 | 2: [121, 66, 149, 189, 103, 195, 0, 72, 179, 46, 7, 159, 70, 65, 123, 76, 54, 37, 186, 62, 96, 136, 124, 69, 181, 6, 57, 125, 161, 81, 134, 147, 132, 59, 20, 50, 93, 71, 117, 33, 135, 47, 36, 120, 94, 73, 75, 14, 102, 60, 113, 142, 175, 115, 184, 185, 152, 63, 12, 198, 24, 26, 119, 109, 165, 87, 144, 64, 48, 52, 21, 95, 116, 187, 137, 10, 84, 162, 90, 131, 88, 150, 110, 41, 146, 106, 86, 127, 151, 16, 107, 67, 129, 140, 4, 172, 39, 23, 51, 183, 197, 31, 157, 188, 171, 58, 13, 153, 32, 98, 173, 130, 97, 80, 133, 163, 53, 44, 141, 145, 155, 176, 156, 138, 22, 68, 112, 3, 174, 2, 42, 25, 29, 104, 170, 178, 193, 126, 122, 30, 196, 199, 182, 128, 91, 56, 49, 111, 83, 78, 89, 61, 192, 34, 148, 191, 180, 190, 74, 167, 158, 139, 1, 101, 166, 143, 28, 8, 43, 105, 38, 177, 118, 55, 108, 19, 5, 168, 15, 79, 160, 45, 169, 164, 85, 82, 77, 27, 40, 99, 92, 194, 18, 11, 154, 35, 100, 17, 9, 114], 83 | 3: [156, 137, 7, 123, 154, 38, 121, 40, 43, 6, 76, 129, 91, 18, 12, 149, 162, 189, 145, 107, 5, 85, 78, 111, 191, 71, 146, 87, 155, 92, 48, 49, 21, 34, 23, 187, 179, 110, 102, 186, 105, 184, 29, 90, 159, 79, 28, 108, 89, 128, 57, 96, 194, 54, 55, 167, 141, 51, 67, 0, 177, 99, 26, 173, 1, 163, 122, 115, 30, 101, 170, 198, 134, 69, 61, 58, 192, 171, 185, 37, 124, 15, 114, 132, 181, 9, 157, 83, 19, 131, 73, 86, 153, 138, 32, 8, 33, 165, 42, 180, 44, 168, 188, 81, 64, 166, 24, 172, 142, 95, 35, 161, 160, 13, 119, 199, 39, 100, 97, 125, 52, 195, 65, 158, 197, 127, 46, 4, 175, 20, 56, 190, 41, 174, 151, 84, 182, 183, 109, 75, 3, 93, 106, 136, 50, 17, 74, 10, 150, 60, 112, 164, 193, 53, 14, 169, 152, 82, 116, 80, 63, 77, 120, 117, 11, 72, 31, 104, 113, 68, 144, 88, 178, 47, 16, 27, 176, 98, 148, 94, 25, 126, 143, 62, 118, 70, 140, 45, 66, 130, 196, 147, 133, 59, 103, 36, 139, 2, 135, 22], 84 | 4: [187, 110, 38, 174, 97, 189, 39, 109, 122, 37, 42, 65, 101, 188, 134, 191, 153, 194, 3, 147, 78, 129, 52, 1, 185, 85, 22, 60, 98, 51, 155, 145, 24, 103, 2, 73, 139, 74, 18, 175, 48, 105, 46, 31, 161, 171, 14, 117, 69, 167, 12, 163, 25, 121, 13, 177, 16, 102, 56, 142, 107, 151, 53, 44, 62, 169, 176, 150, 67, 86, 91, 82, 5, 156, 128, 70, 149, 179, 144, 19, 146, 160, 21, 49, 0, 35, 119, 6, 141, 131, 94, 30, 162, 159, 76, 45, 17, 100, 118, 84, 66, 158, 15, 64, 54, 27, 89, 123, 193, 4, 80, 96, 58, 152, 93, 168, 108, 59, 113, 29, 34, 182, 83, 55, 11, 10, 111, 136, 133, 28, 192, 79, 127, 180, 140, 95, 68, 106, 61, 41, 157, 195, 90, 183, 130, 7, 125, 124, 40, 63, 116, 186, 199, 148, 120, 104, 75, 138, 178, 43, 181, 8, 143, 137, 20, 33, 99, 170, 184, 32, 87, 92, 154, 166, 88, 198, 26, 115, 190, 71, 72, 77, 50, 132, 165, 135, 164, 9, 47, 126, 57, 112, 114, 172, 197, 81, 36, 173, 23, 196], 85 | } 86 | 87 | train.samples = [(x[0], label_map[args.class_order].index(x[1])) for x in train.samples] 88 | test.samples = [(x[0], label_map[args.class_order].index(x[1])) for x in test.samples] 89 | train.targets = [label_map[args.class_order].index(x) for x in train.targets] 90 | test.targets = [label_map[args.class_order].index(x) for x in test.targets] 91 | 92 | else: 93 | raise NotImplementedError 94 | 95 | data = {} 96 | args.task2cls = [int(random_sep[t].split('-')[-1]) for t in range(args.ntasks)] 97 | cls_id_past = [] 98 | for t in range(args.ntasks): 99 | data[t] = {} 100 | 101 | cls_id = [int(random_sep[t].split('-')[-1]) * args.class_num + i for i in range(args.class_num)] 102 | ## train 103 | train_ = deepcopy(train) 104 | 105 | targets_aux, data_aux, full_target_aux, names_aux = [], [], [], [] 106 | idx_aux = [] 107 | 108 | for c in cls_id: 109 | idx = np.where(np.array(train.targets) == c)[0] 110 | 111 | if dataset_name.startswith('T-'): # for tinyImagenet 112 | idx_aux.append(idx) 113 | else: 114 | data_aux.append(train.data[idx]) 115 | targets_aux.append(np.zeros(len(idx), dtype=np.int64) + c) 116 | full_target_aux.append([[c, c]] for _ in range(len(idx))) 117 | names_aux.append([str(c) for _ in range(len(idx))]) 118 | 119 | if dataset_name.startswith('T-'): 120 | idx_list = np.concatenate(idx_aux) 121 | train_ = Subset(train_, idx_list) 122 | train_.data = [] 123 | train_.targets = np.array(train_.dataset.targets)[idx_list] 124 | train_.transform = train_.dataset.transform 125 | elif dataset_name.startswith('M-'): 126 | train_.data = torch.from_numpy(np.concatenate(data_aux, 0)) 127 | train_.targets = torch.from_numpy(np.concatenate(targets_aux, 0)) 128 | else: 129 | train_.data = np.array(list(chain(*data_aux))) 130 | train_.targets = np.array(list(chain(*targets_aux))) 131 | train_.full_labels = np.array(list(chain(*full_target_aux))) 132 | train_.names = list(chain(*names_aux)) 133 | del data_aux, targets_aux, full_target_aux, names_aux, idx_aux 134 | data[t]['train'] = train_ 135 | 136 | ## test 137 | test_ = deepcopy(test) 138 | targets_aux, data_aux, full_target_aux, names_aux = [], [], [], [] 139 | idx_aux = [] 140 | for c in cls_id: 141 | idx = np.where(np.array(test.targets) == c)[0] 142 | if dataset_name.startswith('T-'): 143 | idx_aux.append(idx) 144 | else: 145 | data_aux.append(test.data[idx]) 146 | targets_aux.append(np.zeros(len(idx), dtype=np.int64) + c) 147 | full_target_aux.append([[c, c]] for _ in range(len(idx))) 148 | names_aux.append([str(c) for _ in range(len(idx))]) 149 | 150 | if dataset_name.startswith('T-'): 151 | idx_list = np.concatenate(idx_aux) 152 | test_ = Subset(test_, idx_list) 153 | test_.data = [] 154 | test_.targets = np.array(test_.dataset.targets)[idx_list] 155 | test_.transform = test_.dataset.transform 156 | elif dataset_name.startswith('M-'): 157 | test_.data = torch.from_numpy(np.concatenate(data_aux, 0)) 158 | test_.targets = torch.from_numpy(np.concatenate(targets_aux, 0)) 159 | else: 160 | test_.data = np.array(list(chain(*data_aux))) 161 | test_.targets = np.array(list(chain(*targets_aux))) 162 | test_.full_labels = np.array(list(chain(*full_target_aux))) 163 | test_.names = list(chain(*names_aux)) 164 | del data_aux, targets_aux, full_target_aux, names_aux 165 | data[t]['test'] = test_ 166 | 167 | ## replay 168 | replay_ = deepcopy(train) 169 | targets_aux, data_aux, full_target_aux, names_aux = [], [], [], [] 170 | idx_aux = [] 171 | # mix replay dataset 172 | if t > 0: 173 | for c in cls_id_past: 174 | idx = np.where(np.array(train.targets) == c)[0][:(args.replay_buffer_size // len(cls_id_past))] 175 | if dataset_name.startswith('T-'): 176 | idx_aux.append(idx) 177 | else: 178 | data_aux.append(train.data[idx]) 179 | targets_aux.append(np.zeros(len(idx), dtype=np.int64) + c) 180 | full_target_aux.append([[c, c]] for _ in range(len(idx))) 181 | names_aux.append([str(c) for _ in range(len(idx))]) 182 | 183 | if dataset_name.startswith('T-'): 184 | idx_list = np.concatenate(idx_aux) 185 | replay_ = Subset(replay_, idx_list) 186 | replay_.data = [] 187 | replay_.targets = np.array(replay_.dataset.targets)[idx_list] 188 | replay_.transform = replay_.dataset.transform 189 | elif dataset_name.startswith('M-'): 190 | replay_.data = torch.from_numpy(np.concatenate(data_aux, 0)) 191 | replay_.targets = torch.from_numpy(np.concatenate(targets_aux, 0)) 192 | else: 193 | replay_.data = np.array(list(chain(*data_aux))) 194 | replay_.targets = np.array(list(chain(*targets_aux))) 195 | replay_.full_labels = np.array(list(chain(*full_target_aux))) 196 | replay_.names = list(chain(*names_aux)) 197 | else: 198 | replay_ = None 199 | del data_aux, targets_aux, full_target_aux, names_aux 200 | data[t]['replay'] = replay_ 201 | 202 | cls_id_past += [int(random_sep[t].split('-')[-1]) * args.class_num + i for i in range(args.class_num)] 203 | 204 | data[args.ntasks] = {} 205 | ## replay for the final task 206 | replay_ = deepcopy(train) 207 | targets_aux, data_aux, full_target_aux, names_aux = [], [], [], [] 208 | idx_aux = [] 209 | # mix replay dataset 210 | for c in cls_id_past: 211 | idx = np.where(np.array(train.targets) == c)[0][:(args.replay_buffer_size // len(cls_id_past))] 212 | if dataset_name.startswith('T-'): 213 | idx_aux.append(idx) 214 | else: 215 | data_aux.append(train.data[idx]) 216 | targets_aux.append(np.zeros(len(idx), dtype=np.int64) + c) 217 | full_target_aux.append([[c, c]] for _ in range(len(idx))) 218 | names_aux.append([str(c) for _ in range(len(idx))]) 219 | 220 | if dataset_name.startswith('T-'): 221 | idx_list = np.concatenate(idx_aux) 222 | replay_ = Subset(replay_, idx_list) 223 | replay_.data = [] 224 | replay_.targets = np.array(replay_.dataset.targets)[idx_list] 225 | replay_.transform = replay_.dataset.transform 226 | elif dataset_name.startswith('M-'): 227 | replay_.data = torch.from_numpy(np.concatenate(data_aux, 0)) 228 | replay_.targets = torch.from_numpy(np.concatenate(targets_aux, 0)) 229 | else: 230 | replay_.data = np.array(list(chain(*data_aux))) 231 | replay_.targets = np.array(list(chain(*targets_aux))) 232 | replay_.full_labels = np.array(list(chain(*full_target_aux))) 233 | replay_.names = list(chain(*names_aux)) 234 | 235 | del data_aux, targets_aux, full_target_aux, names_aux 236 | data[args.ntasks]['replay'] = replay_ 237 | 238 | return data 239 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import config 3 | from utils import utils 4 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs 5 | from accelerate.utils import set_seed 6 | import os 7 | from dataloader.data import get_dataset 8 | from torch.utils.data import DataLoader 9 | from approaches.eval import Appr 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | args = config.parse_args() 14 | args = utils.prepare_sequence_eval(args) 15 | 16 | accelerator = Accelerator() 17 | 18 | # Make one log on every process with the configuration for debugging. 19 | logging.basicConfig( 20 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 21 | datefmt="%m/%d/%Y %H:%M:%S", 22 | level=logging.INFO, 23 | ) 24 | # If passed along, set the training seed now. 25 | if args.seed is not None: 26 | set_seed(args.seed) 27 | 28 | if accelerator.is_main_process: 29 | if args.output_dir is not None: 30 | os.makedirs(args.output_dir, exist_ok=True) 31 | accelerator.wait_for_everyone() 32 | 33 | dataset = get_dataset(args) 34 | model = utils.lookfor_model(args) 35 | 36 | test_loaders = [] 37 | train_loaders = [] 38 | for eval_t in range(args.ntasks): 39 | test_dataset = dataset[eval_t]['test'] 40 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 41 | test_loaders.append(test_dataloader) 42 | train_dataset = dataset[eval_t]['train'] 43 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 44 | train_loaders.append(train_dataloader) 45 | 46 | replay_loader = DataLoader(dataset[args.task+1]['replay'], batch_size=int((args.replay_buffer_size // (args.class_num * (args.task + 1))) * args.class_num), shuffle=False, num_workers=8) 47 | 48 | appr = Appr(args) 49 | appr.eval(model, train_loaders, test_loaders, replay_loader, accelerator) -------------------------------------------------------------------------------- /figures/TPL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaowei1/TPL/fc7aa86e0b197048f4ede9a7c914bb026aaa13ec/figures/TPL.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import config 4 | from utils import utils 5 | import os 6 | from dataloader.data import get_dataset 7 | from torch.utils.data import DataLoader 8 | from approaches.train import Appr 9 | from approaches.noncl import Appr as Appr_noncl 10 | import torch 11 | from torch.utils.data import ConcatDataset 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | args = config.parse_args() 16 | args = utils.prepare_sequence_train(args) 17 | ## set seed 18 | random_seed = args.seed # or any of your favorite number 19 | torch.manual_seed(random_seed) 20 | torch.cuda.manual_seed(random_seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | np.random.seed(random_seed) 24 | 25 | logging.basicConfig( 26 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 27 | datefmt="%m/%d/%Y %H:%M:%S", 28 | level=logging.INFO, 29 | ) 30 | 31 | if args.output_dir is not None: 32 | os.makedirs(args.output_dir, exist_ok=True) 33 | 34 | dataset = get_dataset(args) 35 | model = utils.lookfor_model(args) 36 | 37 | if 'full' in args.baseline: 38 | train_loader = DataLoader(ConcatDataset([dataset[t]['train'] for t in range(args.task+1)]), batch_size=args.batch_size, shuffle=True, num_workers=8) 39 | else: 40 | train_loader = DataLoader(dataset[args.task]['train'], batch_size=args.batch_size, shuffle=True, num_workers=8) 41 | 42 | test_loaders = [] 43 | 44 | for eval_t in range(args.ntasks): 45 | test_dataset = dataset[eval_t]['test'] 46 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 47 | test_loaders.append(test_dataloader) 48 | 49 | replay_loader = None 50 | if dataset[args.task]['replay'] is not None: 51 | replay_loader = DataLoader(dataset[args.task]['replay'], batch_size=args.replay_batch_size, shuffle=True, num_workers=8) 52 | 53 | if 'full' in args.baseline: 54 | appr = Appr_noncl(args) 55 | else: 56 | appr = Appr(args) 57 | 58 | appr.train(model, train_loader, test_loaders, replay_loader) -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaowei1/TPL/fc7aa86e0b197048f4ede9a7c914bb026aaa13ec/networks/__init__.py -------------------------------------------------------------------------------- /networks/vit_hat.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2021 Ross Wightman 24 | """ 25 | import math 26 | import logging 27 | from functools import partial 28 | from collections import OrderedDict 29 | from copy import deepcopy 30 | import itertools 31 | 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | 36 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 37 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 38 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 39 | from timm.models.registry import register_model 40 | 41 | _logger = logging.getLogger(__name__) 42 | 43 | 44 | def _cfg(url='', **kwargs): 45 | return { 46 | 'url': url, 47 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 48 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 49 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 50 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 51 | **kwargs 52 | } 53 | 54 | 55 | default_cfgs = { 56 | # patch models (weights from official Google JAX impl) 57 | 'vit_tiny_patch16_224': _cfg( 58 | url='https://storage.googleapis.com/vit_models/augreg/' 59 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 60 | 'vit_tiny_patch16_384': _cfg( 61 | url='https://storage.googleapis.com/vit_models/augreg/' 62 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 63 | input_size=(3, 384, 384), crop_pct=1.0), 64 | 'vit_small_patch32_224': _cfg( 65 | url='https://storage.googleapis.com/vit_models/augreg/' 66 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 67 | 'vit_small_patch32_384': _cfg( 68 | url='https://storage.googleapis.com/vit_models/augreg/' 69 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 70 | input_size=(3, 384, 384), crop_pct=1.0), 71 | 'vit_small_patch16_224': _cfg( 72 | url='https://storage.googleapis.com/vit_models/augreg/' 73 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 74 | 'vit_small_patch16_384': _cfg( 75 | url='https://storage.googleapis.com/vit_models/augreg/' 76 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 77 | input_size=(3, 384, 384), crop_pct=1.0), 78 | 'vit_base_patch32_224': _cfg( 79 | url='https://storage.googleapis.com/vit_models/augreg/' 80 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 81 | 'vit_base_patch32_384': _cfg( 82 | url='https://storage.googleapis.com/vit_models/augreg/' 83 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 84 | input_size=(3, 384, 384), crop_pct=1.0), 85 | 'vit_base_patch16_224': _cfg( 86 | url='https://storage.googleapis.com/vit_models/augreg/' 87 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 88 | 'vit_base_patch16_384': _cfg( 89 | url='https://storage.googleapis.com/vit_models/augreg/' 90 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 91 | input_size=(3, 384, 384), crop_pct=1.0), 92 | 'vit_large_patch32_224': _cfg( 93 | url='', # no official model weights for this combo, only for in21k 94 | ), 95 | 'vit_large_patch32_384': _cfg( 96 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 97 | input_size=(3, 384, 384), crop_pct=1.0), 98 | 'vit_large_patch16_224': _cfg( 99 | url='https://storage.googleapis.com/vit_models/augreg/' 100 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 101 | 'vit_large_patch16_384': _cfg( 102 | url='https://storage.googleapis.com/vit_models/augreg/' 103 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 104 | input_size=(3, 384, 384), crop_pct=1.0), 105 | 106 | # patch models, imagenet21k (weights from official Google JAX impl) 107 | 'vit_tiny_patch16_224_in21k': _cfg( 108 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 109 | num_classes=21843), 110 | 'vit_small_patch32_224_in21k': _cfg( 111 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 112 | num_classes=21843), 113 | 'vit_small_patch16_224_in21k': _cfg( 114 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 115 | num_classes=21843), 116 | 'vit_base_patch32_224_in21k': _cfg( 117 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 118 | num_classes=21843), 119 | 'vit_base_patch16_224_in21k': _cfg( 120 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 121 | num_classes=21843), 122 | 'vit_large_patch32_224_in21k': _cfg( 123 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 124 | num_classes=21843), 125 | 'vit_large_patch16_224_in21k': _cfg( 126 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 127 | num_classes=21843), 128 | 'vit_huge_patch14_224_in21k': _cfg( 129 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 130 | hf_hub='timm/vit_huge_patch14_224_in21k', 131 | num_classes=21843), 132 | 133 | # deit models (FB weights) 134 | 'deit_tiny_patch16_224': _cfg( 135 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', 136 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 137 | 'deit_small_patch16_224': _cfg( 138 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', 139 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 140 | 'deit_base_patch16_224': _cfg( 141 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', 142 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 143 | 'deit_base_patch16_384': _cfg( 144 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 145 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 146 | 'deit_tiny_distilled_patch16_224': _cfg( 147 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', 148 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 149 | 'deit_small_distilled_patch16_224': _cfg( 150 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', 151 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 152 | 'deit_base_distilled_patch16_224': _cfg( 153 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', 154 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 155 | 'deit_base_distilled_patch16_384': _cfg( 156 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 157 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, 158 | classifier=('head', 'head_dist')), 159 | 160 | # ViT ImageNet-21K-P pretraining by MILL 161 | 'vit_base_patch16_224_miil_in21k': _cfg( 162 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', 163 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 164 | ), 165 | 'vit_base_patch16_224_miil': _cfg( 166 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' 167 | '/vit_base_patch16_224_1k_miil_84_4.pth', 168 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 169 | ), 170 | } 171 | 172 | class mySequential(nn.Sequential): 173 | def forward(self, *inputs): 174 | for module in self._modules.values(): 175 | if type(inputs) == tuple: 176 | inputs = module(*inputs) 177 | else: 178 | inputs = module(inputs) 179 | return inputs 180 | 181 | class Attention(nn.Module): 182 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 183 | super().__init__() 184 | self.num_heads = num_heads 185 | head_dim = dim // num_heads 186 | self.scale = head_dim ** -0.5 187 | 188 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 189 | self.attn_drop = nn.Dropout(attn_drop) 190 | self.proj = nn.Linear(dim, dim) 191 | self.proj_drop = nn.Dropout(proj_drop) 192 | 193 | def forward(self, x): 194 | B, N, C = x.shape 195 | 196 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 197 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 198 | 199 | attn = (q @ k.transpose(-2, -1)) * self.scale 200 | attn = attn.softmax(dim=-1) 201 | attn = self.attn_drop(attn) 202 | 203 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 204 | x = self.proj(x) 205 | x = self.proj_drop(x) 206 | return x 207 | 208 | class Adapter(nn.Module): 209 | def __init__(self, in_dim, out_dim, hat=False): 210 | super().__init__() 211 | self.in_dim, self.out_dim = in_dim, out_dim 212 | self.fc1 = nn.Linear(in_dim, out_dim) 213 | self.fc2 = nn.Linear(out_dim, in_dim) 214 | self.relu = nn.ReLU() 215 | 216 | if hat: 217 | self.gate = torch.sigmoid 218 | self.ec1 = nn.ParameterList() 219 | self.ec2 = nn.ParameterList() 220 | 221 | self.hat = hat 222 | self.init_weights() 223 | 224 | def forward(self, x, t=None, msk=None, s=None): 225 | if self.hat: 226 | masks = self.mask(t, s=s) 227 | gc1, gc2 = masks 228 | 229 | msk.append(masks) 230 | 231 | h = self.relu(self.mask_out(self.fc1(x), gc1)) 232 | h = self.mask_out(self.fc2(h), gc2) 233 | return x + h, msk 234 | else: 235 | h = self.relu(self.fc1(x)) 236 | h = self.fc2(h) 237 | return x + h 238 | 239 | def init_weights(self): 240 | for n, p in self.named_parameters(): 241 | p.data = p * 0 + 1e-24 / p.size(0) 242 | 243 | def mask(self, t, s): 244 | gc1 = self.gate(s * self.ec1[t]) 245 | gc2 = self.gate(s * self.ec2[t]) 246 | return [gc1, gc2] 247 | 248 | def mask_out(self, out, mask): 249 | out = out * mask.expand_as(out) 250 | return out 251 | 252 | def append_embeddings(self): 253 | self.ec1.append(nn.Parameter(torch.randn(1, self.out_dim, device='cuda'))) 254 | self.ec2.append(nn.Parameter(torch.randn(1, self.in_dim, device='cuda'))) 255 | 256 | class Block(nn.Module): 257 | 258 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 259 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, latent=None, hat=False): 260 | super().__init__() 261 | self.norm1 = norm_layer(dim) 262 | self.list_norm1 = nn.ModuleList() if hat else None 263 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 264 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 265 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 266 | self.norm2 = norm_layer(dim) 267 | self.list_norm2 = nn.ModuleList() if hat else None 268 | mlp_hidden_dim = int(dim * mlp_ratio) 269 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 270 | 271 | self.hat = hat 272 | self.adapter1 = Adapter(dim, latent, hat=hat) 273 | self.adapter2 = Adapter(dim, latent, hat=hat) 274 | 275 | def forward(self, x, t=None, msk=None, s=None): 276 | if self.hat: 277 | h, msk = self.adapter1( 278 | self.drop_path(self.attn(self.list_norm1[t](x))), 279 | t, 280 | msk, 281 | s 282 | ) 283 | x = x + h 284 | 285 | h, msk = self.adapter2( 286 | self.drop_path(self.mlp(self.list_norm2[t](x))), 287 | t, 288 | msk, 289 | s 290 | ) 291 | x = x + h 292 | return x, t, msk, s 293 | else: 294 | x = self.adapter1(self.drop_path(self.attn(self.norm1(x)))) + x 295 | x = self.adapter2(self.drop_path(self.mlp(self.norm2(x)))) + x 296 | return x 297 | 298 | class MyVisionTransformer(nn.Module): 299 | """ Vision Transformer 300 | 301 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 302 | - https://arxiv.org/abs/2010.11929 303 | 304 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 305 | - https://arxiv.org/abs/2012.12877 306 | """ 307 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 308 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 309 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 310 | act_layer=None, weight_init='', latent=None, args=None, hat=False): 311 | """ 312 | Args: 313 | img_size (int, tuple): input image size 314 | patch_size (int, tuple): patch size 315 | in_chans (int): number of input channels 316 | num_classes (int): number of classes for classification head 317 | embed_dim (int): embedding dimension 318 | depth (int): depth of transformer 319 | num_heads (int): number of attention heads 320 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 321 | qkv_bias (bool): enable bias for qkv if True 322 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 323 | distilled (bool): model includes a distillation token and head as in DeiT models 324 | drop_rate (float): dropout rate 325 | attn_drop_rate (float): attention dropout rate 326 | drop_path_rate (float): stochastic depth rate 327 | embed_layer (nn.Module): patch embedding layer 328 | norm_layer: (nn.Module): normalization layer 329 | weight_init: (str): weight init scheme 330 | """ 331 | super().__init__() 332 | self.freeze_head = False 333 | self.num_classes = num_classes 334 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 335 | self.num_tokens = 2 if distilled else 1 336 | self.norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 337 | act_layer = act_layer or nn.GELU 338 | 339 | self.patch_embed = embed_layer( 340 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=self.norm_layer) 341 | num_patches = self.patch_embed.num_patches 342 | 343 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 344 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 345 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 346 | self.pos_drop = nn.Dropout(p=drop_rate) 347 | 348 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 349 | 350 | self.blocks = mySequential(*[ 351 | Block( 352 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 353 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer, act_layer=act_layer, latent=latent, hat=hat) 354 | for i in range(depth)]) 355 | self.norm = self.norm_layer(embed_dim) 356 | self.list_norm = nn.ModuleList() if hat else None 357 | 358 | # Representation layer 359 | if representation_size and not distilled: 360 | self.num_features = representation_size 361 | self.pre_logits = mySequential(OrderedDict([ 362 | ('fc', nn.Linear(embed_dim, representation_size)), 363 | ('act', nn.Tanh()) 364 | ])) 365 | else: 366 | self.pre_logits = nn.Identity() 367 | 368 | # Classifier head(s) 369 | self.head = nn.ModuleList() if hat else nn.Linear(self.embed_dim, self.num_classes) 370 | self.head_dist = None 371 | if distilled: 372 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 373 | 374 | self.init_weights(weight_init) 375 | 376 | def init_weights(self, mode=''): 377 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 378 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 379 | trunc_normal_(self.pos_embed, std=.02) 380 | if self.dist_token is not None: 381 | trunc_normal_(self.dist_token, std=.02) 382 | if mode.startswith('jax'): 383 | # leave cls token as zeros to match jax impl 384 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 385 | else: 386 | trunc_normal_(self.cls_token, std=.02) 387 | self.apply(_init_vit_weights) 388 | 389 | def _init_weights(self, m): 390 | # this fn left here for compat with downstream users 391 | _init_vit_weights(m) 392 | 393 | @torch.jit.ignore() 394 | def load_pretrained(self, checkpoint_path, prefix=''): 395 | _load_weights(self, checkpoint_path, prefix) 396 | 397 | @torch.jit.ignore 398 | def no_weight_decay(self): 399 | return {'pos_embed', 'cls_token', 'dist_token'} 400 | 401 | def get_classifier(self): 402 | raise NotImplementedError() 403 | if self.dist_token is None: 404 | return self.head 405 | else: 406 | return self.head, self.head_dist 407 | 408 | def reset_classifier(self, num_classes, global_pool=''): 409 | raise NotImplementedError() 410 | self.num_classes = num_classes 411 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 412 | if self.num_tokens == 2: 413 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 414 | 415 | def forward_features(self, x, t=None, s=None): 416 | msk = [] 417 | x = self.patch_embed(x) 418 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 419 | 420 | if self.dist_token is None: 421 | x = torch.cat((cls_token, x), dim=1) 422 | else: 423 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 424 | x = self.pos_drop(x + self.pos_embed) 425 | 426 | if t is not None: 427 | x, _, msk, _ = self.blocks(x, t, msk, s) 428 | x = self.list_norm[t](x) 429 | return self.pre_logits(x[:, 0]), list(itertools.chain(*msk)) 430 | else: 431 | x = self.norm(self.blocks(x)) 432 | return self.pre_logits(x[:, 0]) 433 | 434 | 435 | def forward(self, x, t=None, s=None): 436 | if s is not None: 437 | x, msk = self.forward_features(x, t=t, s=s) 438 | x = self.head[t](x) 439 | return x, msk 440 | else: 441 | x = self.forward_features(x) 442 | return self.head(x) 443 | 444 | def forward_classifier(self, x, t=None): 445 | if t is not None: 446 | return self.head[t](x) 447 | else: 448 | return self.head(x) 449 | 450 | def append_embeddings(self): 451 | # append head 452 | self.head.append(nn.Linear(self.embed_dim, self.num_classes)) 453 | 454 | self.list_norm.append(deepcopy(self.norm)) 455 | for b in self.blocks: 456 | b.adapter1.append_embeddings() 457 | b.adapter2.append_embeddings() 458 | 459 | b.list_norm1.append(deepcopy(b.norm1)) 460 | b.list_norm2.append(deepcopy(b.norm2)) 461 | 462 | def head_parameters(self): 463 | return [p for n, p in self.named_parameters() if 'head' in n] 464 | 465 | def adapter_parameters(self): 466 | return [p for n, p in self.named_parameters() if 'adapter' in n or 'list_norm' in n or 'head' in n] 467 | 468 | class ViTFrozenHead(MyVisionTransformer): 469 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 470 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 471 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 472 | act_layer=None, weight_init='', latent=None): 473 | super(ViTFrozenHead, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dim, depth, 474 | num_heads, mlp_ratio, qkv_bias, representation_size, distilled, 475 | drop_rate, attn_drop_rate, drop_path_rate, embed_layer, norm_layer, 476 | act_layer, weight_init, latent) 477 | 478 | def adapter_parameters(self): 479 | return [p for n, p in self.named_parameters() if 'adapter' in n or 'list_norm' in n] 480 | 481 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 482 | """ ViT weight initialization 483 | * When called without n, head_bias, jax_impl args it will behave exactly the same 484 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 485 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 486 | """ 487 | if isinstance(module, nn.Linear): 488 | if name.startswith('head'): 489 | nn.init.zeros_(module.weight) 490 | nn.init.constant_(module.bias, head_bias) 491 | elif name.startswith('pre_logits'): 492 | lecun_normal_(module.weight) 493 | nn.init.zeros_(module.bias) 494 | else: 495 | if jax_impl: 496 | nn.init.xavier_uniform_(module.weight) 497 | if module.bias is not None: 498 | if 'mlp' in name: 499 | nn.init.normal_(module.bias, std=1e-6) 500 | else: 501 | nn.init.zeros_(module.bias) 502 | else: 503 | trunc_normal_(module.weight, std=.02) 504 | if module.bias is not None: 505 | nn.init.zeros_(module.bias) 506 | elif jax_impl and isinstance(module, nn.Conv2d): 507 | # NOTE conv was left to pytorch default in my original init 508 | lecun_normal_(module.weight) 509 | if module.bias is not None: 510 | nn.init.zeros_(module.bias) 511 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 512 | nn.init.zeros_(module.bias) 513 | nn.init.ones_(module.weight) 514 | 515 | 516 | @torch.no_grad() 517 | def _load_weights(model: MyVisionTransformer, checkpoint_path: str, prefix: str = ''): 518 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 519 | """ 520 | import numpy as np 521 | 522 | def _n2p(w, t=True): 523 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 524 | w = w.flatten() 525 | if t: 526 | if w.ndim == 4: 527 | w = w.transpose([3, 2, 0, 1]) 528 | elif w.ndim == 3: 529 | w = w.transpose([2, 0, 1]) 530 | elif w.ndim == 2: 531 | w = w.transpose([1, 0]) 532 | return torch.from_numpy(w) 533 | 534 | w = np.load(checkpoint_path) 535 | if not prefix and 'opt/target/embedding/kernel' in w: 536 | prefix = 'opt/target/' 537 | 538 | if hasattr(model.patch_embed, 'backbone'): 539 | # hybrid 540 | backbone = model.patch_embed.backbone 541 | stem_only = not hasattr(backbone, 'stem') 542 | stem = backbone if stem_only else backbone.stem 543 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 544 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 545 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 546 | if not stem_only: 547 | for i, stage in enumerate(backbone.stages): 548 | for j, block in enumerate(stage.blocks): 549 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 550 | for r in range(3): 551 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 552 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 553 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 554 | if block.downsample is not None: 555 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 556 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 557 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 558 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 559 | else: 560 | embed_conv_w = adapt_input_conv( 561 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 562 | model.patch_embed.proj.weight.copy_(embed_conv_w) 563 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 564 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 565 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 566 | if pos_embed_w.shape != model.pos_embed.shape: 567 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 568 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 569 | model.pos_embed.copy_(pos_embed_w) 570 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 571 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 572 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 573 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 574 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 575 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 576 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 577 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 578 | for i, block in enumerate(model.blocks.children()): 579 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 580 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 581 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 582 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 583 | block.attn.qkv.weight.copy_(torch.cat([ 584 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 585 | block.attn.qkv.bias.copy_(torch.cat([ 586 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 587 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 588 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 589 | for r in range(2): 590 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 591 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 592 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 593 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 594 | 595 | 596 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 597 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 598 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 599 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 600 | ntok_new = posemb_new.shape[1] 601 | if num_tokens: 602 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 603 | ntok_new -= num_tokens 604 | else: 605 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 606 | gs_old = int(math.sqrt(len(posemb_grid))) 607 | if not len(gs_new): # backwards compatibility 608 | gs_new = [int(math.sqrt(ntok_new))] * 2 609 | assert len(gs_new) >= 2 610 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 611 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 612 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') 613 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 614 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 615 | return posemb 616 | 617 | 618 | def checkpoint_filter_fn(state_dict, model): 619 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 620 | out_dict = {} 621 | if 'model' in state_dict: 622 | # For deit models 623 | state_dict = state_dict['model'] 624 | for k, v in state_dict.items(): 625 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 626 | # For old models that I trained prior to conv based patchification 627 | O, I, H, W = model.patch_embed.proj.weight.shape 628 | v = v.reshape(O, -1, H, W) 629 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 630 | # To resize pos embedding when using model at different size from pretrained weights 631 | v = resize_pos_embed( 632 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 633 | out_dict[k] = v 634 | return out_dict 635 | 636 | 637 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 638 | default_cfg = default_cfg or default_cfgs[variant] 639 | if kwargs.get('features_only', None): 640 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 641 | 642 | # NOTE this extra code to support handling of repr size for in21k pretrained models 643 | default_num_classes = default_cfg['num_classes'] 644 | num_classes = kwargs.get('num_classes', default_num_classes) 645 | repr_size = kwargs.pop('representation_size', None) 646 | if repr_size is not None and num_classes != default_num_classes: 647 | # Remove representation layer if fine-tuning. This may not always be the desired action, 648 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 649 | _logger.warning("Removing representation layer for fine-tuning.") 650 | repr_size = None 651 | 652 | model = build_model_with_cfg( 653 | MyVisionTransformer, variant, pretrained, 654 | default_cfg=default_cfg, 655 | representation_size=repr_size, 656 | pretrained_filter_fn=checkpoint_filter_fn, 657 | pretrained_custom_load='npz' in default_cfg['url'], 658 | **kwargs) 659 | return model 660 | 661 | 662 | @register_model 663 | def vit_tiny_patch16_224(pretrained=False, **kwargs): 664 | """ ViT-Tiny (Vit-Ti/16) 665 | """ 666 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 667 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 668 | return model 669 | 670 | 671 | @register_model 672 | def vit_tiny_patch16_384(pretrained=False, **kwargs): 673 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 674 | """ 675 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 676 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 677 | return model 678 | 679 | 680 | @register_model 681 | def vit_small_patch32_224(pretrained=False, **kwargs): 682 | """ ViT-Small (ViT-S/32) 683 | """ 684 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 685 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) 686 | return model 687 | 688 | 689 | @register_model 690 | def vit_small_patch32_384(pretrained=False, **kwargs): 691 | """ ViT-Small (ViT-S/32) at 384x384. 692 | """ 693 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 694 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) 695 | return model 696 | 697 | 698 | @register_model 699 | def vit_small_patch16_224(pretrained=False, **kwargs): 700 | """ ViT-Small (ViT-S/16) 701 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 702 | """ 703 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 704 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 705 | return model 706 | 707 | 708 | @register_model 709 | def vit_small_patch16_384(pretrained=False, **kwargs): 710 | """ ViT-Small (ViT-S/16) 711 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 712 | """ 713 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 714 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) 715 | return model 716 | 717 | 718 | @register_model 719 | def vit_base_patch32_224(pretrained=False, **kwargs): 720 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 721 | """ 722 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 723 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) 724 | return model 725 | 726 | 727 | @register_model 728 | def vit_base_patch32_384(pretrained=False, **kwargs): 729 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 730 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 731 | """ 732 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 733 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 734 | return model 735 | 736 | 737 | @register_model 738 | def vit_base_patch16_224(pretrained=False, **kwargs): 739 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 740 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 741 | """ 742 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 743 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 744 | return model 745 | 746 | 747 | @register_model 748 | def vit_base_patch16_384(pretrained=False, **kwargs): 749 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 750 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 751 | """ 752 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 753 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 754 | return model 755 | 756 | 757 | @register_model 758 | def vit_large_patch32_224(pretrained=False, **kwargs): 759 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 760 | """ 761 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 762 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) 763 | return model 764 | 765 | 766 | @register_model 767 | def vit_large_patch32_384(pretrained=False, **kwargs): 768 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 769 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 770 | """ 771 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 772 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 773 | return model 774 | 775 | 776 | @register_model 777 | def vit_large_patch16_224(pretrained=False, **kwargs): 778 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 779 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 780 | """ 781 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 782 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) 783 | return model 784 | 785 | 786 | @register_model 787 | def vit_large_patch16_384(pretrained=False, **kwargs): 788 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 789 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 790 | """ 791 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 792 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 793 | return model 794 | 795 | 796 | @register_model 797 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): 798 | """ ViT-Tiny (Vit-Ti/16). 799 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 800 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 801 | """ 802 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 803 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 804 | return model 805 | 806 | 807 | @register_model 808 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs): 809 | """ ViT-Small (ViT-S/16) 810 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 811 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 812 | """ 813 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 814 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 815 | return model 816 | 817 | 818 | @register_model 819 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs): 820 | """ ViT-Small (ViT-S/16) 821 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 822 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 823 | """ 824 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 825 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 826 | return model 827 | 828 | 829 | @register_model 830 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs): 831 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 832 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 833 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 834 | """ 835 | model_kwargs = dict( 836 | patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 837 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 838 | return model 839 | 840 | 841 | @register_model 842 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 843 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 844 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 845 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 846 | """ 847 | model_kwargs = dict( 848 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 849 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 850 | return model 851 | 852 | 853 | @register_model 854 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs): 855 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 856 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 857 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 858 | """ 859 | model_kwargs = dict( 860 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) 861 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 862 | return model 863 | 864 | 865 | @register_model 866 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs): 867 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 868 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 869 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 870 | """ 871 | model_kwargs = dict( 872 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 873 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 874 | return model 875 | 876 | 877 | @register_model 878 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): 879 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 880 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 881 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 882 | """ 883 | model_kwargs = dict( 884 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) 885 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) 886 | return model 887 | 888 | 889 | @register_model 890 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 891 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 892 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 893 | """ 894 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 895 | model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 896 | return model 897 | 898 | 899 | @register_model 900 | def deit_small_patch16_224(pretrained=False, **kwargs): 901 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 902 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 903 | """ 904 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 905 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) 906 | return model 907 | 908 | @register_model 909 | def deit_base_patch16_224(pretrained=False, **kwargs): 910 | """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 911 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 912 | """ 913 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 914 | model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) 915 | return model 916 | 917 | 918 | @register_model 919 | def deit_base_patch16_384(pretrained=False, **kwargs): 920 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 921 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 922 | """ 923 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 924 | model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) 925 | return model 926 | 927 | 928 | @register_model 929 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 930 | """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 931 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 932 | """ 933 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 934 | model = _create_vision_transformer( 935 | 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 936 | return model 937 | 938 | 939 | @register_model 940 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 941 | """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 942 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 943 | """ 944 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 945 | model = _create_vision_transformer( 946 | 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 947 | return model 948 | 949 | 950 | @register_model 951 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 952 | """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 953 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 954 | """ 955 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 956 | model = _create_vision_transformer( 957 | 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 958 | return model 959 | 960 | 961 | @register_model 962 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 963 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 964 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 965 | """ 966 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 967 | model = _create_vision_transformer( 968 | 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) 969 | return model 970 | 971 | 972 | @register_model 973 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): 974 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 975 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 976 | """ 977 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 978 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) 979 | return model 980 | 981 | 982 | @register_model 983 | def vit_base_patch16_224_miil(pretrained=False, **kwargs): 984 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 985 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 986 | """ 987 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 988 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) 989 | return model 990 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.8.0 2 | transformers==4.11.3 3 | scikit-learn 4 | accelerate==0.5.1 5 | datasets 6 | tensorboard 7 | jsonlines 8 | timm==0.4.12 9 | ftfy -------------------------------------------------------------------------------- /scripts/clip.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=0 3 | ve="vit_base_patch16_clip_224.openai" 4 | bs=('vit_base_patch16_clip_224.openai_C10_5T_hat' 'vit_base_patch16_clip_224.openai_C100_10T_hat' 'vit_base_patch16_clip_224.openai_C100_20T_hat' 'vit_base_patch16_clip_224.openai_T_5T_hat' 'vit_base_patch16_clip_224.openai_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in $(seq 0 4); 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/deit_small.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=1 3 | ve="deit_small_patch16_224" 4 | bs=('deit_small_patch16_224_C10_5T_hat' 'deit_small_patch16_224_C100_10T_hat' 'deit_small_patch16_224_C100_20T_hat' 'deit_small_patch16_224_T_5T_hat' 'deit_small_patch16_224_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in $(seq 0 4); 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/deit_small_in661.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=0 3 | ve="deit_small_patch16_224_in661" 4 | bs=('deit_small_patch16_224_in661_C10_5T_hat' 'deit_small_patch16_224_in661_C100_10T_hat' 'deit_small_patch16_224_in661_C100_20T_hat' 'deit_small_patch16_224_in661_T_5T_hat' 'deit_small_patch16_224_in661_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in "${!bs[@]}"; 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/deit_tiny.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=4 3 | ve="deit_tiny_patch16_224" 4 | bs=('deit_tiny_patch16_224_C10_5T_hat' 'deit_tiny_patch16_224_C100_10T_hat' 'deit_tiny_patch16_224_C100_20T_hat' 'deit_tiny_patch16_224_T_5T_hat' 'deit_tiny_patch16_224_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in $(seq 0 4); 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/dino.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=5 3 | ve="vit_small_patch16_224.dino" 4 | bs=('vit_small_patch16_224.dino_C10_5T_hat' 'vit_small_patch16_224.dino_C100_10T_hat' 'vit_small_patch16_224.dino_C100_20T_hat' 'vit_small_patch16_224.dino_T_5T_hat' 'vit_small_patch16_224.dino_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in "${!bs[@]}"; 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/mae.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=1 3 | ve="vit_base_patch16_224.mae" 4 | bs=('vit_base_patch16_224.mae_C10_5T_hat' 'vit_base_patch16_224.mae_C100_10T_hat' 'vit_base_patch16_224.mae_C100_20T_hat' 'vit_base_patch16_224.mae_T_5T_hat' 'vit_base_patch16_224.mae_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in $(seq 4 4); 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/vit_small.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=4 3 | ve="vit_small_patch16_224" 4 | bs=('vit_small_patch16_224_C10_5T_hat' 'vit_small_patch16_224_C100_10T_hat' 'vit_small_patch16_224_C100_20T_hat' 'vit_small_patch16_224_T_5T_hat' 'vit_small_patch16_224_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in $(seq 0 4); 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 0 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /scripts/vit_tiny.sh: -------------------------------------------------------------------------------- 1 | seed=(2023 111 222 333 444 555 666 777 888 999) 2 | cuda_id=4 3 | ve="vit_tiny_patch16_224" 4 | bs=('vit_tiny_patch16_224_C10_5T_hat' 'vit_tiny_patch16_224_C100_10T_hat' 'vit_tiny_patch16_224_C100_20T_hat' 'vit_tiny_patch16_224_T_5T_hat' 'vit_tiny_patch16_224_T_10T_hat') 5 | seqfile=('C10_5T' 'C100_10T' 'C100_20T' 'T_5T' 'T_10T') 6 | learning_rate=(0.005 0.001 0.005 0.005 0.005) 7 | num_train_epochs=(20 40 40 15 10) 8 | base_dir="ckpt" 9 | final_task=(4 9 19 4 9) 10 | latent=(64 128 128 128 128) 11 | buffersize=(200 2000 2000 2000 2000) 12 | 13 | for round in 0 1 2 3 4; 14 | do 15 | for class_order in 0; 16 | do 17 | for i in $(seq 0 4); 18 | do 19 | for ft_task in $(seq 0 ${final_task[$i]}); 20 | do 21 | CUDA_VISIBLE_DEVICES=$cuda_id python main.py \ 22 | --task ${ft_task} \ 23 | --idrandom 0 \ 24 | --visual_encoder $ve \ 25 | --baseline "${bs[$i]}" \ 26 | --seed ${seed[$round]} \ 27 | --batch_size 64 \ 28 | --sequence_file "${seqfile[$i]}" \ 29 | --learning_rate ${learning_rate[$i]} \ 30 | --num_train_epochs ${num_train_epochs[$i]} \ 31 | --base_dir ckpt \ 32 | --class_order ${class_order} \ 33 | --latent ${latent[$i]} \ 34 | --replay_buffer_size ${buffersize[$i]} \ 35 | --training 36 | done 37 | for ft_task in $(seq 1 ${final_task[$i]}); 38 | do 39 | CUDA_VISIBLE_DEVICES=$cuda_id python eval.py \ 40 | --task ${ft_task} \ 41 | --idrandom 0 \ 42 | --visual_encoder $ve \ 43 | --baseline "${bs[$i]}" \ 44 | --seed ${seed[$round]} \ 45 | --batch_size 64 \ 46 | --sequence_file "${seqfile[$i]}" \ 47 | --base_dir ckpt \ 48 | --class_order ${class_order} \ 49 | --latent ${latent[$i]} \ 50 | --replay_buffer_size ${buffersize[$i]} 51 | done 52 | done 53 | done 54 | done -------------------------------------------------------------------------------- /sequence/C100_10T: -------------------------------------------------------------------------------- 1 | C100-10T-0 C100-10T-1 C100-10T-2 C100-10T-3 C100-10T-4 C100-10T-5 C100-10T-6 C100-10T-7 C100-10T-8 C100-10T-9 2 | C100-10T-3 C100-10T-4 C100-10T-7 C100-10T-6 C100-10T-9 C100-10T-2 C100-10T-8 C100-10T-5 C100-10T-1 C100-10T-0 3 | C100-10T-1 C100-10T-4 C100-10T-6 C100-10T-3 C100-10T-8 C100-10T-0 C100-10T-2 C100-10T-9 C100-10T-5 C100-10T-7 4 | C100-10T-7 C100-10T-9 C100-10T-5 C100-10T-6 C100-10T-8 C100-10T-2 C100-10T-3 C100-10T-0 C100-10T-4 C100-10T-1 5 | C100-10T-0 C100-10T-1 C100-10T-9 C100-10T-8 C100-10T-3 C100-10T-2 C100-10T-5 C100-10T-6 C100-10T-7 C100-10T-4 6 | C100-10T-7 C100-10T-9 C100-10T-1 C100-10T-0 C100-10T-2 C100-10T-5 C100-10T-4 C100-10T-6 C100-10T-8 C100-10T-3 7 | C100-10T-8 C100-10T-4 C100-10T-6 C100-10T-5 C100-10T-9 C100-10T-1 C100-10T-2 C100-10T-0 C100-10T-7 C100-10T-3 8 | C100-10T-2 C100-10T-7 C100-10T-8 C100-10T-9 C100-10T-3 C100-10T-0 C100-10T-4 C100-10T-1 C100-10T-5 C100-10T-6 9 | C100-10T-7 C100-10T-3 C100-10T-9 C100-10T-1 C100-10T-5 C100-10T-6 C100-10T-8 C100-10T-4 C100-10T-0 C100-10T-2 10 | C100-10T-8 C100-10T-3 C100-10T-7 C100-10T-5 C100-10T-0 C100-10T-6 C100-10T-2 C100-10T-1 C100-10T-9 C100-10T-4 11 | -------------------------------------------------------------------------------- /sequence/C100_20T: -------------------------------------------------------------------------------- 1 | C100-20T-0 C100-20T-1 C100-20T-2 C100-20T-3 C100-20T-4 C100-20T-5 C100-20T-6 C100-20T-7 C100-20T-8 C100-20T-9 C100-20T-10 C100-20T-11 C100-20T-12 C100-20T-13 C100-20T-14 C100-20T-15 C100-20T-16 C100-20T-17 C100-20T-18 C100-20T-19 2 | C100-20T-3 C100-20T-8 C100-20T-13 C100-20T-2 C100-20T-0 C100-20T-10 C100-20T-19 C100-20T-1 C100-20T-14 C100-20T-6 C100-20T-15 C100-20T-4 C100-20T-17 C100-20T-11 C100-20T-7 C100-20T-12 C100-20T-5 C100-20T-9 C100-20T-16 C100-20T-18 3 | C100-20T-16 C100-20T-18 C100-20T-11 C100-20T-9 C100-20T-1 C100-20T-17 C100-20T-19 C100-20T-6 C100-20T-4 C100-20T-3 C100-20T-5 C100-20T-15 C100-20T-0 C100-20T-14 C100-20T-10 C100-20T-2 C100-20T-12 C100-20T-13 C100-20T-7 C100-20T-8 4 | C100-20T-4 C100-20T-16 C100-20T-19 C100-20T-6 C100-20T-13 C100-20T-14 C100-20T-18 C100-20T-0 C100-20T-12 C100-20T-7 C100-20T-3 C100-20T-2 C100-20T-17 C100-20T-1 C100-20T-15 C100-20T-9 C100-20T-8 C100-20T-10 C100-20T-11 C100-20T-5 5 | C100-20T-17 C100-20T-10 C100-20T-5 C100-20T-4 C100-20T-14 C100-20T-11 C100-20T-16 C100-20T-8 C100-20T-2 C100-20T-15 C100-20T-13 C100-20T-9 C100-20T-1 C100-20T-19 C100-20T-6 C100-20T-7 C100-20T-18 C100-20T-0 C100-20T-3 C100-20T-12 6 | C100-20T-13 C100-20T-19 C100-20T-15 C100-20T-4 C100-20T-9 C100-20T-3 C100-20T-18 C100-20T-14 C100-20T-5 C100-20T-6 C100-20T-1 C100-20T-8 C100-20T-16 C100-20T-17 C100-20T-0 C100-20T-2 C100-20T-10 C100-20T-11 C100-20T-7 C100-20T-12 7 | C100-20T-17 C100-20T-15 C100-20T-1 C100-20T-10 C100-20T-16 C100-20T-9 C100-20T-12 C100-20T-6 C100-20T-7 C100-20T-19 C100-20T-13 C100-20T-0 C100-20T-4 C100-20T-14 C100-20T-11 C100-20T-18 C100-20T-5 C100-20T-2 C100-20T-8 C100-20T-3 8 | C100-20T-10 C100-20T-5 C100-20T-13 C100-20T-7 C100-20T-14 C100-20T-3 C100-20T-1 C100-20T-18 C100-20T-9 C100-20T-4 C100-20T-12 C100-20T-17 C100-20T-16 C100-20T-6 C100-20T-0 C100-20T-11 C100-20T-8 C100-20T-2 C100-20T-19 C100-20T-15 9 | C100-20T-18 C100-20T-8 C100-20T-9 C100-20T-5 C100-20T-14 C100-20T-19 C100-20T-15 C100-20T-6 C100-20T-11 C100-20T-7 C100-20T-10 C100-20T-2 C100-20T-4 C100-20T-3 C100-20T-13 C100-20T-17 C100-20T-1 C100-20T-16 C100-20T-0 C100-20T-12 10 | C100-20T-9 C100-20T-7 C100-20T-12 C100-20T-11 C100-20T-14 C100-20T-0 C100-20T-3 C100-20T-15 C100-20T-18 C100-20T-13 C100-20T-5 C100-20T-4 C100-20T-6 C100-20T-17 C100-20T-2 C100-20T-16 C100-20T-10 C100-20T-1 C100-20T-19 C100-20T-8 11 | -------------------------------------------------------------------------------- /sequence/C10_5T: -------------------------------------------------------------------------------- 1 | C10-5T-0 C10-5T-1 C10-5T-2 C10-5T-3 C10-5T-4 2 | C10-5T-0 C10-5T-1 C10-5T-4 C10-5T-3 C10-5T-2 3 | C10-5T-2 C10-5T-0 C10-5T-4 C10-5T-3 C10-5T-1 4 | C10-5T-2 C10-5T-3 C10-5T-0 C10-5T-1 C10-5T-4 5 | C10-5T-2 C10-5T-4 C10-5T-1 C10-5T-3 C10-5T-0 6 | C10-5T-4 C10-5T-0 C10-5T-2 C10-5T-1 C10-5T-3 7 | C10-5T-1 C10-5T-3 C10-5T-0 C10-5T-4 C10-5T-2 8 | C10-5T-4 C10-5T-3 C10-5T-1 C10-5T-2 C10-5T-0 9 | C10-5T-4 C10-5T-1 C10-5T-0 C10-5T-2 C10-5T-3 10 | C10-5T-3 C10-5T-0 C10-5T-4 C10-5T-1 C10-5T-2 11 | -------------------------------------------------------------------------------- /sequence/T_10T: -------------------------------------------------------------------------------- 1 | T-10T-0 T-10T-1 T-10T-2 T-10T-3 T-10T-4 T-10T-5 T-10T-6 T-10T-7 T-10T-8 T-10T-9 2 | T-10T-5 T-10T-7 T-10T-4 T-10T-3 T-10T-1 T-10T-8 T-10T-6 T-10T-9 T-10T-0 T-10T-2 3 | T-10T-9 T-10T-2 T-10T-1 T-10T-3 T-10T-0 T-10T-7 T-10T-8 T-10T-6 T-10T-4 T-10T-5 4 | T-10T-0 T-10T-7 T-10T-5 T-10T-6 T-10T-2 T-10T-9 T-10T-8 T-10T-4 T-10T-3 T-10T-1 5 | T-10T-0 T-10T-4 T-10T-9 T-10T-1 T-10T-7 T-10T-6 T-10T-2 T-10T-5 T-10T-3 T-10T-8 6 | T-10T-7 T-10T-6 T-10T-0 T-10T-1 T-10T-2 T-10T-5 T-10T-8 T-10T-3 T-10T-9 T-10T-4 7 | T-10T-4 T-10T-3 T-10T-2 T-10T-9 T-10T-7 T-10T-5 T-10T-1 T-10T-6 T-10T-0 T-10T-8 8 | T-10T-9 T-10T-0 T-10T-6 T-10T-3 T-10T-7 T-10T-1 T-10T-8 T-10T-4 T-10T-2 T-10T-5 9 | T-10T-6 T-10T-9 T-10T-0 T-10T-4 T-10T-3 T-10T-8 T-10T-7 T-10T-1 T-10T-2 T-10T-5 10 | T-10T-1 T-10T-6 T-10T-9 T-10T-4 T-10T-0 T-10T-7 T-10T-8 T-10T-2 T-10T-3 T-10T-5 11 | -------------------------------------------------------------------------------- /sequence/T_5T: -------------------------------------------------------------------------------- 1 | T-5T-0 T-5T-1 T-5T-2 T-5T-3 T-5T-4 2 | T-5T-4 T-5T-0 T-5T-2 T-5T-1 T-5T-3 3 | T-5T-3 T-5T-4 T-5T-2 T-5T-1 T-5T-0 4 | T-5T-4 T-5T-1 T-5T-0 T-5T-3 T-5T-2 5 | T-5T-3 T-5T-4 T-5T-0 T-5T-1 T-5T-2 6 | T-5T-0 T-5T-1 T-5T-3 T-5T-2 T-5T-4 7 | T-5T-3 T-5T-0 T-5T-1 T-5T-4 T-5T-2 8 | T-5T-4 T-5T-3 T-5T-0 T-5T-1 T-5T-2 9 | T-5T-4 T-5T-0 T-5T-1 T-5T-3 T-5T-2 10 | T-5T-1 T-5T-3 T-5T-4 T-5T-2 T-5T-0 11 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linhaowei1/TPL/fc7aa86e0b197048f4ede9a7c914bb026aaa13ec/utils/__init__.py -------------------------------------------------------------------------------- /utils/baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import json 4 | import torch.nn.functional as F 5 | from utils import utils 6 | import torch.nn as nn 7 | 8 | 9 | @torch.no_grad() 10 | def mds(args, test_logits, test_hidden, loader, task_mask): 11 | test_samples = torch.Tensor(test_hidden) 12 | score_in = utils.maha_score(args, test_samples, args.precision_list, args.feat_mean_list, task_mask) 13 | return score_in 14 | 15 | @torch.no_grad() 16 | def mls(args, test_logits, test_hidden, loader, task_mask): 17 | logits = test_logits 18 | return torch.tensor(logits).max(-1)[0] 19 | 20 | def calculate_mask(self, w): 21 | contrib = self.mean_act[None, :] * w.data.squeeze().cpu().numpy() 22 | self.thresh = np.percentile(contrib, self.p) 23 | mask = torch.Tensor((contrib > self.thresh)).cuda() 24 | self.masked_w = w * mask 25 | 26 | def TPLR(args, test_logits, test_hidden, loader, task_mask): 27 | 28 | logit_score = torch.tensor(test_logits) 29 | logit_score = torch.max(logit_score, dim=1)[0] 30 | 31 | test_samples = torch.Tensor(test_hidden) 32 | p_in = utils.maha_score(args, test_samples, args.precision_list, args.feat_mean_list, task_mask) 33 | 34 | D_out = args.index_out[task_mask].search(utils.normalize(test_samples).astype(np.float32), args.K) 35 | p_out = torch.tensor(D_out[0][:, -1]) 36 | 37 | 38 | logit_score = logit_score / args.mls_scale[task_mask] 39 | p_in = p_in / args.mds_scale[task_mask] 40 | p_out = p_out / args.knn_scale[task_mask] 41 | 42 | e1 = p_in + p_out 43 | e2 = logit_score 44 | 45 | composition = -torch.logsumexp(torch.stack((-e1, -e2), dim=0), dim=0) 46 | 47 | return composition 48 | 49 | 50 | def baseline(args, results): 51 | 52 | metric = {} 53 | ood_label = {} 54 | ood_score = {} 55 | sum_ = 0 56 | 57 | for eval_t in range(args.task + 1): 58 | 59 | metric[eval_t] = {} 60 | ood_label[eval_t] = {} 61 | ood_score[eval_t] = {} 62 | 63 | logits = np.transpose(results[eval_t]['logits'], (1, 0, 2)) # (task_mask, sample, logit) 64 | softmax = torch.softmax(torch.from_numpy(logits / 1.0), dim=-1) 65 | 66 | for task_mask in range(args.task+1): 67 | 68 | score_in = None # (samples) 69 | 70 | test_logits = results[eval_t]['logits'][task_mask] 71 | test_hidden = results[eval_t]['hidden'][task_mask] 72 | loader = args.test_loaders[eval_t] 73 | 74 | 75 | score_in = TPLR(args, test_logits, test_hidden, loader, task_mask) 76 | 77 | # calibration 78 | score_in = args.calib_b[task_mask] + args.calib_w[task_mask] * score_in 79 | 80 | ood_score[eval_t][task_mask] = score_in.cpu().numpy().tolist() 81 | ood_label[eval_t][task_mask] = [1] * len(ood_score[eval_t][task_mask]) if eval_t == task_mask else [-1] * len(ood_score[eval_t][task_mask]) 82 | 83 | tp_logits = torch.stack([torch.tensor(ood_score[eval_t][task_mask]) for task_mask in range(args.task + 1)], dim=-1) # (sample, task_num) 84 | tp_softmax = torch.softmax(tp_logits / 0.05, -1) 85 | task_prediction = torch.max(tp_logits, dim=1)[1] 86 | 87 | prediction = (softmax * tp_softmax.unsqueeze(-1)).view(tp_logits.shape[0], -1).max(-1)[1].cpu().numpy().tolist() 88 | 89 | metric[eval_t]['tp_acc'] = utils.acc(task_prediction, np.array(results[eval_t]['references']) // args.class_num) 90 | metric[eval_t]['acc'] = utils.acc(prediction, results[eval_t]['references']) 91 | 92 | sum_ += metric[eval_t]['acc'] 93 | 94 | auc_avg = 0.0 95 | fpr_avg = 0.0 96 | aupr_avg = 0.0 97 | 98 | for task_mask in range(args.task + 1): 99 | 100 | ind_score = ood_score[task_mask][task_mask] 101 | ind_label = ood_label[task_mask][task_mask] 102 | for eval_t in range(args.task + 1): 103 | 104 | if eval_t == task_mask: 105 | continue 106 | 107 | ood_s = ood_score[eval_t][task_mask] 108 | ood_l = ood_label[eval_t][task_mask] 109 | 110 | predictions = np.array(ind_score + ood_s) 111 | references = np.array(ind_label + ood_l) 112 | 113 | auc_avg += utils.auroc(predictions, references) 114 | fpr_avg += utils.fpr95(predictions, references) 115 | aupr_avg += utils.aupr(predictions, references) 116 | 117 | metric['auroc'] = auc_avg / ((args.task + 1) * args.task) 118 | metric['fpr@95'] = fpr_avg / ((args.task + 1) * args.task) 119 | metric['aupr'] = aupr_avg / ((args.task + 1) * args.task) 120 | 121 | print("baseline: ", baseline) 122 | print(metric) 123 | metric['average'] = sum_ / (args.task+1) 124 | print(sum_ / (args.task + 1)) 125 | 126 | import os 127 | 128 | with open(os.path.join(args.output_dir, f'{baseline}_results'), 'a') as f: 129 | f.write(json.dumps(metric) + '\n') 130 | 131 | for eval_t in range(args.task + 1): 132 | utils.write_result_eval(metric[eval_t]['acc'], eval_t, args) 133 | 134 | def scaling(args): 135 | 136 | for eval_t in range(args.task + 1): 137 | mls_score = torch.max(torch.tensor(args.train_logits[eval_t]), dim=1)[0] 138 | mds_score = utils.maha_score(args, torch.tensor(args.train_hidden[eval_t]), args.precision_list, args.feat_mean_list, eval_t) 139 | 140 | args.mls_scale[eval_t] = mls_score.mean().data 141 | args.mds_scale[eval_t] = mds_score.mean().data 142 | 143 | print(args.mls_scale) 144 | print(args.mds_scale) 145 | 146 | 147 | def calibration(args): 148 | 149 | tp_score = [] 150 | 151 | for task_mask in range(args.task + 1): 152 | 153 | test_logits = args.logits_dict[task_mask] 154 | test_hidden = args.features_dict[task_mask] 155 | loader = args.replay_loader 156 | 157 | score_in = TPLR(args, test_logits, test_hidden, loader, task_mask) 158 | 159 | tp_score.append(score_in) 160 | 161 | tp_score = torch.stack(tp_score, dim=1) 162 | tp_label = torch.tensor(args.replay_labels).cuda() // args.class_num 163 | 164 | 165 | args.calib_w = args.calib_w.cuda() 166 | args.calib_b = args.calib_b.cuda() 167 | tp_score = tp_score.cuda() 168 | args.calib_w.requires_grad = True 169 | args.calib_b.requires_grad = True 170 | 171 | optimizer = torch.optim.SGD([args.calib_w, args.calib_b], lr=0.01, momentum=0.8) 172 | 173 | tp_label = torch.tensor(args.replay_labels).cuda() // args.class_num 174 | 175 | for _ in range(100): 176 | 177 | cal_score = tp_score * args.calib_w + args.calib_b 178 | loss = F.cross_entropy(cal_score, tp_label) 179 | optimizer.zero_grad() 180 | loss.backward() 181 | optimizer.step() 182 | 183 | args.calib_w = args.calib_w.cpu().detach() 184 | args.calib_b = args.calib_b.cpu().detach() 185 | print(args.calib_w) 186 | print(args.calib_b) -------------------------------------------------------------------------------- /utils/sgd_hat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import required 3 | from torch.optim import SGD 4 | import numpy as np 5 | 6 | class SGD_hat(SGD): 7 | def __init__(self, params, lr=required, momentum=0, dampening=0, 8 | weight_decay=0, nesterov=False): 9 | super(SGD_hat, self).__init__(params, lr, momentum, dampening, 10 | weight_decay, nesterov) 11 | 12 | @torch.no_grad() 13 | def step(self, closure=None, hat=False): 14 | """Performs a single optimization step. 15 | 16 | Arguments: 17 | closure (callable, optional): A closure that reevaluates the model 18 | and returns the loss. 19 | """ 20 | loss = None 21 | if closure is not None: 22 | with torch.enable_grad(): 23 | loss = closure() 24 | 25 | for group in self.param_groups: 26 | weight_decay = group['weight_decay'] 27 | momentum = group['momentum'] 28 | dampening = group['dampening'] 29 | nesterov = group['nesterov'] 30 | 31 | for p in group['params']: 32 | if p.grad is None: 33 | continue 34 | d_p = p.grad 35 | if weight_decay != 0: 36 | d_p = d_p.add(p, alpha=weight_decay) 37 | if momentum != 0: 38 | param_state = self.state[p] 39 | if 'momentum_buffer' not in param_state: 40 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 41 | else: 42 | buf = param_state['momentum_buffer'] 43 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 44 | if nesterov: 45 | d_p = d_p.add(buf, alpha=momentum) 46 | else: 47 | d_p = buf 48 | if hat: 49 | if p.hat is not None: 50 | d_p = d_p * p.hat 51 | 52 | p.add_(d_p, alpha=-group['lr']) 53 | 54 | return loss 55 | 56 | def HAT_reg(args, masks): 57 | """ masks and self.mask_pre must have values in the same order """ 58 | reg, count = 0., 0. 59 | if args.mask_pre is not None: 60 | for m, mp in zip(masks, args.mask_pre.values()): 61 | aux = 1. - mp 62 | reg += (m * aux).sum() 63 | count += aux.sum() 64 | else: 65 | for m in masks: 66 | reg += m.sum() 67 | count += np.prod(m.size()).item() 68 | reg /= count 69 | return args.reg_lambda * reg 70 | 71 | def compensation(model, args, thres_cosh=50, s=1): 72 | """ Equation before Eq. (4) in the paper """ 73 | for n, p in model.named_parameters(): 74 | if 'ec' in n: 75 | if p.grad is not None: 76 | num = torch.cosh(torch.clamp(s * p.data, -thres_cosh, thres_cosh)) + 1 77 | den = torch.cosh(p.data) + 1 78 | p.grad *= args.smax / s * num / den 79 | 80 | def compensation_clamp(model, thres_emb=6): 81 | # Constrain embeddings 82 | for n, p in model.named_parameters(): 83 | if 'ec' in n: 84 | if p.grad is not None: 85 | p.data.copy_(torch.clamp(p.data, -thres_emb, thres_emb)) 86 | 87 | def cum_mask(smax, t, model, mask_pre): 88 | """ 89 | Keep track of mask values. 90 | This will be used later as a regularizer in the optimization 91 | """ 92 | try: 93 | model = model.module 94 | except AttributeError: 95 | model = model 96 | 97 | task_id = torch.tensor([t]).cuda() 98 | mask = {} 99 | for n, _ in model.named_parameters(): 100 | names = n.split('.') 101 | checker = [i for i in ['ec0', 'ec1', 'ec2'] if i in names] 102 | if names[0] == 'module': 103 | names = names[1:] 104 | if checker: 105 | if 'adapter' in n: 106 | gc1, gc2 = model.__getattr__(names[0])[int(names[1])].__getattr__(names[2]).mask(task_id, s=smax) 107 | if checker[0] == 'ec1': 108 | n = '.'.join(n.split('.')[:-1]) 109 | mask[n] = gc1.detach() 110 | mask[n].requires_grad = False 111 | elif checker[0] == 'ec2': 112 | n = '.'.join(n.split('.')[:-1]) 113 | mask[n] = gc2.detach() 114 | mask[n].requires_grad = False 115 | 116 | if mask_pre is None: 117 | mask_pre = {} 118 | for n in mask.keys(): 119 | mask_pre[n] = mask[n] 120 | else: 121 | for n in mask.keys(): 122 | mask_pre[n] = torch.max(mask_pre[n], mask[n]) 123 | return mask_pre 124 | 125 | def freeze_mask(P, t, model, mask_pre): 126 | """ 127 | Eq (2) in the paper. self.mask_back is a dictionary whose keys are 128 | the convolutions' parameter names. Each value of a key is a matrix, whose elements are 129 | approximately binary. 130 | """ 131 | try: 132 | model = model.module 133 | except AttributeError: 134 | model = model 135 | 136 | mask_back = {} 137 | for n, p in model.named_parameters(): 138 | names = n.split('.') 139 | if 'adapter' in n: # adapter1 or adapter2. adapter.ec1, adapter.ec2 140 | # e.g. n is blocks.1.adapter1.fc1.weight 141 | if 'fc1.weight' in n: 142 | mask_back[n] = 1 - mask_pre['.'.join(names[:-2]) + '.ec1'].data.view(-1, 1).expand_as(p) 143 | elif 'fc1.bias' in n: 144 | mask_back[n] = 1 - mask_pre['.'.join(names[:-2]) + '.ec1'].data.view(-1) 145 | elif 'fc2.weight' in n: 146 | post = mask_pre['.'.join(names[:-2]) + '.ec2'].data.view(-1, 1).expand_as(p) 147 | pre = mask_pre['.'.join(names[:-2]) + '.ec1'].data.view(1, -1).expand_as(p) 148 | mask_back[n] = 1 - torch.min(post, pre) 149 | elif 'fc2.bias' in n: 150 | mask_back[n] = 1 - mask_pre['.'.join(names[:-2]) + '.ec2'].view(-1) 151 | return mask_back 152 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | import numpy as np 4 | import torch 5 | import sklearn.covariance 6 | from sklearn import metrics 7 | import json 8 | import faiss 9 | import torch.nn.functional as F 10 | from tqdm.auto import tqdm 11 | 12 | 13 | def normalize(v): 14 | v = np.array(v) 15 | norm = np.linalg.norm(v, axis=-1).reshape(-1, 1) 16 | return v / (norm + 1e-9) 17 | 18 | 19 | def standardization(data): 20 | mu = np.mean(data, axis=0) 21 | sigma = np.std(data, axis=0) 22 | return (data - mu) / sigma 23 | 24 | 25 | def write_result(results, eval_t, args): 26 | 27 | progressive_main_path = os.path.join( 28 | args.output_dir + '/../', 'progressive_main_' + str(args.seed) 29 | ) 30 | progressive_til_path = os.path.join( 31 | args.output_dir + '/../', 'progressive_til_' + str(args.seed) 32 | ) 33 | progressive_tp_path = os.path.join( 34 | args.output_dir + '/../', 'progressive_tp_' + str(args.seed) 35 | ) 36 | 37 | if os.path.exists(progressive_main_path): 38 | eval_main = np.loadtxt(progressive_main_path) 39 | else: 40 | eval_main = np.zeros((args.ntasks, args.ntasks), dtype=np.float32) 41 | if os.path.exists(progressive_til_path): 42 | eval_til = np.loadtxt(progressive_til_path) 43 | else: 44 | eval_til = np.zeros((args.ntasks, args.ntasks), dtype=np.float32) 45 | if os.path.exists(progressive_tp_path): 46 | eval_tp = np.loadtxt(progressive_tp_path) 47 | else: 48 | eval_tp = np.zeros((args.ntasks, args.ntasks), dtype=np.float32) 49 | 50 | try: 51 | eval_main[args.task][eval_t] = results['accuracy'] 52 | eval_til[args.task][eval_t] = results['accuracy'] 53 | eval_tp[args.task][eval_t] = results['accuracy'] 54 | except: 55 | eval_main[args.task][eval_t] = results['cil_accuracy'] 56 | eval_til[args.task][eval_t] = results['til_accuracy'] 57 | eval_tp[args.task][eval_t] = results['TP_accuracy'] 58 | 59 | np.savetxt(progressive_main_path, eval_main, '%.4f', delimiter='\t') 60 | np.savetxt(progressive_til_path, eval_til, '%.4f', delimiter='\t') 61 | np.savetxt(progressive_tp_path, eval_tp, '%.4f', delimiter='\t') 62 | 63 | if args.task == args.ntasks - 1: 64 | final_main = os.path.join(args.output_dir + '/../', 'final_main_' + str(args.seed)) 65 | forward_main = os.path.join(args.output_dir + '/../', 'forward_main_' + str(args.seed)) 66 | 67 | final_til = os.path.join(args.output_dir + '/../', 'final_til_' + str(args.seed)) 68 | forward_til = os.path.join(args.output_dir + '/../', 'forward_til_' + str(args.seed)) 69 | 70 | final_tp = os.path.join(args.output_dir + '/../', 'final_tp_' + str(args.seed)) 71 | forward_tp = os.path.join(args.output_dir + '/../', 'forward_tp_' + str(args.seed)) 72 | 73 | with open(final_main, 'w') as final_main_file, open(forward_main, 'w') as forward_main_file: 74 | for j in range(eval_main.shape[1]): 75 | final_main_file.writelines(str(eval_main[-1][j]) + '\n') 76 | forward_main_file.writelines(str(eval_main[j][j]) + '\n') 77 | 78 | with open(final_til, 'w') as final_til_file, open(forward_til, 'w') as forward_til_file: 79 | for j in range(eval_til.shape[1]): 80 | final_til_file.writelines(str(eval_til[-1][j]) + '\n') 81 | forward_til_file.writelines(str(eval_til[j][j]) + '\n') 82 | 83 | with open(final_tp, 'w') as final_tp_file, open(forward_tp, 'w') as forward_tp_file: 84 | for j in range(eval_tp.shape[1]): 85 | final_tp_file.writelines(str(eval_tp[-1][j]) + '\n') 86 | forward_tp_file.writelines(str(eval_tp[j][j]) + '\n') 87 | 88 | 89 | def write_result_eval(results, eval_t, args): 90 | 91 | progressive_main_path = os.path.join( 92 | args.output_dir + '/../', 'progressive_main_' + str(args.seed) 93 | ) 94 | 95 | if os.path.exists(progressive_main_path): 96 | eval_main = np.loadtxt(progressive_main_path) 97 | else: 98 | eval_main = np.zeros((args.ntasks, args.ntasks), dtype=np.float32) 99 | 100 | eval_main[args.task][eval_t] = results 101 | 102 | np.savetxt(progressive_main_path, eval_main, '%.4f', delimiter='\t') 103 | 104 | if args.task == args.ntasks - 1: 105 | final_main = os.path.join(args.output_dir + '/../', 'final_main_' + str(args.seed)) 106 | forward_main = os.path.join(args.output_dir + '/../', 'forward_main_' + str(args.seed)) 107 | 108 | with open(final_main, 'w') as final_main_file, open(forward_main, 'w') as forward_main_file: 109 | for j in range(eval_main.shape[1]): 110 | final_main_file.writelines(str(eval_main[-1][j]) + '\n') 111 | forward_main_file.writelines(str(eval_main[j][j]) + '\n') 112 | 113 | 114 | def prepare_sequence_eval(args): 115 | with open(os.path.join('./sequence', args.sequence_file), 'r') as f: 116 | data = f.readlines()[args.idrandom] 117 | data = data.split() 118 | 119 | args.all_tasks = data 120 | args.ntasks = len(data) 121 | 122 | ckpt = args.base_dir + '/seq' + str(args.class_order) + "/seed" + str(args.seed) + \ 123 | '/' + str(args.baseline) + '/' + str(data[args.task]) + '/model' 124 | args.output_dir = args.base_dir + "/seq" + \ 125 | str(args.class_order) + "/seed" + str(args.seed) + "/" + str(args.baseline) + '/' + str(data[args.task]) 126 | 127 | args.prev_output = None 128 | args.model_name_or_path = ckpt 129 | 130 | print('output_dir: ', args.output_dir) 131 | print('prev_output: ', args.prev_output) 132 | print('args.model_name_or_path: ', args.model_name_or_path) 133 | 134 | return args 135 | 136 | 137 | def prepare_sequence_train(args): 138 | with open(os.path.join('./sequence', args.sequence_file), 'r') as f: 139 | data = f.readlines()[args.idrandom] 140 | data = data.split() 141 | 142 | args.task_name = data[args.task] 143 | args.all_tasks = data 144 | args.ntasks = len(data) 145 | args.output_dir = args.base_dir + '/seq' + \ 146 | str(args.class_order) + "/seed" + str(args.seed) + '/' + str(args.baseline) + '/' + str(data[args.task]) 147 | ckpt = args.base_dir + '/seq' + str(args.class_order) + "/seed" + str(args.seed) + \ 148 | '/' + str(args.baseline) + '/' + str(data[args.task-1]) + '/model' 149 | 150 | if args.task > 0: 151 | args.prev_output = args.base_dir + "/seq" + \ 152 | str(args.class_order) + "/seed" + str(args.seed) + "/" + str(args.baseline) + '/' + str(data[args.task-1]) 153 | args.model_name_or_path = ckpt 154 | else: 155 | args.prev_output = None 156 | args.model_name_or_path = None 157 | 158 | print('output_dir: ', args.output_dir) 159 | print('prev_output: ', args.prev_output) 160 | print('args.model_name_or_path: ', args.model_name_or_path) 161 | 162 | return args 163 | 164 | 165 | def load_in661_pretrain(args, target_model): 166 | """ 167 | target_model: the model we want to replace the parameters (most likely un-trained) 168 | """ 169 | if os.path.isfile(f'{args.base_dir}/deit_in661/best_checkpoint.pth'): 170 | checkpoint = torch.load(f'{args.base_dir}/deit_in661/best_checkpoint.pth', map_location='cpu') 171 | else: 172 | raise NotImplementedError("Cannot find pre-trained model") 173 | target = target_model.state_dict() 174 | pretrain = checkpoint['model'] 175 | transfer = {k: v for k, v in pretrain.items() if k in target and 'head' not in k} 176 | target.update(transfer) 177 | target_model.load_state_dict(target) 178 | 179 | 180 | def lookfor_model(args): 181 | 182 | ## load visual encoder ## 183 | if 'deit_small_patch16' in args.visual_encoder: 184 | from networks.vit_hat import deit_small_patch16_224 185 | model = deit_small_patch16_224(pretrained=False, num_classes=args.class_num * 186 | args.ntasks, latent=args.latent, args=args, hat='hat' in args.baseline) 187 | elif 'vit_small_patch16' in args.visual_encoder: 188 | from networks.vit_hat import vit_small_patch16_224 189 | model = vit_small_patch16_224(pretrained=False, num_classes=args.class_num * 190 | args.ntasks, latent=args.latent, args=args, hat='hat' in args.baseline) 191 | elif 'vit_base_patch16' in args.visual_encoder: 192 | from networks.vit_hat import vit_base_patch16_224 193 | model = vit_base_patch16_224(pretrained=False, num_classes=args.class_num * 194 | args.ntasks, latent=args.latent, args=args, hat='hat' in args.baseline) 195 | elif 'vit_tiny_patch16' in args.visual_encoder: 196 | from networks.vit_hat import vit_tiny_patch16_224 197 | model = vit_tiny_patch16_224(pretrained=False, num_classes=args.class_num * 198 | args.ntasks, latent=args.latent, args=args, hat='hat' in args.baseline) 199 | elif 'deit_tiny_patch16' in args.visual_encoder: 200 | from networks.vit_hat import deit_tiny_patch16_224 201 | model = deit_tiny_patch16_224(pretrained=False, num_classes=args.class_num * 202 | args.ntasks, latent=args.latent, args=args, hat='hat' in args.baseline) 203 | else: 204 | raise NotImplementedError 205 | 206 | checkpoint = torch.load(f'{args.base_dir}/pretrained/{args.visual_encoder}.pth', map_location='cpu') 207 | target = model.state_dict() 208 | 209 | if 'model' in checkpoint.keys(): 210 | checkpoint = checkpoint['model'] 211 | 212 | transfer = {k: v for k, v in checkpoint.items() if k in target and 'head' not in k} 213 | target.update(transfer) 214 | model.load_state_dict(target) 215 | 216 | 217 | ## load adapter or hat mask## 218 | if 'hat' in args.baseline: 219 | for _ in range(args.task): 220 | model.append_embeddings() 221 | 222 | if not args.training: # inference for the t-th task 223 | model.append_embeddings() 224 | 225 | if args.task > 0: # load the trained weights 226 | model.load_state_dict(torch.load(os.path.join(args.model_name_or_path), map_location='cpu')) 227 | 228 | if args.training: # training for the t-th task 229 | model.append_embeddings() 230 | 231 | if 'derpp' in args.baseline: 232 | if args.task > 0: 233 | model.load_state_dict(torch.load(os.path.join(args.model_name_or_path), map_location='cpu')) 234 | args.teacher_model = deepcopy(model) # used for get representation for replay sample 235 | 236 | return model 237 | 238 | 239 | def auroc(predictions, references): 240 | fpr, tpr, _ = metrics.roc_curve(references, predictions, pos_label=1) 241 | return metrics.auc(fpr, tpr) 242 | 243 | 244 | def acc(predictions, references): 245 | acc = metrics.accuracy_score(references, predictions) 246 | return acc 247 | 248 | def aupr(predictions, references): 249 | 250 | ind_indicator = np.zeros_like(references) 251 | ind_indicator[references != -1] = 1 252 | 253 | precision_in, recall_in, thresholds_in \ 254 | = metrics.precision_recall_curve(ind_indicator, predictions) 255 | 256 | aupr_in = metrics.auc(recall_in, precision_in) 257 | 258 | return aupr_in 259 | 260 | def fpr95(predictions, references, tpr=0.95): 261 | gt = np.ones_like(references) 262 | gt[references == -1] = 0 263 | fpr_list, tpr_list, threshold_list = metrics.roc_curve(gt, predictions) 264 | fpr = fpr_list[np.argmax(tpr_list >= tpr)] 265 | thresh = threshold_list[np.argmax(tpr_list >= tpr)] 266 | return fpr 267 | 268 | @torch.no_grad() 269 | def Mahainit(args, train_hidden, train_labels): 270 | group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False) 271 | feat_mean_list = {} # ntasks x ntasks x num_classes 272 | precision_list = {} 273 | feat_list = {} 274 | for train_t in tqdm(range(args.task + 1)): 275 | # prepare feat_mean_list, precision_lst 276 | 277 | feat_list[train_t] = {} 278 | feat_mean_list[train_t] = {} 279 | precision_list[train_t] = {} 280 | # for task_t in range(args.task + 1): 281 | # feat_list[train_t] = {} 282 | for feature, label in zip(train_hidden[train_t], train_labels[train_t]): 283 | feature = np.array(feature).reshape([-1, len(feature)]) 284 | if label not in feat_list[train_t].keys(): 285 | feat_list[train_t][label] = feature 286 | else: 287 | feat_list[train_t][label] = np.concatenate( 288 | (feat_list[train_t][label], feature), axis=0) 289 | 290 | feat_mean_list[train_t] = [ 291 | np.mean(feat_list[train_t][i], axis=0).tolist() for i in range(args.class_num)] 292 | precision_list[train_t] = None 293 | X = None 294 | for k in range(args.class_num): 295 | if X is None: 296 | X = feat_list[train_t][k] - feat_mean_list[train_t][k] 297 | else: 298 | X = np.concatenate((X, feat_list[train_t][k] - feat_mean_list[train_t][k]), axis=0) 299 | # find inverse 300 | group_lasso.fit(X) 301 | precision = group_lasso.precision_ 302 | precision_list[train_t] = precision.tolist() 303 | 304 | return feat_mean_list, precision_list 305 | 306 | 307 | def maha_score(args, test_sample, precision_list, feat_mean_list, task_mask): 308 | for class_idx in range(args.class_num): 309 | zero_f = test_sample - torch.Tensor(feat_mean_list[str(task_mask)][class_idx]) 310 | term_gau = 20.0 / torch.mm(torch.mm(zero_f, torch.Tensor(precision_list[str(task_mask)])), 311 | zero_f.t()).diag() 312 | if class_idx == 0: 313 | noise_gaussian_score = term_gau.view(-1, 1) 314 | else: 315 | noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1, 1)), 1) 316 | 317 | score_in, _ = torch.max(noise_gaussian_score, dim=1) 318 | return score_in 319 | 320 | 321 | def load_maha(args, train_hidden, train_labels): 322 | print("start mahainit...") 323 | try: 324 | with open(os.path.join(args.output_dir, 'feat_mean_list'), 'r') as f: 325 | feat_mean_list = json.load(f) 326 | with open(os.path.join(args.output_dir, 'precision_list'), 'r') as f: 327 | precision_list = json.load(f) 328 | except: 329 | feat_mean_list, precision_list = Mahainit(args, train_hidden, train_labels) 330 | with open(os.path.join(args.output_dir, 'feat_mean_list'), 'w') as f: 331 | json.dump(feat_mean_list, f) 332 | with open(os.path.join(args.output_dir, 'precision_list'), 'w') as f: 333 | json.dump(precision_list, f) 334 | with open(os.path.join(args.output_dir, 'feat_mean_list'), 'r') as f: 335 | feat_mean_list = json.load(f) 336 | with open(os.path.join(args.output_dir, 'precision_list'), 'r') as f: 337 | precision_list = json.load(f) 338 | print("finish mahainit!!") 339 | return feat_mean_list, precision_list 340 | 341 | 342 | --------------------------------------------------------------------------------