├── BERT_codebase ├── continual_learning_one_head.py ├── continual_learning_utils.py ├── dataset_utils.py ├── good_id_yahoo_test2.npy ├── good_id_yahoo_train2.npy ├── model_utils.py ├── run_prog_prompts.sh ├── train_cl.py ├── train_cl2.py └── train_soft_prompt.py ├── LICENSE ├── README.md ├── T5_codebase ├── prompt_debug.ipynb ├── t5_continual.py ├── t5_dataset.py ├── train_prompt.py └── train_t5_cl.py ├── datasets ├── README.md └── src │ └── data │ └── amazon │ └── Archive.zip ├── environment.yaml └── images ├── illustration.png └── test.png /BERT_codebase/continual_learning_one_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pandas as pd 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | import logging, os, argparse 7 | import matplotlib.pyplot as plt 8 | 9 | import dataset_utils, model_utils, continual_learning_utils 10 | from itertools import cycle 11 | from transformers import AdamW 12 | from copy import deepcopy 13 | from datasets import load_metric 14 | 15 | from transformers import AdamW, get_constant_schedule_with_warmup 16 | 17 | def change_string(str): 18 | #creating negative samples for NSP by randomly splitting positive samples 19 | #and swapping two halves 20 | if 102 in str: 21 | str.remove(102) 22 | if 102 in str: 23 | str.remove(102) 24 | 25 | len1 = len(str) 26 | if len1 == 1: 27 | cut = 1 28 | else: 29 | cut = np.random.randint(1, len1) 30 | str = str[cut:] + [102] + str[:cut] + [102] 31 | return str 32 | 33 | 34 | def get_permutation_batch(src, src_mask, device, seq_len=512): 35 | #create negative samples for Next Sentence Prediction 36 | batch_size = src.size(0) 37 | length = src.size(1) 38 | dst = [] 39 | dst_mask = [] 40 | lbl = [] 41 | for i in range(batch_size): 42 | cur = src[i] 43 | mask = src_mask[i].tolist() 44 | first_pad = (cur.tolist() + [0]).index(0) 45 | cur = cur[1:first_pad].tolist() 46 | cur = change_string(cur) 47 | lbl.append(1) 48 | 49 | padding = [0] * (length - len(cur) - 1) 50 | inp = torch.tensor([101] + cur + padding) 51 | dst.append(inp[:seq_len]) 52 | dst_mask.append(torch.tensor(mask)) 53 | return torch.stack(dst).to(device), torch.stack(dst_mask).to(device), torch.tensor(lbl).to(device) 54 | 55 | 56 | 57 | def compute_class_offsets(tasks, task_classes): 58 | ''' 59 | :param tasks: a list of the names of tasks, e.g. ["amazon", "yahoo"] 60 | :param task_classes: the corresponding numbers of classes, e.g. [5, 10] 61 | :return: the class # offsets, e.g. [0, 5] 62 | Here we merge the labels of yelp and amazon, i.e. the class # offsets 63 | for ["amazon", "yahoo", "yelp"] will be [0, 5, 0] 64 | ''' 65 | task_num = len(tasks) 66 | offsets = [0] * task_num 67 | prev = -1 68 | total_classes = 0 69 | for i in range(task_num): 70 | if tasks[i] in ["amazon", "yelp_review_full"]: 71 | if prev == -1: 72 | prev = i 73 | offsets[i] = total_classes 74 | total_classes += task_classes[i] 75 | else: 76 | offsets[i] = offsets[prev] 77 | else: 78 | offsets[i] = total_classes 79 | total_classes += task_classes[i] 80 | return total_classes, offsets 81 | 82 | 83 | 84 | def pass_batch_reg(self, batch, device, cls_idx=0, only_output_loss=False): 85 | #if self.cls_idx_override!=None: 86 | #cls_idx = self.cls_idx_override 87 | 88 | model = self.model 89 | optimizer = self.optimizer 90 | scheduler = self.scheduler 91 | tokenizer = self.tokenizer 92 | 93 | batch = {k: v.to(device) for k, v in batch.items()} 94 | 95 | out = model.bert(**{'input_ids': batch['input_ids'], 96 | 'attention_mask': batch['attention_mask'], 97 | 'token_type_ids': batch['token_type_ids'], 98 | }) 99 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 100 | 101 | if only_output_loss: 102 | # just returning the loss for subsequent operations (e.g. sum) 103 | # outputs = model.heads[task](outputs = out, 104 | # cls_output = cls_output, 105 | # return_dict = True, 106 | # labels = batch['labels']) 107 | 108 | # loss = outputs.loss 109 | #return loss, cls_output 110 | return cls_output 111 | 112 | else: 113 | # performing optimization step here 114 | loss.backward() 115 | # only allowing updates for added special token if required 116 | if self.freeze_weights == 1 and self.freeze_except == 'word_embeddings': 117 | k = len(self.special_tokens_list) 118 | model.bert.embeddings.word_embeddings.weight.grad[:-k] = 0 119 | 120 | optimizer.step() 121 | scheduler.step() 122 | optimizer.zero_grad() 123 | 124 | 125 | 126 | class Predictor(torch.nn.Module): 127 | def __init__(self, num_class, hidden_size): 128 | super(Predictor, self).__init__() 129 | 130 | self.num_class = num_class 131 | 132 | self.dis = torch.nn.Sequential( 133 | torch.nn.Linear(hidden_size, self.num_class) 134 | ) 135 | 136 | def forward(self, z): 137 | return self.dis(z) 138 | 139 | 140 | 141 | class ContinualLearnerIDBR: 142 | def __init__(self, 143 | model_name, 144 | task_list, 145 | batch_size=8, 146 | select_k_per_class=-1, 147 | memory_perc=0, 148 | #block_attn=0, 149 | freeze_weights=0, 150 | freeze_except='word_embeddings', 151 | lr=3e-5, #2e-5 152 | seq_len=512, 153 | cls_idx_override=None, 154 | early_stopping=True, 155 | offsets=[], 156 | total_classes=-1, 157 | hidden_size=128, 158 | tasks_data_dict=None, 159 | regcoe=0.5, 160 | regcoe_rply=5.0, 161 | ): 162 | 163 | self.task_to_num_labels = { 164 | 'cola': 2, 165 | 'rte': 2, 166 | 'mrpc': 2, 167 | 'qqp': 2, 168 | 'sst2': 2, 169 | 'qnli': 2, 170 | 'mnli': 3, 171 | 172 | 'scicite': 3, 173 | 'imdb': 2, 174 | 175 | 'cb': 3, 176 | 'copa': 2, 177 | 'wic': 2, 178 | 'boolq': 2, 179 | 'multirc': 2, 180 | 181 | 'yelp_review_full': 5, 182 | 'ag_news': 4, 183 | 'yahoo_answers_topics': 10, 184 | 'amazon': 5, 185 | 'dbpedia_14': 14, 186 | 'dbpedia': 14, 187 | 188 | 'yelp': 5, 189 | 'ag': 4, 190 | 'yahoo': 10, 191 | } 192 | self.glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', \ 193 | 'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax'] 194 | self.superglue = ['cb', 'copa', 'wic', 'wsc', 'boolq', 'record', 'multirc'] 195 | 196 | self.task_list = task_list 197 | num_labels_list = [self.task_to_num_labels[t] for t in self.task_list] 198 | self.total_classes, self.offsets = compute_class_offsets(self.task_list, num_labels_list) 199 | 200 | self.num_labels = [self.task_to_num_labels[t] for t in self.task_list] 201 | self.freeze_weights = freeze_weights 202 | self.lr = lr 203 | self.task_learning_rate = 5e-4 204 | self.seq_len = seq_len 205 | self.batch_size = batch_size 206 | self.cls_idx_override = cls_idx_override 207 | self.select_k_per_class = select_k_per_class 208 | self.memory_perc = memory_perc 209 | self.freeze_except = freeze_except 210 | self.early_stopping = early_stopping 211 | 212 | if torch.cuda.is_available(): 213 | self.device = torch.device("cuda") 214 | else: 215 | self.device = torch.device("cpu") 216 | 217 | self.model_name = model_name #"bert-base-uncased" 218 | self.trainer = model_utils.ModelForCL(self.model_name, 219 | tasks=self.task_list, 220 | num_labels=self.num_labels, 221 | #blockwise_causal_attention= (block_attn==1), 222 | freeze_weights= (self.freeze_weights==1), 223 | freeze_except=self.freeze_except, 224 | lr=self.lr, 225 | num_repeats=0, #self.num_repeats, 226 | max_length=self.seq_len, # max sequence length in #tokens 227 | cls_idx_override=self.cls_idx_override, 228 | ) 229 | self.trainer.pass_batch_reg = lambda batch, device, cls_idx, only_output_loss: \ 230 | pass_batch_reg(self.trainer, batch, device, cls_idx, only_output_loss) 231 | #self.trainer.pass_batch_reg = MethodType(lambda batch, task, device, cls_idx, only_output_loss: 232 | # pass_batch_reg(batch, task, device, 233 | # cls_idx=cls_idx, only_output_loss=only_output_loss)) 234 | 235 | #self.trainer.model.add_classification_head('giant', num_labels=total_classes) 236 | self.trainer.model.to(self.device) # model to cuda 237 | self.tokenizer = self.trainer.tokenizer 238 | if tasks_data_dict==None: 239 | self.tasks_data_dict = self.get_tasks_data_dict(self.select_k_per_class, memory_perc=self.memory_perc) 240 | else: 241 | print('Data is ready ', list(tasks_data_dict)) 242 | self.tasks_data_dict = tasks_data_dict 243 | #### ADDING REG UTILS #### 244 | self.base_model = deepcopy(self.trainer.model) # bert before training 245 | 246 | self.hidden_size = hidden_size 247 | self.General_Encoder = nn.Sequential( 248 | nn.Linear(768, self.hidden_size), 249 | nn.Tanh() 250 | ) 251 | 252 | self.Specific_Encoder = nn.Sequential( 253 | nn.Linear(768, self.hidden_size), 254 | nn.Tanh() 255 | ) 256 | 257 | n_class = self.total_classes 258 | self.cls_classifier = nn.Sequential( 259 | nn.Linear(self.hidden_size * 2, n_class) 260 | ) 261 | 262 | n_tasks = len(self.task_list) 263 | self.task_classifier = nn.Sequential( 264 | nn.Linear(self.hidden_size, n_tasks) 265 | ) 266 | 267 | self.cls_CR = torch.nn.CrossEntropyLoss() 268 | self.predictor = Predictor(2, hidden_size=self.hidden_size).to(self.device) # NSP loss predictor 269 | self.nsp_CR = torch.nn.CrossEntropyLoss() 270 | 271 | if self.early_stopping: 272 | self.best_model = deepcopy(self.trainer.model.state_dict()) # saving best model 273 | self.best_GenEnc = deepcopy(self.General_Encoder.state_dict()) 274 | self.best_SpeEnc = deepcopy(self.Specific_Encoder.state_dict()) 275 | self.best_cls_classifier = deepcopy(self.cls_classifier.state_dict()) 276 | self.best_task_classifier = deepcopy(self.task_classifier.state_dict()) 277 | self.best_predictor = deepcopy(self.predictor.state_dict()) 278 | self.best_acc = 0.0 # best avg accuracy on seen tasks 279 | 280 | self.trainer.optimizer = AdamW( 281 | [ 282 | {"params": self.trainer.model.bert.parameters(), "lr": self.lr, "weight_decay": 0.01}, 283 | {"params": self.General_Encoder.parameters(), "lr": self.lr, "weight_decay": 0.01}, 284 | {"params": self.Specific_Encoder.parameters(), "lr": self.lr, "weight_decay": 0.01}, 285 | {"params": self.cls_classifier.parameters(), "lr": self.lr, "weight_decay": 0.01}, 286 | {"params": self.task_classifier.parameters(), "lr": self.task_learning_rate, "weight_decay": 0.01}, 287 | ] 288 | ) 289 | 290 | self.trainer.optimizer_P = AdamW( 291 | [ 292 | {"params": self.predictor.parameters(), "lr": self.lr, "weight_decay": 0.01}, 293 | ] 294 | ) 295 | 296 | self.trainer.scheduler = get_constant_schedule_with_warmup(self.trainer.optimizer, 1000) 297 | self.trainer.scheduler_P = get_constant_schedule_with_warmup(self.trainer.optimizer_P, 1000) 298 | 299 | self.regcoe = regcoe 300 | self.regcoe_rply = regcoe_rply 301 | 302 | ##### ####### ##### #### 303 | 304 | 305 | # def get_tasks_data_dict(self, k, memory_perc=0): 306 | # # if k==-1: use all data, otherwise use k examples from class 307 | # trainer = self.trainer 308 | # tasks_data_dict = {} 309 | 310 | # k_val = -1 if k==-1 else int(k*0.15) 311 | # for j, task in enumerate(self.task_list): 312 | # tasks_data_dict[task] = {} 313 | # print(task) 314 | # du = dataset_utils.Dataset(task=task, tokenizer=trainer.tokenizer, idbr_preprocessing=True) # turn on idbr_preprocessing flag 315 | # data_params = {'repeats': trainer.num_repeats, 316 | # 'batch_size': self.batch_size, 317 | # 'max_length': trainer.max_length, 318 | # 'label_offset': self.offsets[j], 319 | # #'select_k_per_class': k 320 | # } 321 | # benchmark = 'glue' if task in self.glue_datasets else None 322 | # val_split = 'validation' if task in self.glue_datasets else 'test' 323 | 324 | # dataloader_train = du.get_dataset(benchmark=benchmark, split='train', select_k_per_class=k, **data_params) 325 | # if memory_perc>0: 326 | # k_mem = int(len(dataloader_train)*memory_perc) 327 | # dataloader_mem = du.get_dataset(benchmark=benchmark, split='train', 328 | # select_k_per_class=k_mem, **data_params) 329 | 330 | # if k!=-1: 331 | # if task == 'dbpedia': k_val = int(k*0.1) 332 | # elif task != 'dbpedia': k_val = int(k*0.15) 333 | # dataloader_val, dataloader_test = du.get_dataset(benchmark=benchmark, split=val_split, 334 | # select_k_per_class=k_val, return_test_subset=True, 335 | # **data_params) 336 | 337 | # tasks_data_dict[task]['train'] = dataloader_train 338 | # if memory_perc>0: tasks_data_dict[task]['train_mem'] = dataloader_mem # for data replay 339 | # tasks_data_dict[task]['val'] = dataloader_val 340 | # tasks_data_dict[task]['test'] = dataloader_test 341 | 342 | # return tasks_data_dict 343 | 344 | def get_tasks_data_dict(self, k, memory_perc=0): 345 | # if k==-1: use all data, otherwise use k examples from class 346 | trainer = self.trainer 347 | tasks_data_dict = {} 348 | current_task_progressive_prompt = [] 349 | 350 | k_val = -1 if k==-1 else max(int(k*0.15), 500) 351 | for task in self.task_list: 352 | tasks_data_dict[task] = {} 353 | print(task) 354 | current_task_progressive_prompt += trainer.prefix_tokens_list[task] 355 | print(current_task_progressive_prompt) 356 | du = dataset_utils.Dataset(task=task, tokenizer=trainer.tokenizer) 357 | tid = self.task_list.index(task) # task id 358 | data_params = {'repeats': 0, #trainer.num_repeats, 359 | 'batch_size': self.batch_size, 360 | 'max_length': trainer.max_length, 361 | 'prefix_tokens_list': [], 362 | 'prefix_len': 0, # prompt * task_quantity 363 | 'do_repeats': False, # only applies to the first task in repeated set-up (if we format it according to repeats) 364 | } 365 | benchmark = 'glue' if task in self.glue_datasets else 'super_glue' if task in self.superglue else None 366 | val_split = 'validation' if (task in self.glue_datasets or task in self.superglue) else 'test' 367 | dataloader_train = du.get_dataset(benchmark=benchmark, split='train', select_k_per_class=k, **data_params) 368 | if memory_perc>0: 369 | k_mem = max( int(len(dataloader_train)*self.batch_size*memory_perc), 1) # no less than 1 370 | dataloader_mem = du.get_dataset(benchmark=benchmark, split='train', 371 | select_k_per_class=k_mem, **data_params) 372 | 373 | if k!=-1: 374 | if task in ['dbpedia', 'sst2']: k_val = int(k*0.1) 375 | else: k_val = int(k*0.15) 376 | dataloader_val, dataloader_test = du.get_dataset(benchmark=benchmark, split=val_split, 377 | select_k_per_class=k_val, return_test_subset=True, 378 | **data_params) 379 | 380 | tasks_data_dict[task]['train'] = dataloader_train 381 | if memory_perc>0: tasks_data_dict[task]['train_mem'] = dataloader_mem # for data replay 382 | tasks_data_dict[task]['val'] = dataloader_val 383 | tasks_data_dict[task]['test'] = dataloader_test 384 | 385 | return tasks_data_dict 386 | 387 | 388 | # returns metric corresponding to the task 389 | def task_to_metric_key(self, task): 390 | if task not in self.glue_datasets: 391 | return 'accuracy' 392 | 393 | if task in ['qqp', 'mrpc']: 394 | return 'f1' 395 | 396 | elif 'mnli' in task or task == 'cola': 397 | return 'matthews_correlation' 398 | 399 | else: 400 | return 'accuracy' 401 | 402 | 403 | def eval_repr_split(self, trainer, dataloader_val, task, metric, cls_idx=0): 404 | if trainer.cls_idx_override!=None: 405 | cls_idx = trainer.cls_idx_override 406 | model = trainer.model 407 | tokenizer = trainer.tokenizer 408 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 409 | model.eval().to(device) 410 | 411 | if metric==None: 412 | #if task in self.glue_datasets: 413 | if task in ['qqp', 'mrpc', 'cola'] or 'mnli' in task: 414 | metric = load_metric('glue', task) 415 | else: 416 | metric = load_metric('accuracy') 417 | print(metric.name) 418 | #if trainer.num_repeats>=1: 419 | # pos = trainer.get_position_ids(dataloader_val, device) 420 | 421 | for i, batch in enumerate(tqdm(dataloader_val)): 422 | batch = {k: v.to(device) for k, v in batch.items()} 423 | with torch.no_grad(): 424 | # Forward pass, calculate logit predictions 425 | inp_dict = {'input_ids': batch['input_ids'], 426 | 'attention_mask': batch['attention_mask'], 427 | 'token_type_ids': batch['token_type_ids'], 428 | } 429 | 430 | out = model.bert(**inp_dict) 431 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 432 | # outputs = model.heads[task](outputs = out, 433 | # cls_output = cls_output, 434 | # return_dict = True, 435 | # labels = batch['labels']) 436 | 437 | # predictions = torch.argmax(outputs.logits, dim=-1) 438 | general_features = self.General_Encoder(cls_output) 439 | specific_features = self.Specific_Encoder(cls_output) 440 | features = torch.cat([general_features, specific_features], dim=1) 441 | cls_pred = self.cls_classifier(features) 442 | 443 | # Calculate classification loss 444 | _, pred_cls = cls_pred.max(1) 445 | #y = batch['labels'] 446 | #correct_cls = pred_cls.eq(y.view_as(pred_cls)).sum().item() 447 | metric.add_batch(predictions=pred_cls, references=batch['labels']) 448 | 449 | try: 450 | result = metric.compute() 451 | metric_key = self.task_to_metric_key(task) # we want to return value float (not dict metric -> value) 452 | result = result[metric_key] 453 | except: result=0.0 # could not compute (maybe all labels are the same and f1 not computing) 454 | return result 455 | 456 | 457 | def eval_on_tasks(self, val_scores, cls_idx=0, split='val', repr_split=True): 458 | self.trainer.model.eval() 459 | self.cls_classifier.eval() 460 | self.task_classifier.eval() 461 | self.General_Encoder.eval() 462 | self.Specific_Encoder.eval() 463 | self.predictor.eval() 464 | 465 | for task in list(self.tasks_data_dict): 466 | dataloader_val = self.tasks_data_dict[task][split] 467 | if repr_split: 468 | result = self.eval_repr_split(self.trainer, dataloader_val, task, None, cls_idx) 469 | else: 470 | result = self.trainer.eval(dataloader_val, 'giant', None, cls_idx) 471 | print(task, ' result = ',result) 472 | val_scores[task].append(result) 473 | 474 | return val_scores 475 | 476 | 477 | 478 | def update_best_model(self, val_scores, new_task = False): 479 | #idx = list(self.tasks_data_dict).index(curr_task) # look tasks up to curr task 480 | #seen_tasks = [t for t in list(self.tasks_data_dict)[:idx+1]] 481 | seen_tasks = list(self.tasks_data_dict) 482 | avg_acc = np.mean([val_scores[task][-1] for task in seen_tasks]) 483 | # only update if we are starting a new task OR if acc gets better 484 | if avg_acc > self.best_acc: # or new_task: 485 | print('NEW BEST MODEL acc=', avg_acc) 486 | self.best_acc = avg_acc 487 | self.best_model = deepcopy(self.trainer.model.state_dict()) 488 | self.best_GenEnc = deepcopy(self.General_Encoder.state_dict()) 489 | self.best_SpeEnc = deepcopy(self.Specific_Encoder.state_dict()) 490 | self.best_cls_classifier = deepcopy(self.cls_classifier.state_dict()) 491 | self.best_task_classifier = deepcopy(self.task_classifier.state_dict()) 492 | self.best_predictor = deepcopy(self.predictor.state_dict()) 493 | 494 | 495 | def get_loss_from_representation(self, bert_cls_embedding, batch, 496 | regspe, reggen, tskcoe, nspcoe, disen, 497 | task, data_replay_freq, 498 | cls_idx=0): 499 | if self.cls_idx_override!=None: 500 | cls_idx = self.cls_idx_override 501 | 502 | task_id = list(self.tasks_data_dict).index(task) 503 | replay = data_replay_freq!=-1 504 | 505 | if disen: 506 | x, mask = batch['input_ids'], batch['attention_mask'] 507 | p_x, p_mask, p_lbl = get_permutation_batch(x, mask, self.device, seq_len=self.seq_len) 508 | 509 | #x = torch.cat([x, p_x], dim=0) 510 | #mask = torch.cat([mask, p_mask], dim=0) 511 | r_lbl = torch.zeros_like(p_lbl) 512 | nsp_lbl = torch.cat([r_lbl, p_lbl], dim=0) 513 | 514 | y = torch.cat([batch['labels'], batch['labels']], dim=0) 515 | t = torch.tensor([task_id]*self.batch_size*2).to(self.device) # correct task ids 516 | 517 | p_out = self.trainer.model.bert(**{'input_ids': p_x, 518 | 'attention_mask': p_mask, 519 | 'token_type_ids': batch['token_type_ids'], 520 | }) 521 | p_cls_output = p_out.last_hidden_state[:,cls_idx,:].to(self.device) 522 | bert_cls_embedding = torch.cat([bert_cls_embedding, p_cls_output], dim=0) 523 | else: 524 | y = batch['labels'] 525 | t = torch.tensor([task_id]*self.batch_size).to(self.device) # correct task ids 526 | 527 | general_features = self.General_Encoder(bert_cls_embedding) 528 | specific_features = self.Specific_Encoder(bert_cls_embedding) 529 | 530 | features = torch.cat([general_features, specific_features], dim=1) 531 | cls_pred = self.cls_classifier(features) 532 | task_pred = self.task_classifier(specific_features) 533 | 534 | # Calculate classification loss 535 | _, pred_cls = cls_pred.max(1) 536 | #y = batch['labels'] 537 | #correct_cls = pred_cls.eq(y.view_as(pred_cls)).sum().item() 538 | cls_loss = self.cls_CR(cls_pred, y) 539 | 540 | 541 | reg_loss = torch.tensor(0.0).to(self.device) 542 | task_loss = torch.tensor(0.0).to(self.device) 543 | nsp_loss = torch.tensor(0.0).to(self.device) 544 | 545 | if task_id >0 and regspe>0 and reggen>0: 546 | # Calculate reg loss 547 | base_out = self.base_model.bert(**{'input_ids': batch['input_ids'], 548 | 'attention_mask': batch['attention_mask'], 549 | 'token_type_ids': batch['token_type_ids'], 550 | }) 551 | base_emb = base_out.last_hidden_state[:,cls_idx,:].to(self.device) 552 | 553 | old_g_fea = self.General_Encoder(base_emb) 554 | old_s_fea = self.Specific_Encoder(base_emb) 555 | lim = old_s_fea.shape[0] # previously was self.batch_size 556 | reg_loss += regspe * torch.nn.functional.mse_loss(specific_features[:lim], old_s_fea) + \ 557 | reggen * torch.nn.functional.mse_loss(general_features[:lim], old_g_fea) 558 | if replay and task_id > 0: 559 | reg_loss *= self.regcoe_rply 560 | elif not replay and task_id > 0: 561 | reg_loss *= self.regcoe 562 | 563 | # Calculate task loss only when in replay batch 564 | if task_id > 0 and replay and tskcoe>0: 565 | task_pred = task_pred[:, :task_id + 1] 566 | _, pred_task = task_pred.max(1) 567 | # correct_task = pred_task.eq(t.view_as(pred_task)).sum().item() 568 | task_loss += tskcoe * self.cls_CR(task_pred, t[:task_pred.shape[0]]) 569 | 570 | # Calculate nsp loss 571 | if disen and nspcoe>0: 572 | nsp_output = self.predictor(general_features) 573 | nsp_loss += nspcoe * self.nsp_CR(nsp_output, nsp_lbl) 574 | #_, nsp_pred = nsp_output.max(1) 575 | #nsp_correct = nsp_pred.eq(nsp_lbl.view_as(nsp_pred)).sum().item() 576 | #nsp_acc = nsp_correct * 1.0 / (batch_size * 2.0) 577 | 578 | loss = cls_loss + reg_loss + task_loss + nsp_loss 579 | return loss 580 | 581 | 582 | def train_on_one_task(self, 583 | task, 584 | data_replay_freq = -1, # if -1 no data replay, else replay after N samples 585 | num_epochs = 5, 586 | regspe=0.5, 587 | reggen=0.5, 588 | tskcoe=1.0, 589 | nspcoe=1.0, 590 | disen=True): 591 | 592 | val_scores = {x: [] for x in list(self.tasks_data_dict)} 593 | device = self.device 594 | 595 | self.Specific_Encoder.to(self.device) 596 | self.General_Encoder.to(self.device) 597 | self.cls_classifier.to(self.device) 598 | self.task_classifier.to(self.device) 599 | self.trainer.model.to(self.device) 600 | self.predictor.to(self.device) 601 | 602 | optimizer = self.trainer.optimizer 603 | scheduler = self.trainer.scheduler 604 | 605 | for epoch in range(num_epochs): 606 | print(epoch) 607 | self.trainer.model.train() 608 | self.cls_classifier.train() 609 | self.task_classifier.train() 610 | self.General_Encoder.train() 611 | self.Specific_Encoder.train() 612 | self.predictor.train() 613 | 614 | if data_replay_freq != -1: 615 | print('Creating generators for previous tasks ...') 616 | tasks_to_generators = {} 617 | curr_task_num = list(self.tasks_data_dict).index(task) 618 | for idx in np.arange(curr_task_num): 619 | prev_task = list(self.tasks_data_dict)[idx] 620 | print(prev_task) 621 | tasks_to_generators[prev_task] = iter(self.tasks_data_dict[prev_task]['train_mem']) 622 | 623 | for i, batch in enumerate(tqdm(self.tasks_data_dict[task]['train'])): 624 | batch = {k: v.to(device) for k, v in batch.items()} 625 | # we will ignore the default loss 626 | bert_cls_embedding = self.trainer.pass_batch_reg(batch, self.device, 627 | cls_idx=0, only_output_loss=True) 628 | bert_cls_embedding = bert_cls_embedding.to(self.device) 629 | 630 | #### ADDING REG UTILS #### 631 | loss = self.get_loss_from_representation(bert_cls_embedding, batch, 632 | regspe, reggen, tskcoe, nspcoe, disen, 633 | task, data_replay_freq) 634 | loss.backward() 635 | optimizer.step() 636 | scheduler.step() 637 | optimizer.zero_grad() 638 | #### #### #### 639 | 640 | # performing data replay on all previous tasks 641 | if data_replay_freq != -1 and i%data_replay_freq == 0: 642 | for prev_task in tasks_to_generators: 643 | generator_mem1 = tasks_to_generators[prev_task] 644 | try: 645 | # Samples the batch 646 | b = next(generator_mem1) 647 | except StopIteration: 648 | # restart the generator if the previous generator is exhausted. 649 | generator_mem1 = iter(self.tasks_data_dict[prev_task]['train_mem']) 650 | tasks_to_generators[prev_task] = generator_mem1 651 | b = next(generator_mem1) 652 | 653 | b = {k: v.to(device) for k, v in b.items()} 654 | #self.trainer.pass_batch(b, 'giant', self.device, cls_idx=0) 655 | bert_cls_embedding = self.trainer.pass_batch_reg(b, self.device, 656 | cls_idx=0, only_output_loss=True) 657 | bert_cls_embedding = bert_cls_embedding.to(self.device) 658 | loss = self.get_loss_from_representation(bert_cls_embedding, b, 659 | regspe, reggen, tskcoe, nspcoe, disen, 660 | task, data_replay_freq) 661 | loss.backward() 662 | optimizer.step() 663 | scheduler.step() 664 | optimizer.zero_grad() 665 | ###################### 666 | 667 | #if i%250 == 0 and i>0: # check val accuracy every 250 iterations 668 | val_scores = self.eval_on_tasks(val_scores, cls_idx=0) 669 | if self.early_stopping: 670 | self.update_best_model(val_scores) 671 | 672 | return val_scores 673 | 674 | 675 | def continual_training(self, 676 | #tasks=[], 677 | num_epochs=5, 678 | data_replay_freq=-1, 679 | regspe=0.5, 680 | reggen=0.5, 681 | tskcoe=1.0, 682 | nspcoe=1.0, 683 | disen=True): 684 | results_dict = {} 685 | print('Continual training') 686 | for i, task in enumerate(self.task_list): 687 | if i>0 and self.early_stopping: 688 | self.update_best_model(val_scores, new_task=True) 689 | 690 | print('\n\nTASK ', task) 691 | val_scores = self.train_on_one_task(task, 692 | num_epochs=num_epochs, 693 | data_replay_freq=data_replay_freq, 694 | regspe=regspe, 695 | reggen=reggen, 696 | tskcoe=tskcoe, 697 | nspcoe=nspcoe, 698 | disen=disen) 699 | results_dict[i] = val_scores 700 | # loading the best model across all epochs (based on val acc) 701 | # in case of early stopping 702 | if self.early_stopping: 703 | self.trainer.model.load_state_dict(deepcopy(self.best_model)) 704 | self.General_Encoder.load_state_dict(deepcopy(self.best_GenEnc)) 705 | self.Specific_Encoder.load_state_dict(deepcopy(self.best_SpeEnc)) 706 | self.cls_classifier.load_state_dict(deepcopy(self.best_cls_classifier)) 707 | self.task_classifier.load_state_dict(deepcopy(self.best_task_classifier)) 708 | self.predictor.load_state_dict(deepcopy(self.best_predictor)) 709 | 710 | # update regularization model 711 | self.base_model = deepcopy(self.trainer.model) # bert before training 712 | 713 | # final eval on test set 714 | test_scores = {x: [] for x in list(self.tasks_data_dict)} 715 | test_scores = self.eval_on_tasks(test_scores, cls_idx=0, split='test') 716 | results_dict['test'] = test_scores 717 | return results_dict 718 | 719 | 720 | 721 | 722 | def multi_task_training(self, num_epochs=5, cls_idx=0): 723 | tasks_data_dict = self.tasks_data_dict 724 | val_scores = {x: [] for x in list(tasks_data_dict)} 725 | # getting index of the largest dataset (other datasets will be cycled) 726 | task_lengths = [len(tasks_data_dict[t]['train'])*self.batch_size for t in list(tasks_data_dict)] 727 | idx_biggest_task = np.argmax(task_lengths) 728 | n_tasks = len(list(tasks_data_dict)) 729 | 730 | results_dict = {} 731 | device = self.device 732 | 733 | self.cls_classifier.to(self.device) 734 | self.trainer.model.to(self.device) 735 | self.task_classifier.to(self.device) 736 | self.General_Encoder.to(self.device) 737 | self.Specific_Encoder.to(self.device) 738 | self.predictor.to(self.device) 739 | 740 | for epoch in range(num_epochs): 741 | print(epoch) 742 | self.trainer.model.train() 743 | self.cls_classifier.train() 744 | self.task_classifier.train() 745 | self.General_Encoder.train() 746 | self.Specific_Encoder.train() 747 | self.predictor.train() 748 | 749 | dataloaders_list = [tasks_data_dict[t]['train'] if j==idx_biggest_task else cycle(tasks_data_dict[t]['train']) \ 750 | for j, t in enumerate(tasks_data_dict)] 751 | mlt_dataloader = zip(*dataloaders_list) 752 | 753 | max_task = np.max([len(tasks_data_dict[t]['train']) for t in list(tasks_data_dict)]) 754 | pbar = tqdm(total=max_task) 755 | for i, batch_combined in enumerate(mlt_dataloader): 756 | loss_combined = 0 757 | 758 | for task_num in range(n_tasks): 759 | task = list(self.tasks_data_dict)[task_num] 760 | batch = {k: v.to(device) for k, v in batch_combined[task_num].items()} 761 | bert_cls_embedding = self.trainer.pass_batch_reg(batch, device, cls_idx=cls_idx, only_output_loss=True) 762 | bert_cls_embedding = bert_cls_embedding.to(device) 763 | loss = self.get_loss_from_representation(bert_cls_embedding, batch, 764 | 0, 0, 0, 0, False, task, -1) 765 | loss_combined += loss 766 | 767 | loss_combined.backward() 768 | 769 | # only allowing updates for added special token if required 770 | #if self.trainer.freeze_weights == 1 and self.trainer.freeze_except == 'word_embeddings': 771 | #k = len(trainer.special_tokens_list) 772 | #model.bert.embeddings.word_embeddings.weight.grad[:-k] = 0 773 | #model.bert.embeddings.word_embeddings.weight.grad[:-1] = 0 774 | 775 | self.trainer.optimizer.step() 776 | self.trainer.scheduler.step() 777 | self.trainer.optimizer.zero_grad() 778 | pbar.update(1) 779 | 780 | val_scores = self.eval_on_tasks(val_scores, cls_idx=cls_idx) 781 | if self.early_stopping: 782 | self.update_best_model(val_scores) 783 | 784 | results_dict[epoch] = val_scores 785 | pbar.close() 786 | 787 | # final eval on test set 788 | if self.early_stopping: 789 | self.trainer.model.load_state_dict(deepcopy(self.best_model)) 790 | self.General_Encoder.load_state_dict(deepcopy(self.best_GenEnc)) 791 | self.Specific_Encoder.load_state_dict(deepcopy(self.best_SpeEnc)) 792 | self.cls_classifier.load_state_dict(deepcopy(self.best_cls_classifier)) 793 | self.task_classifier.load_state_dict(deepcopy(self.best_task_classifier)) 794 | self.predictor.load_state_dict(deepcopy(self.best_predictor)) 795 | test_scores = {x: [] for x in list(self.tasks_data_dict)} 796 | test_scores = self.eval_on_tasks(test_scores, cls_idx=cls_idx, split='test') 797 | results_dict['test'] = test_scores 798 | 799 | return results_dict 800 | -------------------------------------------------------------------------------- /BERT_codebase/continual_learning_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pandas as pd 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | import logging, os, argparse 7 | import matplotlib.pyplot as plt 8 | 9 | import dataset_utils, model_utils 10 | from itertools import cycle 11 | from copy import deepcopy 12 | 13 | 14 | class ResMLP(torch.nn.Module): 15 | def __init__(self, bottleneck_size, module_type='MLP1'): 16 | super().__init__() 17 | if module_type=='MLP1': 18 | self.module = nn.Sequential( 19 | nn.Linear(768, bottleneck_size), 20 | nn.ReLU(), 21 | nn.Linear(bottleneck_size, 768), 22 | ) 23 | 24 | elif module_type=='MLP2': 25 | self.module = nn.Sequential( 26 | nn.Linear(768, bottleneck_size), 27 | nn.ReLU(), 28 | nn.Linear(bottleneck_size, bottleneck_size), 29 | nn.Tanh(), 30 | nn.Linear(bottleneck_size, 768), 31 | ) 32 | 33 | elif module_type=='transformer': 34 | device = 'cuda' 35 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=2, dropout=0.05).to(device) 36 | self.module = nn.TransformerEncoder(self.encoder_layer, num_layers=2).to(device) 37 | 38 | def forward(self, inputs): 39 | return self.module(inputs) + inputs 40 | 41 | 42 | def get_prefix_net(bottleneck_size = 800, network_type='MLP1'): 43 | 44 | if network_type == 'MLP1': 45 | prefix_MLP = nn.Sequential( 46 | nn.Linear(768, bottleneck_size), 47 | nn.ReLU(), 48 | #nn.Linear(bottleneck_size, bottleneck_size), 49 | #nn.Tanh(), 50 | nn.Linear(bottleneck_size, 768), 51 | ) 52 | 53 | elif network_type == 'MLP2': 54 | prefix_MLP = nn.Sequential( 55 | nn.Linear(768, bottleneck_size), 56 | nn.ReLU(), 57 | nn.Linear(bottleneck_size, bottleneck_size), 58 | nn.Tanh(), 59 | nn.Linear(bottleneck_size, 768), 60 | ) 61 | 62 | elif network_type == 'transformer': 63 | device = 'cuda' 64 | encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=2, dropout=0.05).to(device) 65 | prefix_MLP = nn.TransformerEncoder(encoder_layer, num_layers=2).to(device) 66 | 67 | elif 'residual' in network_type: 68 | prefix_MLP = ResMLP(bottleneck_size, module_type=network_type.split('_')[1]) 69 | 70 | return prefix_MLP 71 | 72 | 73 | 74 | class ContinualLearner: 75 | def __init__(self, 76 | model_name, 77 | task_list, 78 | batch_size=8, 79 | select_k_per_class=-1, 80 | memory_perc=0, 81 | #block_attn=0, 82 | prefix_len=0, 83 | freeze_weights=0, 84 | freeze_except='word_embeddings', 85 | lr=2e-5, 86 | seq_len=512, 87 | cls_idx_override=None, 88 | early_stopping=True, 89 | prefix_MLP='None', 90 | do_repeats=False, # default setting is without repeats 91 | bottleneck_size=800, # bottleneck size in case of using MLP reparametrization 92 | same_prompt=False, 93 | ): 94 | 95 | self.task_to_num_labels = { 96 | 'cola': 2, 97 | 'rte': 2, 98 | 'mrpc': 2, 99 | 'qqp': 2, 100 | 'sst2': 2, 101 | 'qnli': 2, 102 | 'mnli': 3, 103 | 104 | 'scicite': 3, 105 | 'imdb': 2, 106 | 107 | 'cb': 3, 108 | 'copa': 2, 109 | 'wic': 2, 110 | 'boolq': 2, 111 | 'multirc': 2, 112 | 113 | 'yelp_review_full': 5, 114 | 'ag_news': 4, 115 | 'yahoo_answers_topics': 10, 116 | 'dbpedia_14': 14, 117 | 'amazon': 5, 118 | 119 | 'yelp': 5, 120 | 'ag': 4, 121 | 'yahoo': 10, 122 | 'proc_yahoo': 10, 123 | 'dbpedia': 14, 124 | } 125 | self.glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', \ 126 | 'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax'] 127 | self.superglue = ['cb', 'copa', 'wic', 'wsc', 'boolq', 'record', 'multirc'] 128 | 129 | self.task_list = task_list 130 | #self.special_tokens_list = ["[CLS"+str(i+1)+"]" for i in range(len(self.task_list))] #[ "[CLS1]" ], 131 | #self.special_tokens_list = [] 132 | self.num_labels = [self.task_to_num_labels[t] for t in self.task_list] 133 | self.freeze_weights = freeze_weights 134 | self.prefix_len = prefix_len 135 | self.lr = lr 136 | self.seq_len = seq_len 137 | self.batch_size = batch_size 138 | self.cls_idx_override = cls_idx_override 139 | self.same_prompt = same_prompt 140 | self.select_k_per_class = select_k_per_class 141 | self.memory_perc = memory_perc 142 | self.freeze_except = freeze_except 143 | self.early_stopping = early_stopping 144 | self.do_repeats = do_repeats 145 | 146 | if torch.cuda.is_available(): 147 | self.device = torch.device("cuda") 148 | else: 149 | self.device = torch.device("cpu") 150 | 151 | self.model_name = model_name #"bert-base-uncased" 152 | num_repeats = len(self.task_list)-1 if self.do_repeats else 0 153 | print('Max repeats = ', num_repeats) 154 | 155 | if prefix_MLP == 'None': 156 | prefix_MLPs = None 157 | else: 158 | print('Using MLP reparametrization with bottleneck = ', bottleneck_size) 159 | prefix_MLPs = {t: get_prefix_net(bottleneck_size = bottleneck_size, network_type=prefix_MLP) for t in self.task_list} 160 | 161 | self.trainer = model_utils.ModelForCL(self.model_name, 162 | tasks=self.task_list, 163 | num_labels=self.num_labels, 164 | #blockwise_causal_attention= (block_attn==1), 165 | prefix_len=self.prefix_len, 166 | freeze_weights= (self.freeze_weights==1), 167 | freeze_except=self.freeze_except, 168 | lr=self.lr, 169 | num_repeats=num_repeats, # default 0 170 | max_length=self.seq_len, # max sequence length in #tokens 171 | cls_idx_override=self.cls_idx_override, 172 | prefix_MLPs=prefix_MLPs, 173 | same_prompt=self.same_prompt, 174 | ) 175 | self.trainer.model.to(self.device) # model to cuda 176 | if prefix_MLPs!=None: 177 | for t in self.task_list: 178 | self.trainer.prefix_MLPs[t].to(self.device) 179 | if self.early_stopping: 180 | self.best_model = deepcopy(self.trainer.model.state_dict()) # saving best model 181 | self.best_acc = 0.0 # best avg accuracy on seen tasks 182 | self.tokenizer = self.trainer.tokenizer 183 | self.tasks_data_dict = self.get_tasks_data_dict(self.select_k_per_class, memory_perc=self.memory_perc) 184 | 185 | 186 | def get_tasks_data_dict(self, k, memory_perc=0): 187 | # if k==-1: use all data, otherwise use k examples from class 188 | trainer = self.trainer 189 | tasks_data_dict = {} 190 | current_task_progressive_prompt = [] 191 | 192 | k_val = -1 if k==-1 else max(int(k*0.15), 500) 193 | for task in self.task_list: 194 | tasks_data_dict[task] = {} 195 | print(task) 196 | if self.same_prompt: # same prompt for all tasks 197 | current_task_progressive_prompt = trainer.prefix_tokens_list[0] 198 | else: 199 | current_task_progressive_prompt += trainer.prefix_tokens_list[task] 200 | print(current_task_progressive_prompt) 201 | du = dataset_utils.Dataset(task=task, tokenizer=trainer.tokenizer) 202 | if self.same_prompt: 203 | tid = 0 204 | else: 205 | tid = self.task_list.index(task) # task id 206 | data_params = {'repeats': 0 if not self.do_repeats else tid, #trainer.num_repeats, 207 | 'batch_size': self.batch_size, 208 | 'max_length': trainer.max_length, 209 | 'prefix_tokens_list': current_task_progressive_prompt, 210 | 'prefix_len': self.prefix_len * (tid+1), # prompt * task_quantity 211 | 'do_repeats': self.do_repeats, # only applies to the first task in repeated set-up (if we format it according to repeats) 212 | } 213 | benchmark = 'glue' if task in self.glue_datasets else 'super_glue' if task in self.superglue else None 214 | val_split = 'validation' if (task in self.glue_datasets or task in self.superglue) else 'test' 215 | dataloader_train = du.get_dataset(benchmark=benchmark, split='train', select_k_per_class=k, **data_params) 216 | if memory_perc>0: 217 | k_mem = max(int(len(dataloader_train)*memory_perc), 2) 218 | dataloader_mem = du.get_dataset(benchmark=benchmark, split='train', 219 | select_k_per_class=k_mem, **data_params) 220 | 221 | if k!=-1: 222 | if task in ['dbpedia', 'sst2']: k_val = int(k*0.1) 223 | else: k_val = int(k*0.15) 224 | dataloader_val, dataloader_test = du.get_dataset(benchmark=benchmark, split=val_split, 225 | select_k_per_class=k_val, return_test_subset=True, 226 | **data_params) 227 | 228 | tasks_data_dict[task]['train'] = dataloader_train 229 | if memory_perc>0: tasks_data_dict[task]['train_mem'] = dataloader_mem # for data replay 230 | tasks_data_dict[task]['val'] = dataloader_val 231 | tasks_data_dict[task]['test'] = dataloader_test 232 | 233 | return tasks_data_dict 234 | 235 | 236 | def change_attention_mask_for_tasks(self, task_num): 237 | #if self.do_repeats: 238 | # overriding attention mask for the next task 239 | print('updating attention mask for tasks #', task_num) 240 | num_repeats = task_num 241 | repeat_length = self.prefix_len + self.seq_len 242 | self.trainer.override_attention_mask(num_repeats, repeat_length, self.device) 243 | 244 | 245 | 246 | def eval_on_tasks(self, val_scores, split='val', prompt_tuning=True, original_task_id=None, tasks_to_eval=None): 247 | self.trainer.model.eval() 248 | if self.prefix_len>0 and self.trainer.prefix_MLPs != None: 249 | for task in list(self.tasks_data_dict): # put all MLPs into eval mode 250 | self.trainer.prefix_MLPs[task].eval() 251 | 252 | if tasks_to_eval==None: # if not specified, eval on all tasks 253 | tasks_to_eval = self.tasks_data_dict 254 | 255 | for task in list(tasks_to_eval): 256 | dataloader_val = self.tasks_data_dict[task][split] 257 | if prompt_tuning: # special eval for prompts (we use custom pos ids) 258 | if self.same_prompt: 259 | tid=0 260 | else: 261 | tid = self.task_list.index(task) 262 | if self.do_repeats: 263 | self.change_attention_mask_for_tasks(tid) # change attention mask in case of "repeats" set up 264 | pos_id = self.get_pos_id(tid) 265 | cls_idx = self.seq_len + self.prefix_len*tid if not self.do_repeats \ 266 | else self.seq_len + (self.seq_len + self.prefix_len) * tid 267 | #else (self.seq_len + self.prefix_len)*tid 268 | result = self.trainer.eval_with_prompt(dataloader_val, task, None, 269 | cls_idx=cls_idx, 270 | custom_pos_ids=True, 271 | pos_ids=pos_id) 272 | else: # regular eval for fine-tuning 273 | cls_idx = 0 274 | result = self.trainer.eval(dataloader_val, task, None, cls_idx) 275 | print(task, ' result = ',result) 276 | val_scores[task].append(result) 277 | 278 | # restore original attention mask for the current task after evaluation 279 | if self.do_repeats and original_task_id!=None: 280 | self.change_attention_mask_for_tasks(original_task_id) 281 | print('restored attn mask for task ', original_task_id) 282 | 283 | return val_scores 284 | 285 | 286 | 287 | def update_best_model(self, curr_task, val_scores, tasks_to_eval=None): 288 | #idx = list(self.tasks_data_dict).index(curr_task) # look tasks up to curr task 289 | #seen_tasks = [t for t in list(self.tasks_data_dict)[:idx+1]] 290 | #avg_acc = np.mean([val_scores[task][-1] for task in seen_tasks]) 291 | if tasks_to_eval==None: 292 | tasks_to_eval = self.task_list 293 | avg_acc = np.mean([val_scores[task][-1] for task in tasks_to_eval]) 294 | # only update if we are starting a new task OR if acc gets better 295 | if avg_acc > self.best_acc: 296 | print('NEW BEST MODEL acc=', avg_acc) 297 | self.best_acc = avg_acc 298 | self.best_model = deepcopy(self.trainer.model.state_dict()) 299 | 300 | 301 | # for prompt tuning set up we use custom position ids 302 | # Hello world :) [pad] [pad] ... [pad] [pre0_1] [pre0_2] [pre0_3] [pre1_1] [pre1_2] [pre1_3] 303 | # 1 2 3 4 5 6 7 400 0 401 402 0 403 404 304 | # def get_pos_id(self, task_id): 305 | # s = self.seq_len+1 306 | # pos_id = list(np.arange(1, s)) 307 | # for k in range(task_id+1): 308 | # pos_id += [0] + list(np.arange(s, s + self.prefix_len-1)) 309 | # s = s + self.prefix_len - 1 310 | # return torch.tensor(pos_id) 311 | 312 | # for prompt tuning set up we use custom position ids 313 | def get_pos_id(self, task_id): 314 | s = self.seq_len+1 315 | 316 | if not self.do_repeats: 317 | # (progressive) prompt tuning set-up 318 | # Hello world :) [pad] [pad] ... [pad] [pre0_1] [pre0_2] [pre0_3] [pre1_1] [pre1_2] [pre1_3] 319 | # 1 2 3 4 5 6 7 400 0 401 402 0 403 404 320 | pos_id = list(np.arange(1, s)) 321 | for k in range(task_id+1): 322 | pos_id += [0] + list(np.arange(s, s + self.prefix_len-1)) 323 | s = s + self.prefix_len - 1 324 | return torch.tensor(pos_id) 325 | 326 | else: 327 | # prompt tuning with repeats 328 | # Hello world :) [pad] [pad] ... [pad] [pre0_1] [pre0_2] [pre0_3] | Hello world :) [pad] [pad] ... [pad][pre1_1] [pre1_2] [pre1_3] 329 | # 1 2 3 4 5 6 7 400 0 401 402 | 1 2 3 4 5 6 7 400 0 401 402 330 | pos_id = list(np.arange(1, s)) + [0] + list(np.arange(s, s + self.prefix_len-1)) 331 | pos_id *= task_id+1 332 | return torch.tensor(pos_id) 333 | 334 | 335 | 336 | def create_memory_replay_generators(self, task, split='train_mem'): # creating previous tasks memory buffers 337 | print('Creating generators for previous tasks ...') 338 | tasks_to_generators = {} 339 | curr_task_num = list(self.tasks_data_dict).index(task) 340 | for idx in np.arange(curr_task_num): 341 | prev_task = list(self.tasks_data_dict)[idx] 342 | print(prev_task) 343 | tasks_to_generators[prev_task] = iter(self.tasks_data_dict[prev_task][split]) 344 | return tasks_to_generators 345 | 346 | 347 | def memory_replay(self, tasks_to_generators, cls_idx): 348 | # for each memory buffer in tasks_to_generators perform memory replay 349 | for prev_task in tasks_to_generators: 350 | generator_mem1 = tasks_to_generators[prev_task] 351 | try: 352 | # Samples the batch 353 | b = next(generator_mem1) 354 | except StopIteration: 355 | # restart the generator if the previous generator is exhausted. 356 | generator_mem1 = iter(self.tasks_data_dict[prev_task]['train_mem']) 357 | tasks_to_generators[prev_task] = generator_mem1 358 | b = next(generator_mem1) 359 | 360 | b = {k: v.to(self.device) for k, v in b.items()} 361 | self.trainer.pass_batch(b, prev_task, self.device, cls_idx=cls_idx) 362 | 363 | 364 | def train_on_one_task(self, 365 | task, 366 | data_replay_freq = -1, # if -1 no data replay, else replay after N samples 367 | prompt_tuning = True, 368 | num_epochs = 5): 369 | self.best_acc = 0.0 # our baseline accuracy is 0 370 | val_scores = {x: [] for x in list(self.tasks_data_dict)} 371 | device = self.device 372 | 373 | if prompt_tuning: 374 | if self.same_prompt: 375 | task_id = 0 376 | else: 377 | task_id = self.task_list.index(task) 378 | pos_id = self.get_pos_id(task_id) 379 | cls_idx = self.seq_len + self.prefix_len*task_id if not self.do_repeats \ 380 | else self.seq_len + (self.seq_len + self.prefix_len) * task_id 381 | else: 382 | task_id = None # we do not need task id for eval in case of regular fine-tuning 383 | 384 | for epoch in range(num_epochs): 385 | print(epoch) 386 | self.trainer.model.train().to(device) 387 | if self.prefix_len>0 and self.trainer.prefix_MLPs != None: 388 | self.trainer.prefix_MLPs[task].train().to(device) 389 | 390 | if data_replay_freq != -1: 391 | tasks_to_generators = self.create_memory_replay_generators(task, split='train_mem') 392 | 393 | for i, batch in enumerate(tqdm(self.tasks_data_dict[task]['train'])): 394 | batch = {k: v.to(device) for k, v in batch.items()} 395 | 396 | if prompt_tuning: # tune only soft prompt 397 | batch['position_ids'] = pos_id.to(self.device) # custom pos ids for prompts 398 | self.trainer.pass_batch_with_prompt(batch, task, self.device, 399 | prefix_len=self.prefix_len, 400 | cls_idx=cls_idx, 401 | custom_pos_ids=True) 402 | else: # regular fine-tuning 403 | self.trainer.pass_batch(batch, task, self.device, cls_idx=0) 404 | 405 | # performing data replay on all previous tasks 406 | if data_replay_freq != -1 and i%data_replay_freq == 0: 407 | self.memory_replay(tasks_to_generators, cls_idx=0) 408 | ###################### 409 | 410 | # eval only on curr task (others are static) 411 | val_scores = self.eval_on_tasks(val_scores, split='val', prompt_tuning=prompt_tuning, original_task_id=task_id, tasks_to_eval=[task]) 412 | if self.early_stopping: 413 | self.update_best_model(task, val_scores, tasks_to_eval=[task]) # update best model based on curr task acc improvement 414 | 415 | return val_scores 416 | 417 | 418 | 419 | def continual_training(self, 420 | #tasks=[], 421 | num_epochs=5, 422 | data_replay_freq=-1, 423 | prompt_tuning=True, 424 | prompt_init='None', 425 | save_prompt_path='None', 426 | save_results_path='None', 427 | ): 428 | results_dict = {} 429 | print('Continual training') 430 | if self.trainer.prefix_MLPs != None and prompt_tuning: 431 | cl_params = ['mlp', 'head'] 432 | else: 433 | cl_params = ['head'] 434 | 435 | for i, task in enumerate(self.task_list): 436 | self.trainer.freeze_unfreeze_mlps([x for x in self.task_list if x!=task], blocks=cl_params, requires_grad=False) # freezing MLPs & head for all tasks 437 | self.trainer.freeze_unfreeze_mlps([task], blocks=cl_params, requires_grad=True) # unfreezing current task MLP & head 438 | 439 | print('\n\nTASK ', task) 440 | val_scores = self.train_on_one_task(task, 441 | num_epochs=num_epochs, 442 | data_replay_freq=data_replay_freq, 443 | prompt_tuning=prompt_tuning) 444 | results_dict[i] = val_scores 445 | # loading the best model across all epochs (based on val acc) 446 | # in case of early stopping 447 | if self.early_stopping: 448 | self.trainer.model.load_state_dict(deepcopy(self.best_model)) 449 | 450 | if prompt_tuning: 451 | # update wte matrix so we don't override the learned prompt with non-trained model 452 | self.trainer.update_baseline_model_emb() 453 | if save_prompt_path != 'None': 454 | self.trainer.save_curr_task_emb(task_idx_curr=i, save_path=os.path.join(save_prompt_path, 'prompt_'+self.task_list[i])) 455 | if prompt_init != 'None' and task != self.task_list[-1]: # smart prompt initialization for the next prompts 456 | print('Initializing new task prompt from the currently finished task') 457 | self.trainer.init_new_prompt(task_idx_curr=i) 458 | 459 | if self.do_repeats: 460 | # overriding attention mask for the next task 461 | self.change_attention_mask_for_tasks(i+1) 462 | 463 | if save_results_path != 'None': 464 | test_scores = {x: [] for x in list(self.tasks_data_dict)} 465 | test_scores = self.eval_on_tasks(test_scores, split='test', prompt_tuning=prompt_tuning, original_task_id=len(self.task_list)-1) 466 | results_dict['test'] = test_scores 467 | np.save(os.path.join(save_results_path, 'results_dict_'+str(task)+'.npy'), results_dict) 468 | 469 | # final eval on test set 470 | test_scores = {x: [] for x in list(self.tasks_data_dict)} 471 | test_scores = self.eval_on_tasks(test_scores, split='test', prompt_tuning=prompt_tuning, original_task_id=len(self.task_list)-1) 472 | results_dict['test'] = test_scores 473 | return results_dict 474 | 475 | 476 | 477 | def multi_task_training(self, num_epochs=5, cls_idx=0): 478 | tasks_data_dict = self.tasks_data_dict 479 | val_scores = {x: [] for x in list(tasks_data_dict)} 480 | # getting index of the largest dataset (other datasets will be cycled) 481 | task_lengths = [len(tasks_data_dict[t]['train'])*self.batch_size for t in list(tasks_data_dict)] 482 | idx_biggest_task = np.argmax(task_lengths) 483 | n_tasks = len(list(tasks_data_dict)) 484 | 485 | results_dict = {} 486 | val_scores = {x: [] for x in list(self.tasks_data_dict)} 487 | device = self.device 488 | 489 | for epoch in range(num_epochs): 490 | print(epoch) 491 | 492 | dataloaders_list = [tasks_data_dict[t]['train'] if j==idx_biggest_task else cycle(tasks_data_dict[t]['train']) \ 493 | for j, t in enumerate(tasks_data_dict)] 494 | mlt_dataloader = zip(*dataloaders_list) 495 | 496 | max_task = np.max([len(tasks_data_dict[t]['train']) for t in list(tasks_data_dict)]) 497 | pbar = tqdm(total=max_task) 498 | for i, batch_combined in enumerate(mlt_dataloader): 499 | loss_combined = 0 500 | 501 | for task_num in range(n_tasks): 502 | batch = {k: v.to(device) for k, v in batch_combined[task_num].items()} 503 | loss = self.trainer.pass_batch(batch, list(tasks_data_dict)[task_num], self.device, cls_idx=cls_idx, only_output_loss=True) 504 | loss_combined += loss 505 | 506 | loss_combined.backward() 507 | 508 | # only allowing updates for added special token if required 509 | if self.trainer.freeze_weights == 1 and self.trainer.freeze_except == 'word_embeddings': 510 | k = len(trainer.special_tokens_list) 511 | model.bert.embeddings.word_embeddings.weight.grad[:-k] = 0 512 | #model.bert.embeddings.word_embeddings.weight.grad[:-1] = 0 513 | 514 | self.trainer.optimizer.step() 515 | self.trainer.scheduler.step() 516 | self.trainer.optimizer.zero_grad() 517 | pbar.update(1) 518 | 519 | val_scores = self.eval_on_tasks(val_scores, prompt_tuning=False, original_task_id=None) 520 | results_dict[epoch] = val_scores 521 | pbar.close() 522 | 523 | # final eval on test set 524 | test_scores = {x: [] for x in list(self.tasks_data_dict)} 525 | test_scores = self.eval_on_tasks(test_scores, split='test', prompt_tuning=False, original_task_id=None) 526 | results_dict['test'] = test_scores 527 | 528 | return results_dict 529 | -------------------------------------------------------------------------------- /BERT_codebase/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import datasets 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | class Dataset: 8 | def __init__(self, tokenizer, task, idbr_preprocessing=False): 9 | self.task = task 10 | self.tokenizer = tokenizer 11 | self.idbr_preprocessing = idbr_preprocessing # do idbr-style sentence preprocessing (split sentence in 2 parts and add sep token) 12 | 13 | self.task_to_keys = { 14 | "cola": ("sentence", None), 15 | "mnli": ("premise", "hypothesis"), 16 | "mnli-mm": ("premise", "hypothesis"), 17 | "mrpc": ("sentence1", "sentence2"), 18 | #"qnli": ("question", "sentence"), 19 | "qnli": ("text1", "text2"), 20 | "qqp": ("question1", "question2"), 21 | "rte": ("sentence1", "sentence2"), 22 | "sst2": ("sentence", None), 23 | "stsb": ("sentence1", "sentence2"), 24 | "wnli": ("sentence1", "sentence2"), 25 | 26 | "scicite": ("sectionName", "string"), 27 | "imdb": ("text", None), 28 | 29 | "cb": ("premise", "hypothesis"), 30 | "boolq": ("passage", "question"), 31 | "copa": ('choice1', 'choice2', 'premise', 'question'), 32 | "wic": ("start1", "end1", "sentence1", "start2", "end2", "sentence2", "word"), 33 | "wsc": ("span1_text", "span1_index", "span2_text", "span2_index", "text"), 34 | "multirc": ("question", "answer", "paragraph"), 35 | 36 | "ag_news": ("text", None), 37 | "yelp_review_full": ("text", None), 38 | "yahoo_answers_topics": ("question_content", "best_answer"), 39 | "dbpedia_14": ("title", "content"), 40 | 41 | "ag": ("content", None), 42 | "yelp": ("content", None), 43 | "yahoo": ("content", None), 44 | "dbpedia": ("content", None), 45 | "amazon": ("content", None), 46 | } 47 | 48 | # self.sentence1_key, self.sentence2_key = self.task_to_keys[task] 49 | 50 | def preprocess_function(self, examples, max_length=100): 51 | #sentence1_key, sentence2_key = self.task_to_keys[self.task] 52 | sentence_keys = self.task_to_keys[self.task] 53 | 54 | if self.idbr_preprocessing and sentence_keys[1]==None: 55 | sentence1_key, sentence2_key = self.task_to_keys[self.task] 56 | tokenized_res = self.tokenizer.tokenize(examples[sentence1_key]) 57 | ids = self.tokenizer.convert_tokens_to_ids(tokenized_res)[:max_length-3] 58 | len1 = len(ids) // 2 59 | x = [101] + ids[:len1] + [102] + ids[len1:] + [102] 60 | mask = [1] * len(x) 61 | 62 | padding = [0] * (max_length - len(x)) 63 | x += padding 64 | mask += padding 65 | 66 | d = {'input_ids': x, 'attention_mask': mask, 'token_type_ids': [0]*max_length} 67 | 68 | assert len(x) == max_length 69 | assert len(mask) == max_length 70 | 71 | return d 72 | 73 | if self.task == "yahoo_answers_topics": 74 | return self.tokenizer(examples["question_title"] + '[SEP]' + examples["question_content"] + '[SEP]' + \ 75 | examples["best_answer"], 76 | truncation=True, 77 | padding='max_length', 78 | max_length=max_length) 79 | 80 | if self.task in ["copa", "wic", "wsc", "multirc"]: 81 | text = ('[SEP]').join([str(examples[x]) for x in self.task_to_keys[self.task]]) 82 | return self.tokenizer(text, 83 | truncation=True, 84 | padding='max_length', 85 | max_length=max_length) 86 | 87 | # for all other tasks we have 2 sentence keys 88 | sentence1_key, sentence2_key = self.task_to_keys[self.task] 89 | if sentence2_key is None: 90 | return self.tokenizer(examples[sentence1_key], 91 | truncation=True, 92 | #padding=False, 93 | padding='max_length', 94 | max_length=max_length) 95 | return self.tokenizer(examples[sentence1_key], examples[sentence2_key], 96 | truncation=True, 97 | #padding=False, 98 | padding='max_length', 99 | max_length=max_length) 100 | 101 | 102 | 103 | def encode_with_repeats(self, examples, prefix_tokens_list=[], repeats=0, prefix_len=0, max_length=100, do_repeats=False): 104 | # max length defines input length not counting prompt (in case of repeats, each repeat will be of max length) 105 | #text = examples['sentence'] 106 | #inputs = tokenizer(text, truncation=True, padding='max_length', max_length=max_length) 107 | if prefix_len==0: 108 | inputs = self.preprocess_function(examples, max_length=max_length) # add 1 since we will remove CLS token 109 | 110 | if prefix_len>0: 111 | soft_prompt = [] 112 | inputs = self.preprocess_function(examples, max_length=max_length+1) # add 1 since we will remove CLS token 113 | for j in range(prefix_len): 114 | tokenized = self.tokenizer.tokenize(prefix_tokens_list[j]) 115 | #cls1_id = self.tokenizer.convert_tokens_to_ids(tokenized)[0] 116 | soft_prompt.append(self.tokenizer.convert_tokens_to_ids(tokenized)[0]) 117 | 118 | if repeats == 0 and not do_repeats: 119 | #print('We got soft prompt: ', soft_prompt) 120 | L = max_length + prefix_len # total len 121 | # for n in range(len(inputs['input_ids'])): 122 | # inputs['input_ids'][n] = inputs['input_ids'][n][1:][:max_length] + soft_prompt # ignore CLS token since we have one in prefix 123 | # for key in ['token_type_ids', 'attention_mask']: 124 | # inputs[key][n] = inputs[key][n][1:][:max_length] + inputs[key][n][0:1]*prefix_len 125 | 126 | # if batched == False 127 | inputs['input_ids'] = inputs['input_ids'][1:][:max_length] + soft_prompt # ignore CLS token since we have one in prefix 128 | for key in ['token_type_ids', 'attention_mask']: 129 | inputs[key] = inputs[key][1:][:max_length] + inputs[key][0:1]*prefix_len 130 | 131 | if repeats>0 or do_repeats: # if we have 1+ repeats or current sentence is the 0th repeat (should be formatted accordingly) 132 | assert prefix_len>0 133 | # cls_j_id = {} 134 | # for j in range(repeats): 135 | # tokenized = self.tokenizer.tokenize(prefix_tokens_list[j]) 136 | # cls_j_id[j] = self.tokenizer.convert_tokens_to_ids(tokenized)[0] 137 | 138 | #for n in range(len(inputs['input_ids'])): # we are not currently using batched version 139 | repeat = inputs['input_ids'].copy()[1:][:max_length] 140 | repeat_full = [] 141 | single_prefix_len = prefix_len//(repeats+1) 142 | for j in range(repeats+1): 143 | repeat_full += soft_prompt[single_prefix_len*j : single_prefix_len*(j+1)] + repeat.copy() 144 | inputs['input_ids'] = repeat_full 145 | 146 | for key in ['token_type_ids', 'attention_mask']: 147 | repeat_block = inputs[key][0:1]*single_prefix_len + inputs[key][1:][:max_length] 148 | inputs[key] = repeat_block*(repeats+1) 149 | 150 | return inputs 151 | 152 | 153 | def prepare_dataset(self, 154 | dataset, 155 | prefix_tokens_list=[], 156 | prefix_len=0, 157 | repeats=0, 158 | batch_size=8, 159 | max_length=100, 160 | task=None, 161 | label_offset=0, 162 | do_repeats=False): # input = dataset loaded from hugging face 163 | encoded_dataset = dataset.map(lambda x: self.encode_with_repeats(x, 164 | repeats=repeats, 165 | prefix_tokens_list=prefix_tokens_list, 166 | prefix_len=prefix_len, 167 | max_length=max_length, 168 | do_repeats=do_repeats), batched=False) 169 | if task==None: 170 | task = self.task 171 | label_key = 'label' if 'yahoo_' not in task else 'topic' 172 | 173 | if label_offset==0: # no change to the original label 174 | dataset2 = encoded_dataset.map(lambda examples: {'labels': examples[label_key]}, batched=True) 175 | else: # adding offset (for one head training) 176 | dataset2 = encoded_dataset.map(lambda examples: {'labels': examples[label_key] + label_offset}) 177 | dataset2.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels']) 178 | dataloader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size) 179 | 180 | return dataloader 181 | 182 | 183 | def select_subset_ds(self, ds, k=2000, task=None, seed=0): 184 | if task==None: 185 | task = self.task 186 | 187 | if self.task=='stsb': # stsb has continuos labels 188 | idx_total = np.random.choice(np.arange(ds.shape[0]), min(k, ds.shape[0]), replace=False) 189 | else: 190 | label_key = 'label' if 'yahoo_' not in task else 'topic' 191 | N = len(ds[label_key]) 192 | idx_total = np.array([], dtype='int64') 193 | 194 | for l in set(ds[label_key]): 195 | idx = np.where(np.array(ds[label_key]) == l)[0] 196 | idx_total = np.concatenate([idx_total, np.random.choice(idx, min(k, idx.shape[0]), replace=False)]) 197 | 198 | np.random.seed(seed) 199 | np.random.shuffle(idx_total) 200 | return ds.select(idx_total) 201 | 202 | 203 | def get_dataset(self, 204 | #task='cola', 205 | benchmark=None, 206 | prefix_tokens_list=[], 207 | prefix_len=0, 208 | split='train', 209 | repeats=0, 210 | batch_size=8, 211 | select_k_per_class=-1, 212 | max_length=100, 213 | return_test_subset=False, 214 | label_offset=0, 215 | seed=42, 216 | do_repeats=False): 217 | #dataset = load_dataset('glue', 'cola', split='train') 218 | task = self.task 219 | # if task == 'amazon': 220 | # # amazon reviews (with 5 starts) is only available from original paper's google drive 221 | # df = pd.read_csv('downloaded_data/amazon_review_full_csv/'+split+'.csv', header=None) 222 | # df = df.rename(columns={0: "label", 1: "title", 2: "content"}) 223 | # df['label'] = df['label'] - 1 224 | # dataset = datasets.Dataset.from_pandas(df) 225 | 226 | if task in ['cola', 'rte', 'mrpc', 'cb', 'copa', 'wsc'] and select_k_per_class>250: 227 | select_k_per_class = -1 # too small datasets for selection 228 | 229 | if task in ['ag', 'yahoo', 'yelp', 'amazon', 'dbpedia']: 230 | df = pd.read_csv('../datasets/src/data/'+task+'/'+split+'.csv', header=None) 231 | df = df.rename(columns={0: "label", 1: "title", 2: "content"}) 232 | df['label'] = df['label'] - 1 233 | dataset = datasets.Dataset.from_pandas(df) 234 | 235 | elif task == 'mnli': 236 | dataset = load_dataset('LysandreJik/glue-mnli-train', split=split) 237 | elif task == 'qnli': 238 | dataset = load_dataset('SetFit/qnli', split=split) 239 | elif task == 'stsb': 240 | dataset = load_dataset('stsb_multi_mt', name='en', split=split if split=='train' else 'dev') 241 | else: 242 | if benchmark != None: 243 | dataset = load_dataset(benchmark, task, split=split) 244 | else: 245 | dataset = load_dataset(task, split=split) 246 | 247 | if self.task == "yahoo_answers_topics": 248 | # for yahoo dataset we need to filter out empty rows (no question) 249 | if split=='train': 250 | good_id = np.load('good_id_yahoo_train.npy') 251 | dataset = dataset.select(good_id) 252 | elif split=='test': 253 | good_id = np.load('good_id_yahoo_test.npy') 254 | dataset = dataset.select(good_id) 255 | 256 | dataset = dataset.shuffle(seed=seed) 257 | 258 | if select_k_per_class != -1 and task not in ['copa', 'cb', 'wic']: 259 | k = select_k_per_class 260 | if return_test_subset: 261 | k *= 2 262 | dataset = self.select_subset_ds(dataset, k=k) 263 | 264 | if not return_test_subset: 265 | # returning one dataset 266 | dataset_final = self.prepare_dataset(dataset, 267 | repeats=repeats, 268 | batch_size=batch_size, 269 | max_length=max_length, 270 | prefix_tokens_list=prefix_tokens_list, 271 | prefix_len=prefix_len, 272 | label_offset=label_offset, 273 | do_repeats=do_repeats) 274 | return dataset_final 275 | 276 | else: 277 | # splitting current dataset into 2: val and test 278 | N = len(dataset) 279 | dataset_val = dataset.select(np.arange(0, N//2)) 280 | dataset_test = dataset.select(np.arange(N//2, N)) 281 | 282 | dataset_final_val = self.prepare_dataset(dataset_val, repeats=repeats, do_repeats=do_repeats, 283 | batch_size=batch_size, max_length=max_length, 284 | prefix_tokens_list=prefix_tokens_list, prefix_len=prefix_len, 285 | label_offset=label_offset) 286 | dataset_final_test = self.prepare_dataset(dataset_test, repeats=repeats, do_repeats=do_repeats, 287 | batch_size=batch_size, max_length=max_length, 288 | prefix_tokens_list=prefix_tokens_list, prefix_len=prefix_len, 289 | label_offset=label_offset) 290 | return dataset_final_val, dataset_final_test 291 | -------------------------------------------------------------------------------- /BERT_codebase/good_id_yahoo_test2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arazd/ProgressivePrompts/01572d6a73c0576b070ceee00dbe4f5bc278423f/BERT_codebase/good_id_yahoo_test2.npy -------------------------------------------------------------------------------- /BERT_codebase/good_id_yahoo_train2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arazd/ProgressivePrompts/01572d6a73c0576b070ceee00dbe4f5bc278423f/BERT_codebase/good_id_yahoo_train2.npy -------------------------------------------------------------------------------- /BERT_codebase/model_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoAdapterModel 2 | import numpy as np 3 | from transformers.adapters import PrefixTuningConfig 4 | from datasets import load_dataset 5 | import torch 6 | from tqdm import tqdm 7 | from transformers.models.bert.modeling_bert import BERT_INPUTS_DOCSTRING, BERT_START_DOCSTRING, BertModel, BertPreTrainedModel 8 | #from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward 9 | #from transformers.context import AdapterSetup 10 | from transformers.adapters.heads import ( 11 | BertStyleMaskedLMHead, 12 | BiaffineParsingHead, 13 | CausalLMHead, 14 | ClassificationHead, 15 | ModelWithFlexibleHeadsAdaptersMixin, 16 | MultiLabelClassificationHead, 17 | MultipleChoiceHead, 18 | QuestionAnsweringHead, 19 | TaggingHead, 20 | ) 21 | from transformers.adapters import BertAdapterModel 22 | from transformers import AdamW, get_constant_schedule_with_warmup 23 | from types import MethodType # to update attention calculation of BERT 24 | from torch import Tensor, device, nn 25 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 26 | import warnings 27 | from datasets import load_metric 28 | 29 | from copy import deepcopy 30 | 31 | glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax'] 32 | 33 | def get_mask_arr(k = 2, block_size = 100, device = None): 34 | n = k+1 # number of "stacked" input blocks 35 | mask_arr = np.ones([block_size*n, block_size*n]) 36 | 37 | for i in range(n): 38 | for m in range(i+1, n): 39 | mask_arr[i*block_size : (i+1)*block_size, 40 | m*block_size : (m+1)*block_size] = 0 41 | return torch.Tensor(mask_arr).to(device) 42 | 43 | 44 | 45 | # overriding attention mask for BERT 46 | # usage: self.get_extended_attention_mask(attention_mask, input_shape, device) 47 | def get_extended_attention_mask2( 48 | self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, 49 | k = 2, block_size = 100, blockwise_causal_mask = None, 50 | ) -> Tensor: 51 | 52 | """ 53 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. 54 | Arguments: 55 | attention_mask (`torch.Tensor`): 56 | Mask with ones indicating tokens to attend to, zeros for tokens to ignore. 57 | input_shape (`Tuple[int]`): 58 | The shape of the input to the model. 59 | Returns: 60 | `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. 61 | """ 62 | if not (attention_mask.dim() == 2 and self.config.is_decoder): 63 | # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` 64 | if device is not None: 65 | warnings.warn( 66 | "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning 67 | ) 68 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 69 | # ourselves in which case we just need to make it broadcastable to all heads. 70 | if attention_mask.dim() == 3: 71 | extended_attention_mask = attention_mask[:, None, :, :] 72 | elif attention_mask.dim() == 2: 73 | # Provided a padding mask of dimensions [batch_size, seq_length] 74 | # - if the model is a decoder, apply a causal mask in addition to the padding mask 75 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] 76 | if self.config.is_decoder: 77 | extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( 78 | input_shape, attention_mask, device 79 | ) 80 | else: 81 | extended_attention_mask = attention_mask[:, None, None, :] 82 | 83 | # OUR MODIFICATION FOR CUSTOM MASK 84 | if k != None and block_size != None: 85 | #print('applying blockwise attention') 86 | # for blockwise_causal_mask, 87 | # we broadcast [seq_len, seq_len] to [batch_size, num_heads, seq_length, seq_length] 88 | blockwise_causal_mask = get_mask_arr(k = k, block_size = block_size, device = device)[None, None, :, :] 89 | 90 | if blockwise_causal_mask != None: 91 | #print('blockwise_causal_mask ', blockwise_causal_mask.shape) 92 | #print('extended_attention_mask ', extended_attention_mask.shape) 93 | with torch.no_grad(): 94 | extended_attention_mask = blockwise_causal_mask * extended_attention_mask.to(device) 95 | else: 96 | raise ValueError( 97 | f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" 98 | ) 99 | 100 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 101 | # masked positions, this operation will create a tensor which is 0.0 for 102 | # positions we want to attend and -10000.0 for masked positions. 103 | # Since we are adding it to the raw scores before the softmax, this is 104 | # effectively the same as removing these entirely. 105 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 106 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 107 | return extended_attention_mask 108 | 109 | 110 | 111 | class ModelForCL: 112 | def __init__(self, 113 | model_name, 114 | tasks=['cola'], 115 | num_labels=[2], 116 | #blockwise_causal_attention=False, 117 | prefix_len=0, 118 | freeze_weights=False, 119 | freeze_except='word_embeddings', 120 | lr=2e-5, 121 | num_repeats=0, 122 | max_length=150, # max sequence length in #tokens 123 | cls_idx_override=None, 124 | prefix_MLPs=None, 125 | same_prompt=False, # whether to use the same prompt for all tasks 126 | ): 127 | self.model_name = model_name # "bert-base-uncased" 128 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 129 | self.model = AutoAdapterModel.from_pretrained(model_name) 130 | self.prefix_MLPs = prefix_MLPs 131 | 132 | self.tasks = tasks 133 | self.num_tasks = len(tasks) 134 | self.same_prompt = same_prompt 135 | #self.current_task = tasks[0] 136 | self.num_repeats = num_repeats 137 | self.max_length = max_length 138 | self.cls_idx_override = cls_idx_override 139 | # overriding attention mechanism to make block-wise causal attention 140 | #self.blockwise_causal_attention = blockwise_causal_attention 141 | 142 | cls_tok = self.tokenizer.tokenize("[CLS]") 143 | self.cls_id = self.tokenizer.convert_tokens_to_ids(cls_tok)[0] 144 | 145 | # freezing weights except word embeddings 146 | self.freeze_except = freeze_except 147 | self.freeze_weights = freeze_weights 148 | if self.freeze_weights: 149 | self.do_freeze_weights(except_condition=freeze_except) 150 | 151 | #self.model.add_classification_head(self.current_task, num_labels=num_labels[0]) 152 | for i in range(self.num_tasks): 153 | self.model.add_classification_head(self.tasks[i], num_labels=num_labels[i]) 154 | 155 | # adding special prefix tokens (CHANGE for new tasks) 156 | self.prefix_len = prefix_len 157 | 158 | if prefix_len > 0: 159 | self.prefix_tokens_list = {} 160 | 161 | if self.same_prompt: # assume we have just 1 task (i.e. 1 prompt for each task) 162 | self.prefix_tokens_list[0] = self.add_prompt_tokens(prefix_len, prompt_name='PRE0_') 163 | else: 164 | for i in range(self.num_tasks): 165 | # new prefix for each task 166 | # Task 1 = PRE1_1, ... PRE1_30 ; Task 2 = PRE2_1, ... PRE2_30 167 | self.prefix_tokens_list[self.tasks[i]] = self.add_prompt_tokens(prefix_len, prompt_name='PRE'+str(i)+'_') 168 | else: 169 | self.prefix_tokens_list = {self.tasks[i]: [] for i in range(self.num_tasks)} # empty prompt for each task 170 | 171 | #self.optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=lr) 172 | params_group = [{"params": self.model.parameters(), "lr": lr, "weight_decay": 0.01},] 173 | if self.prefix_MLPs != None: 174 | for t in prefix_MLPs: # append parameters of each task MLP 175 | params_group.append({"params": self.prefix_MLPs[t].parameters(), "lr": lr, "weight_decay": 0.01}) 176 | self.optimizer = AdamW(params_group) 177 | self.scheduler = get_constant_schedule_with_warmup(self.optimizer, 1000) 178 | 179 | # save prompt emb for all tasks 180 | #self.saved_embs = deepcopy(self.model.bert.embeddings.word_embeddings.weight[:-self.prefix_len*self.num_tasks].cpu().detach().numpy()) 181 | self.saved_embs = deepcopy(self.model.bert.embeddings.word_embeddings.weight.cpu().detach().numpy()) 182 | 183 | 184 | 185 | def add_prompt_tokens(self, prefix_len, prompt_name='PRE'): 186 | tokenizer = self.tokenizer 187 | model = self.model 188 | N = model.bert.embeddings.word_embeddings.weight.shape[0] # wte shape before resize 189 | 190 | # tokens_list - ['[PRE1]', '[PRE2]', '[PRE3]'] 191 | tokens_list = ['['+ prompt_name + str(i) + ']' for i in np.arange(1, prefix_len+1)] 192 | special_tokens_dict = {'additional_special_tokens': tokens_list} 193 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 194 | model.bert.resize_token_embeddings(len(tokenizer)) 195 | 196 | model.bert.embeddings.word_embeddings.weight.requires_grad = False 197 | for i in range(len(tokens_list)): 198 | with torch.no_grad(): 199 | # initialize pre1 as CLS token 200 | if i==0: 201 | j = self.cls_id 202 | # initalize pre2, pre3 ... with random word embedding 203 | else: 204 | j = np.random.randint(N) 205 | #model.bert.embeddings.word_embeddings.weight[N+i, :] = \ 206 | #model.bert.embeddings.word_embeddings.weight[j] 207 | w = deepcopy(model.bert.embeddings.word_embeddings.weight[j].detach().cpu().numpy()) 208 | model.bert.embeddings.word_embeddings.weight[N+i] = torch.from_numpy(w) 209 | model.bert.embeddings.word_embeddings.weight.requires_grad = True 210 | return tokens_list 211 | 212 | 213 | # freeze all weights except for word emb (or other condition specified) 214 | def do_freeze_weights(self, except_condition='word_embeddings'): 215 | for name, param in self.model.bert.named_parameters(): 216 | if param.requires_grad == True and except_condition not in name: 217 | param.requires_grad = False 218 | 219 | # freeze / unfreeze MLPs for given tasks (when requires_grad==False then freezing) 220 | def freeze_unfreeze_mlps(self, tasks, blocks=['mlp', 'head'], requires_grad=False): 221 | for x in blocks: # we only freeze/unfreeze cls heads and MLPs for CL setting 222 | assert x in ['mlp', 'head'] 223 | 224 | param_groups = [] 225 | for name in blocks: 226 | if name=='mlp': 227 | assert self.prefix_MLPs != None 228 | param_groups.append(self.prefix_MLPs) 229 | if name=='head': 230 | param_groups.append(self.model.heads) 231 | 232 | for t in tasks: 233 | for p_group in param_groups: 234 | #for name, param in self.prefix_MLPs[t].named_parameters(): 235 | for name, param in p_group[t].named_parameters(): 236 | if param.requires_grad != requires_grad: 237 | param.requires_grad = requires_grad 238 | param.grad = None # remove old gradient 239 | 240 | 241 | # overriding attention mask for blockwise causal attention: 242 | def override_attention_mask(self, num_repeats, repeat_length, device): 243 | blockwise_causal_mask = get_mask_arr(k=num_repeats, block_size=repeat_length, device=device) 244 | self.model.bert.get_extended_attention_mask = MethodType(lambda self, attention_mask, input_shape, device: \ 245 | get_extended_attention_mask2(self, attention_mask, input_shape, device, 246 | k = None, 247 | block_size = None, 248 | blockwise_causal_mask = blockwise_causal_mask,), 249 | self.model.bert) 250 | return blockwise_causal_mask 251 | 252 | # get custom position ids for sentence with repeats 253 | # [cls] input [cls1] input 254 | # 0 1 2 ... N 0 N+1 ... 255 | def get_position_ids(self, dataloader, device): 256 | tokenizer = self.tokenizer 257 | b = next(iter(dataloader)) 258 | pos = list(range( len(b['input_ids'][0]) )) 259 | #pos[100] = 0 260 | for tok in self.special_tokens_list: 261 | tokenized = tokenizer.tokenize(tok) 262 | tokenized_ids = tokenizer.convert_tokens_to_ids(tokenized) 263 | clsN_pos = int(torch.where(b['input_ids'][0] == tokenized_ids[0])[0].cpu()) 264 | pos[clsN_pos] = 0 265 | 266 | pos = torch.tensor(pos) 267 | pos = torch.cat([pos.view(1,-1)]* len(b['input_ids']), axis=0).to(device) 268 | return pos 269 | 270 | 271 | def train(self, dataloader, task, epochs=5, cls_idx=100, dataloader_val=None, metric=None): 272 | #cls_idx=self.max_length # CHANGE FOR TASK 3+ 273 | if self.cls_idx_override!=None: 274 | cls_idx = self.cls_idx_override 275 | print('Using CLS idx ', cls_idx) 276 | model = self.model 277 | optimizer = self.optimizer 278 | scheduler = self.scheduler 279 | tokenizer = self.tokenizer 280 | # fine-tuning with 2 repeats 281 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 282 | model.train().to(device) 283 | 284 | if self.num_repeats>=1: 285 | pos = self.get_position_ids(dataloader, device) 286 | 287 | val_scores = [] 288 | 289 | for epoch in range(epochs): 290 | for i, batch in enumerate(tqdm(dataloader)): 291 | batch = {k: v.to(device) for k, v in batch.items()} 292 | 293 | if self.num_repeats < 1 or cls_idx==0: ## num_repeats <1 or ==1? 294 | outputs = model(**batch) 295 | 296 | else: 297 | out = model.bert(**{'input_ids': batch['input_ids'], 298 | 'attention_mask': batch['attention_mask'], 299 | 'token_type_ids': batch['token_type_ids'], 300 | 'position_ids': pos[:len(batch['input_ids'])], 301 | }) 302 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 303 | outputs = model.heads[task](outputs = out, 304 | cls_output = cls_output, 305 | return_dict = True, 306 | labels = batch['labels']) 307 | 308 | loss = outputs.loss 309 | loss.backward() 310 | 311 | # only allowing updates for added special token if required 312 | if self.freeze_weights == 1 and self.freeze_except == 'word_embeddings': 313 | k = len(self.special_tokens_list) 314 | model.bert.embeddings.word_embeddings.weight.grad[:-k] = 0 315 | #model.bert.embeddings.word_embeddings.weight.grad[:-1] = 0 316 | optimizer.step() 317 | scheduler.step() 318 | optimizer.zero_grad() 319 | 320 | if dataloader_val != None: 321 | result = self.eval(dataloader_val, self.current_task, metric, cls_idx) 322 | print('result = ',result) 323 | #metric_key = list(result)[0] # append results as value floats, instead of dict metric -> value 324 | metric_key = self.task_to_metric_key(task) 325 | val_scores.append(result[metric_key]) 326 | 327 | return val_scores 328 | 329 | 330 | 331 | def pass_batch(self, batch, task, device, cls_idx=0, only_output_loss=False): 332 | if self.cls_idx_override!=None: 333 | cls_idx = self.cls_idx_override 334 | #print('Using CLS idx ', cls_idx) 335 | model = self.model 336 | optimizer = self.optimizer 337 | scheduler = self.scheduler 338 | tokenizer = self.tokenizer 339 | 340 | batch = {k: v.to(device) for k, v in batch.items()} 341 | 342 | out = model.bert(**{'input_ids': batch['input_ids'], 343 | 'attention_mask': batch['attention_mask'], 344 | 'token_type_ids': batch['token_type_ids'], 345 | #'position_ids': pos[:len(batch['input_ids'])], 346 | }) 347 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 348 | 349 | outputs = model.heads[task](outputs = out, 350 | cls_output = cls_output, 351 | return_dict = True, 352 | labels = batch['labels']) 353 | 354 | loss = outputs.loss 355 | 356 | if only_output_loss: 357 | # just returning the loss for subsequent operations (e.g. sum) 358 | return loss 359 | 360 | else: 361 | # performing optimization step here 362 | loss.backward() 363 | optimizer.step() 364 | scheduler.step() 365 | optimizer.zero_grad() 366 | 367 | 368 | 369 | def pass_batch_with_prompt(self, batch, task, device, prefix_len=None, cls_idx=0, custom_pos_ids=True): 370 | if self.cls_idx_override!=None: 371 | cls_idx = self.cls_idx_override # position of cls in sentence (usually 0) 372 | 373 | if prefix_len == None: 374 | prefix_len = self.prefix_len 375 | #print('Using CLS idx ', cls_idx) 376 | model = self.model 377 | optimizer = self.optimizer 378 | scheduler = self.scheduler 379 | tokenizer = self.tokenizer 380 | prefix_MLP = None if self.prefix_MLPs==None else self.prefix_MLPs[task] 381 | if self.same_prompt: 382 | task_idx = 0 # constast task idx bcs of shared prompt 383 | else: 384 | task_idx = self.tasks.index(task) 385 | 386 | batch_keys = ['input_ids', 'token_type_ids'] 387 | if custom_pos_ids: 388 | batch_keys.append('position_ids') # loop through custom position ids 389 | emb = model.bert.embeddings(**{k: batch[k] for k in batch_keys}) 390 | if prefix_len>0 and prefix_MLP != None: 391 | # WORKS FOR NO REPEATS CASE 392 | #emb[:, :prefix_len, :] = prefix_MLP(emb[:, :prefix_len, :].clone().to(device)) 393 | pos1, pos2 = self.max_length + prefix_len*task_idx, self.max_length + prefix_len*(task_idx+1) 394 | emb[:, pos1:pos2, :] = prefix_MLP(emb[:, pos1:pos2, :].clone().to(device)) 395 | 396 | extended_attention_mask: torch.Tensor = model.bert.get_extended_attention_mask(batch['attention_mask'], 397 | batch['input_ids'].shape, 398 | device=device) 399 | out = model.bert.encoder(emb, attention_mask=extended_attention_mask) 400 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 401 | outputs = model.heads[task](outputs = out, 402 | cls_output = cls_output, 403 | return_dict = True, 404 | labels = batch['labels']) 405 | 406 | loss = outputs.loss 407 | loss.backward() 408 | 409 | # only allowing updates for added special token if required 410 | #if freeze_except == 'word_embeddings': 411 | if prefix_len>0: 412 | emb_size = model.bert.embeddings.word_embeddings.weight.shape[0] 413 | k1, k2 = emb_size - prefix_len * (self.num_tasks - task_idx), emb_size - prefix_len * (self.num_tasks - task_idx -1) 414 | model.bert.embeddings.word_embeddings.weight.grad[:k1] = 0 415 | model.bert.embeddings.word_embeddings.weight.grad[k2:] = 0 416 | 417 | optimizer.step() 418 | scheduler.step() 419 | 420 | 421 | #model.bert.embeddings.word_embeddings.weight[:-prefix_len] = torch.from_numpy(self.saved_embs) 422 | if prefix_len>0: 423 | model.bert.embeddings.word_embeddings.weight.requires_grad = False 424 | model.bert.embeddings.word_embeddings.weight[:k1] = torch.from_numpy(self.saved_embs[:k1]) # restore all emb except curr task 425 | model.bert.embeddings.word_embeddings.weight[k2:] = torch.from_numpy(self.saved_embs[k2:]) 426 | model.bert.embeddings.word_embeddings.weight.requires_grad = True 427 | optimizer.zero_grad() 428 | 429 | 430 | 431 | def update_baseline_model_emb(self): 432 | # update word emb matrix for continual prompt tuning 433 | # so that we don't "forget" learned prompts during re-setting 434 | self.saved_embs = deepcopy(self.model.bert.embeddings.word_embeddings.weight.cpu().detach().numpy()) 435 | 436 | 437 | # initialize new task prompt from previous task prompts 438 | def init_new_prompt(self, task_idx_curr): 439 | prefix_len = self.prefix_len 440 | model = self.model 441 | emb_size = model.bert.embeddings.word_embeddings.weight.shape[0] 442 | k1_curr, k2_curr = emb_size - prefix_len * (self.num_tasks - task_idx_curr), emb_size - prefix_len * (self.num_tasks - task_idx_curr -1) 443 | task_idx_next = task_idx_curr+1 444 | k1_next, k2_next = emb_size - prefix_len * (self.num_tasks - task_idx_next), emb_size - prefix_len * (self.num_tasks - task_idx_next -1) 445 | 446 | model.bert.embeddings.word_embeddings.weight.requires_grad = False 447 | model.bert.embeddings.word_embeddings.weight[k1_next:k2_next] = torch.from_numpy(self.saved_embs[k1_curr:k2_curr]) # init new task emb from curr task emb 448 | model.bert.embeddings.word_embeddings.weight.requires_grad = True 449 | 450 | 451 | def save_curr_task_emb(self, task_idx_curr, save_path): 452 | prefix_len = self.prefix_len 453 | model = self.model 454 | emb_size = model.bert.embeddings.word_embeddings.weight.shape[0] 455 | k1_curr, k2_curr = emb_size - prefix_len * (self.num_tasks - task_idx_curr), emb_size - prefix_len * (self.num_tasks - task_idx_curr -1) 456 | np.save(save_path, self.saved_embs[k1_curr:k2_curr]) # save np array with curr task emb 457 | 458 | 459 | # modifies output by passing soft prompt embs throught MLP 460 | def get_bert_output_with_prompt(self, batch, prefix_len, device): 461 | model = self.model 462 | emb = model.bert.embeddings(**{k: batch[k] for k in ['input_ids', 'token_type_ids']}) 463 | emb[:, :prefix_len, :] = self.prefix_MLP(emb[:, :prefix_len, :].to(device)) 464 | 465 | extended_attention_mask: torch.Tensor = model.bert.get_extended_attention_mask(batch['attention_mask'], 466 | batch['input_ids'].shape, 467 | device=device) 468 | out = model.bert.encoder(emb, attention_mask=extended_attention_mask) 469 | return out 470 | 471 | 472 | 473 | # returns metric corresponding to the task 474 | def task_to_metric_key(self, task): 475 | if task not in glue_datasets: 476 | return 'accuracy' 477 | 478 | if task in ['qqp', 'mrpc']: 479 | return 'f1' 480 | 481 | #elif 'mnli' in task or task == 'cola': 482 | elif task == 'cola': 483 | return 'matthews_correlation' 484 | 485 | else: 486 | return 'accuracy' 487 | 488 | 489 | 490 | def eval(self, dataloader_val, task, metric, cls_idx=100): 491 | #ls_idx=self.max_length # CHANGE FOR TASK 3+ 492 | if self.cls_idx_override!=None: 493 | cls_idx = self.cls_idx_override 494 | model = self.model 495 | tokenizer = self.tokenizer 496 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 497 | model.eval().to(device) 498 | 499 | if metric==None: 500 | if task in glue_datasets: 501 | metric = load_metric('glue', task) 502 | else: 503 | metric = load_metric('accuracy') 504 | 505 | if self.num_repeats>=1: 506 | pos = self.get_position_ids(dataloader_val, device) 507 | 508 | for i, batch in enumerate(tqdm(dataloader_val)): 509 | batch = {k: v.to(device) for k, v in batch.items()} 510 | with torch.no_grad(): 511 | # Forward pass, calculate logit predictions 512 | inp_dict = {'input_ids': batch['input_ids'], 513 | 'attention_mask': batch['attention_mask'], 514 | 'token_type_ids': batch['token_type_ids'], 515 | } 516 | if self.num_repeats < 1 or cls_idx==0: 517 | # out = model.bert(**{'input_ids': batch['input_ids'], 518 | # 'attention_mask': batch['attention_mask'], 519 | # 'token_type_ids': batch['token_type_ids'], 520 | # }) 521 | pass 522 | else: 523 | inp_dict['position_ids'] = pos[:len(batch['input_ids'])] 524 | out = model.bert(**inp_dict) 525 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 526 | outputs = model.heads[task](outputs = out, 527 | cls_output = cls_output, 528 | return_dict = True, 529 | labels = batch['labels']) 530 | 531 | predictions = torch.argmax(outputs.logits, dim=-1) 532 | metric.add_batch(predictions=predictions, references=batch['labels']) 533 | 534 | result = metric.compute() 535 | #metric_key = list(result)[0] # we want to return value float (not dict metric -> value) 536 | metric_key = self.task_to_metric_key(task) 537 | return result[metric_key] 538 | 539 | 540 | 541 | 542 | def eval_with_prompt(self, dataloader_val, task, metric, cls_idx=0, prefix_len=None, custom_pos_ids=True, pos_ids=None): 543 | if custom_pos_ids: assert pos_ids != None 544 | 545 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 546 | if self.cls_idx_override!=None: 547 | cls_idx = self.cls_idx_override 548 | print('Using cls_idx', cls_idx) 549 | 550 | if prefix_len==None: 551 | prefix_len = self.prefix_len 552 | 553 | model = self.model 554 | prefix_MLP = None if self.prefix_MLPs==None else self.prefix_MLPs[task] 555 | if self.same_prompt: 556 | task_idx = 0 557 | else: 558 | task_idx = self.tasks.index(task) 559 | model.eval().to(device) 560 | if prefix_MLP!=None: 561 | prefix_MLP.eval().to(device) 562 | 563 | #metric = load_metric('glue', task) 564 | if task in glue_datasets: 565 | metric = load_metric('glue', task) 566 | else: 567 | metric = load_metric('accuracy') 568 | 569 | for i, batch in enumerate(tqdm(dataloader_val)): 570 | batch = {k: v.to(device) for k, v in batch.items()} 571 | with torch.no_grad(): 572 | batch_keys = ['input_ids', 'token_type_ids'] 573 | if custom_pos_ids: 574 | batch['position_ids'] = pos_ids.to(device) 575 | batch_keys.append('position_ids') # loop through custom position ids 576 | emb = model.bert.embeddings(**{k: batch[k] for k in batch_keys}) 577 | 578 | if prefix_len>0 and prefix_MLP != None: 579 | pos1, pos2 = self.max_length + prefix_len*task_idx, self.max_length + prefix_len*(task_idx+1) 580 | emb[:, pos1:pos2, :] = prefix_MLP(emb[:, pos1:pos2, :].clone().to(device)) 581 | #emb[:, :prefix_len, :] = prefix_MLP(emb[:, :prefix_len, :].to(device)) 582 | 583 | extended_attention_mask: torch.Tensor = model.bert.get_extended_attention_mask(batch['attention_mask'], 584 | batch['input_ids'].shape, 585 | device=device) 586 | out = model.bert.encoder(emb, attention_mask=extended_attention_mask) 587 | 588 | cls_output = out.last_hidden_state[:,cls_idx,:].to(device) 589 | outputs = model.heads[task](outputs = out, 590 | cls_output = cls_output, 591 | return_dict = True, 592 | labels = batch['labels']) 593 | 594 | predictions = torch.argmax(outputs.logits, dim=-1) 595 | metric.add_batch(predictions=predictions, references=batch['labels']) 596 | 597 | res = metric.compute() 598 | #metric_key = list(res)[0] # we want to return value float (not dict metric -> value) 599 | metric_key = self.task_to_metric_key(task) 600 | return res[metric_key] 601 | -------------------------------------------------------------------------------- /BERT_codebase/run_prog_prompts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=cl_nlp 3 | #SBATCH --partition=hipri 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --account=all 6 | #SBATCH --time=1-00:00:00 7 | #SBATCH --output=/data/home/%u/CL/prog_prompt_%j.log 8 | 9 | source ~/miniconda/bin/activate  10 | conda init 11 | source activate nlp 12 | 13 | HPARAMS=( 14 | "--task_list ag_news yelp_review_full amazon yahoo_answers_topics dbpedia --save_name prog_prompt_len10_1_order4 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 15 | "--task_list ag_news yelp_review_full amazon yahoo_answers_topics dbpedia --save_name prog_prompt_len10_2_order4 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 16 | "--task_list ag_news yelp_review_full amazon yahoo_answers_topics dbpedia --save_name prog_prompt_len10_3_order4 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 17 | 18 | "--task_list ag_news yelp_review_full amazon yahoo_answers_topics dbpedia --save_name prog_prompt_len5_1_order4 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 19 | "--task_list ag_news yelp_review_full amazon yahoo_answers_topics dbpedia --save_name prog_prompt_len5_2_order4 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 20 | "--task_list ag_news yelp_review_full amazon yahoo_answers_topics dbpedia --save_name prog_prompt_len5_3_order4 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 21 | 22 | 23 | "--task_list yelp_review_full yahoo_answers_topics amazon dbpedia ag_news --save_name prog_prompt_len10_1_order5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 24 | "--task_list yelp_review_full yahoo_answers_topics amazon dbpedia ag_news --save_name prog_prompt_len10_2_order5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 25 | "--task_list yelp_review_full yahoo_answers_topics amazon dbpedia ag_news --save_name prog_prompt_len10_3_order5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 26 | 27 | "--task_list yelp_review_full yahoo_answers_topics amazon dbpedia ag_news --save_name prog_prompt_len5_1_order5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 28 | "--task_list yelp_review_full yahoo_answers_topics amazon dbpedia ag_news --save_name prog_prompt_len5_2_order5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 29 | "--task_list yelp_review_full yahoo_answers_topics amazon dbpedia ag_news --save_name prog_prompt_len5_3_order5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 30 | 31 | 32 | "--task_list dbpedia yahoo_answers_topics ag_news amazon yelp_review_full --save_name prog_prompt_len10_1_order6 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 33 | "--task_list dbpedia yahoo_answers_topics ag_news amazon yelp_review_full --save_name prog_prompt_len10_2_order6 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 34 | "--task_list dbpedia yahoo_answers_topics ag_news amazon yelp_review_full --save_name prog_prompt_len10_3_order6 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 35 | 36 | "--task_list dbpedia yahoo_answers_topics ag_news amazon yelp_review_full --save_name prog_prompt_len5_1_order6 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 37 | "--task_list dbpedia yahoo_answers_topics ag_news amazon yelp_review_full --save_name prog_prompt_len5_2_order6 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 38 | "--task_list dbpedia yahoo_answers_topics ag_news amazon yelp_review_full --save_name prog_prompt_len5_3_order6 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 39 | 40 | 41 | # "--task_list yahoo_answers_topics ag_news yelp_review_full --save_name prog_prompt_len10_1_order3 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 42 | # "--task_list yahoo_answers_topics ag_news yelp_review_full --save_name prog_prompt_len10_2_order3 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 43 | # "--task_list yahoo_answers_topics ag_news yelp_review_full --save_name prog_prompt_len10_3_order3 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 10" 44 | 45 | # "--task_list yahoo_answers_topics ag_news yelp_review_full --save_name prog_prompt_len5_1_order3 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 46 | # "--task_list yahoo_answers_topics ag_news yelp_review_full --save_name prog_prompt_len5_2_order3 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 47 | # "--task_list yahoo_answers_topics ag_news yelp_review_full --save_name prog_prompt_len5_3_order3 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings --prompt_tuning 1 --prefix_len 5" 48 | 49 | 50 | ) 51 | 52 | cmd="python train_cl2.py ${HPARAMS[SLURM_ARRAY_TASK_ID]} \ 53 | --seq_len 450 --select_k_per_class 2000 --early_stopping 1 --one_head 0" 54 | 55 | echo $cmd 56 | eval $cmd 57 | -------------------------------------------------------------------------------- /BERT_codebase/train_cl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | import logging, os, argparse 6 | 7 | import dataset_utils, model_utils 8 | 9 | 10 | task_to_num_labels = { 11 | 'cola': 2, 12 | 'sst2': 2, 13 | 'yelp_review_full': 5, 14 | 'ag_news': 4, 15 | 'yahoo_answers_topics': 10 16 | } 17 | glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax'] 18 | 19 | 20 | def train_on_task(trainer, 21 | task, 22 | num_epochs, 23 | batch_size, 24 | k=-1, 25 | cls_idx=0, 26 | ): 27 | print('Starting task ', task) 28 | du = dataset_utils.Dataset(task=task, tokenizer=trainer.tokenizer) 29 | data_params = {'repeats': trainer.num_repeats, 30 | 'batch_size': batch_size, 31 | 'max_length': trainer.max_length, 32 | 'special_tokens_list': trainer.special_tokens_list, 33 | #'select_k_per_class': k 34 | } 35 | benchmark = 'glue' if task in glue_datasets else None 36 | val_split = 'validation' if task in glue_datasets else 'test' 37 | dataloader_train = du.get_dataset(benchmark=benchmark, split='train', select_k_per_class=k, **data_params) 38 | k_val = -1 if k==-1 else int(k*0.1) 39 | dataloader_val = du.get_dataset(benchmark=benchmark, split=val_split, select_k_per_class=k_val, **data_params) 40 | 41 | print("Trainig set size = ", len(dataloader_train)*batch_size) 42 | scores_list = trainer.train(dataloader_train, 43 | task, 44 | epochs=num_epochs, 45 | cls_idx=cls_idx, 46 | #cls_idx=trainer.max_length, # CHANGE THIS CLS IDX 47 | dataloader_val=dataloader_val) 48 | 49 | return scores_list 50 | 51 | 52 | 53 | def continual_training(trainer, 54 | tasks=[], 55 | num_epochs=5, 56 | batch_size=8, 57 | k=-1): 58 | results_dict = {} 59 | for i, task in enumerate(tasks): 60 | print('TASK ', task) 61 | val_scores = train_on_task(trainer, task, num_epochs, batch_size, k=k) 62 | results_dict[task] = val_scores 63 | return results_dict 64 | 65 | 66 | 67 | def main(args): 68 | if torch.cuda.is_available(): 69 | device = torch.device("cuda") 70 | else: 71 | device = torch.device("cpu") 72 | 73 | save_path = os.path.join(args.save_dir, args.save_name) 74 | if not os.path.exists(save_path): 75 | os.mkdir(save_path) 76 | log_file = os.path.join(save_path, 'LR_logs') 77 | logging.basicConfig(filename=log_file,level=logging.DEBUG) 78 | 79 | logging.info("starting training script") 80 | 81 | task_list = args.task_list # ['cola'] 82 | special_tokens_list = ["[CLS"+str(i+1)+"]" for i in range(len(task_list))] #[ "[CLS1]" ], 83 | num_labels = [task_to_num_labels[t] for t in task_list] # CHANGE THIS !!!! 84 | 85 | model_name = args.model_name #"bert-base-uncased" 86 | trainer = model_utils.ModelForCL( model_name, 87 | tasks=task_list, 88 | num_labels=num_labels, 89 | blockwise_causal_attention= (args.block_attn==1), 90 | special_tokens_list=special_tokens_list, 91 | init_token="[CLS]", 92 | freeze_weights= (args.freeze_weights==1), 93 | freeze_except=args.freeze_except, #by default 'word_embeddings', 94 | lr=args.lr, 95 | num_repeats=args.num_repeats, 96 | max_length=args.seq_len, # max sequence length in #tokens 97 | cls_idx_override=args.cls_idx_override, 98 | same_prompt=args.same_prompt==1, 99 | ) 100 | 101 | # task = trainer.current_task 102 | # du = dataset_utils.Dataset(task=task, tokenizer=trainer.tokenizer) 103 | # data_params = {'repeats': trainer.num_repeats, 104 | # 'batch_size': args.batch_size, 105 | # 'max_length': trainer.max_length, 106 | # 'special_tokens_list': trainer.special_tokens_list} 107 | # dataloader_train = du.get_dataset(benchmark='glue', split='train', **data_params) 108 | # dataloader_val = du.get_dataset(benchmark='glue', split='validation', **data_params) 109 | # 110 | # 111 | # scores_list = trainer.train(dataloader_train, 112 | # task, 113 | # epochs=args.num_epochs, 114 | # cls_idx=trainer.max_length, 115 | # dataloader_val=dataloader_val) 116 | # 117 | # if not os.path.exists( os.path.join(save_path, 'val_scores.npy') ): 118 | # np.save(os.path.join(save_path, 'val_scores.npy'), scores_list) 119 | # 120 | # else: 121 | # np.save(os.path.join(save_path, 'val_scores2.npy'), scores_list) 122 | 123 | results_dict = continual_training( trainer, 124 | tasks=task_list, 125 | num_epochs=args.num_epochs, 126 | batch_size=args.batch_size, 127 | k=args.select_k_per_class) 128 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 129 | 130 | 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser( 135 | description='NLP training script in PyTorch' 136 | ) 137 | 138 | parser.add_argument( 139 | '--save_dir', 140 | type=str, 141 | help='base directory of all models / features (should not be changed)', 142 | default='/data/home/arazdai/CL' #'/scratch/hdd001/home/anastasia/CL/' 143 | ) 144 | 145 | parser.add_argument( 146 | '--save_name', 147 | type=str, 148 | help='folder name to save', 149 | required=True 150 | ) 151 | 152 | parser.add_argument( 153 | '--model_name', 154 | type=str, 155 | help='Name of the model used for training', 156 | default="bert-base-uncased" 157 | ) 158 | 159 | parser.add_argument( 160 | '--num_epochs', 161 | type=int, 162 | help='Number of epochs to train model', 163 | default=5 164 | ) 165 | 166 | parser.add_argument( 167 | '--batch_size', 168 | type=int, 169 | help='Batch size', 170 | default=8 171 | ) 172 | 173 | parser.add_argument( 174 | '--seq_len', 175 | type=int, 176 | help='Length of a single repeat (in #tokens)', 177 | default=150 178 | ) 179 | 180 | parser.add_argument( 181 | '--num_repeats', 182 | type=int, 183 | help='Number of sentence repeats', 184 | required=True 185 | ) 186 | 187 | parser.add_argument( 188 | '--lr', 189 | type=float, 190 | help='Learning rate', 191 | default=2e-5 192 | ) 193 | 194 | parser.add_argument( 195 | '--block_attn', 196 | type=int, 197 | help='Whether to use blockwise causal attention', 198 | default=0 199 | ) 200 | 201 | parser.add_argument( 202 | '--same_prompt', 203 | type=int, 204 | help='Whether to use the same prompt across all tasks', 205 | default=0 206 | ) 207 | 208 | parser.add_argument( 209 | '--select_k_per_class', 210 | type=int, 211 | help='Select k examples from each class (default -1, i.e. no changes to the original dataset)', 212 | default=-1 213 | ) 214 | 215 | parser.add_argument( 216 | '--freeze_weights', 217 | type=int, 218 | help='Whether to freeze model weigts (except word emb)', 219 | default=0 220 | ) 221 | 222 | parser.add_argument( 223 | '--freeze_except', 224 | type=str, 225 | help='If freeze_weights==1, freeze all weights except those that contain this keyword', 226 | default='word_embeddings' 227 | ) 228 | 229 | parser.add_argument( 230 | '--task_list', 231 | nargs='+', 232 | help='List of tasks for training', 233 | required=True 234 | ) 235 | 236 | parser.add_argument( 237 | '--cls_idx_override', 238 | type=int, 239 | help='Position of classification token; by default will use current task k token - CLSk', 240 | default=None 241 | ) 242 | 243 | main(parser.parse_args()) 244 | -------------------------------------------------------------------------------- /BERT_codebase/train_cl2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | import logging, os, argparse 6 | 7 | import dataset_utils, model_utils, continual_learning_utils, continual_learning_one_head 8 | 9 | 10 | 11 | def main(args): 12 | 13 | save_path = os.path.join(args.save_dir, args.save_name) 14 | if not os.path.exists(save_path): 15 | os.mkdir(save_path) 16 | log_file = os.path.join(save_path, 'LR_logs') 17 | 18 | if args.one_head == 0: 19 | prefix_MLP = args.prefix_MLP 20 | #prefix_MLP = None 21 | CL_class = continual_learning_utils.ContinualLearner(args.model_name, 22 | args.task_list, 23 | batch_size=args.batch_size, 24 | select_k_per_class=args.select_k_per_class, 25 | memory_perc=args.memory_perc, 26 | #block_attn=0, 27 | freeze_weights=args.freeze_weights, 28 | freeze_except=args.freeze_except, 29 | lr=args.lr, 30 | seq_len=args.seq_len, 31 | cls_idx_override=args.cls_idx_override, 32 | early_stopping= args.early_stopping==1, 33 | prefix_len=args.prefix_len, 34 | prefix_MLP=prefix_MLP, 35 | bottleneck_size=args.bottleneck_size, 36 | do_repeats= args.do_repeats==1, 37 | same_prompt=args.same_prompt==1, 38 | ) 39 | 40 | 41 | if args.multitask == 1: 42 | print("Multi-task learning") 43 | results_dict = CL_class.multi_task_training(num_epochs=args.num_epochs) 44 | else: 45 | results_dict = CL_class.continual_training(num_epochs=args.num_epochs, 46 | data_replay_freq=args.data_replay_freq, 47 | prompt_tuning=args.prompt_tuning==1, 48 | prompt_init=args.prompt_init, 49 | save_prompt_path='None' if args.save_prompt==0 else save_path, 50 | save_results_path=save_path, 51 | ) 52 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 53 | 54 | 55 | ## IDBR results 56 | else: 57 | CL_class = continual_learning_one_head.ContinualLearnerIDBR(args.model_name, 58 | args.task_list, 59 | batch_size=args.batch_size, 60 | select_k_per_class=args.select_k_per_class, 61 | memory_perc=args.memory_perc, 62 | #block_attn=0, 63 | freeze_weights=0, 64 | freeze_except=args.freeze_except, 65 | lr=args.lr, 66 | seq_len=args.seq_len, 67 | cls_idx_override=args.cls_idx_override, 68 | early_stopping= args.early_stopping==1, 69 | ) 70 | if args.multitask == 1: 71 | print("Multi-task learning IDBR style") 72 | results_dict = CL_class.multi_task_training(num_epochs=args.num_epochs) 73 | else: 74 | print("Training in IDBR style") 75 | results_dict = CL_class.continual_training(num_epochs=args.num_epochs, 76 | data_replay_freq=args.data_replay_freq, 77 | regspe=args.regspe, 78 | reggen=args.reggen, 79 | tskcoe=args.tskcoe, 80 | nspcoe=args.nspcoe, 81 | disen=args.disen==1) 82 | 83 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 84 | 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser( 90 | description='NLP training script in PyTorch' 91 | ) 92 | 93 | parser.add_argument( 94 | '--save_dir', 95 | type=str, 96 | help='base directory of all models / features (should not be changed)', 97 | default='/data/home/arazdai/CL/all_CL_results/' 98 | ) 99 | 100 | parser.add_argument( 101 | '--save_name', 102 | type=str, 103 | help='folder name to save', 104 | required=True 105 | ) 106 | 107 | parser.add_argument( 108 | '--model_name', 109 | type=str, 110 | help='Name of the model used for training', 111 | default="bert-base-uncased" 112 | ) 113 | 114 | parser.add_argument( 115 | '--num_epochs', 116 | type=int, 117 | help='Number of epochs to train model', 118 | default=5 119 | ) 120 | 121 | parser.add_argument( 122 | '--memory_perc', 123 | type=float, 124 | help='Percentage of examples from previous tasks to use for data replay; if 0 then no data replay', 125 | default=0 #0.01 126 | ) 127 | 128 | parser.add_argument( 129 | '--data_replay_freq', 130 | type=int, 131 | help='Data replay happens after every X samples (if -1 then no data replay)', 132 | default=-1 #9 133 | ) 134 | 135 | parser.add_argument( 136 | '--batch_size', 137 | type=int, 138 | help='Batch size', 139 | default=8 140 | ) 141 | 142 | parser.add_argument( 143 | '--seq_len', 144 | type=int, 145 | help='Length of a single repeat (in #tokens)', 146 | default=512 147 | ) 148 | 149 | 150 | parser.add_argument( 151 | '--prefix_len', 152 | type=int, 153 | help='Soft prompt length (same for each task)', 154 | default=10 155 | ) 156 | 157 | parser.add_argument( 158 | '--prefix_MLP', 159 | type=str, 160 | help='Whether to use MLP reparametrization on prompt embeddings (default = "None", no reparametrization)', 161 | default='None' 162 | ) 163 | 164 | parser.add_argument( 165 | '--bottleneck_size', 166 | type=int, 167 | help='Bottleneck size in case of MLP reparametrization', 168 | default=800 169 | ) 170 | 171 | 172 | parser.add_argument( 173 | '--num_repeats', 174 | type=int, 175 | help='Number of sentence repeats after the original sentence', 176 | default=0 177 | ) 178 | 179 | parser.add_argument( 180 | '--block_attn', 181 | type=int, 182 | help='Whether to use blockwise causal attention', 183 | default=0 184 | ) 185 | 186 | parser.add_argument( 187 | '--select_k_per_class', 188 | type=int, 189 | help='Select k examples from each class (default -1, i.e. no changes to the original dataset)', 190 | default=-1 191 | ) 192 | 193 | parser.add_argument( 194 | '--lr', 195 | type=float, 196 | help='Learning rate', 197 | default=3e-5 198 | ) 199 | 200 | parser.add_argument( 201 | '--freeze_weights', 202 | type=int, 203 | help='Whether to freeze model weigts (except word emb)', 204 | default=0 205 | ) 206 | 207 | parser.add_argument( 208 | '--freeze_except', 209 | type=str, 210 | help='If freeze_weights==1, freeze all weights except those that contain this keyword', 211 | default='word_embeddings' 212 | ) 213 | 214 | parser.add_argument( 215 | '--task_list', 216 | nargs='+', 217 | help='List of tasks for training', 218 | required=True 219 | ) 220 | 221 | parser.add_argument( 222 | '--prompt_tuning', 223 | type=int, 224 | help='Perform prompt tuning (1 - True, 0 - False)', 225 | default=0 226 | ) 227 | 228 | parser.add_argument( 229 | '--same_prompt', 230 | type=int, 231 | help='Whether to use the same prompt for all tasks (1 - True, 0 - False)', 232 | default=0 233 | ) 234 | 235 | parser.add_argument( 236 | '--save_prompt', 237 | type=int, 238 | help='Save prompts in np arrays after training (1 - True, 0 - False)', 239 | default=0 240 | ) 241 | 242 | parser.add_argument( 243 | '--prompt_init', 244 | type=str, 245 | help='Initialization of next task prompts, if None - init from random word emb in the vocabulary', 246 | default='None' 247 | ) 248 | 249 | 250 | parser.add_argument( 251 | '--do_repeats', 252 | type=int, 253 | help='Perform progressive prompt tuning with repeated input (1 - True, 0 - False)', 254 | default=0 255 | ) 256 | 257 | 258 | parser.add_argument( 259 | '--cls_idx_override', 260 | type=int, 261 | help='Position of classification token; by default will use current task k token - CLSk', 262 | default=None 263 | ) 264 | 265 | parser.add_argument( 266 | '--multitask', 267 | type=int, 268 | help='Perform multi-task learning (1 - True, 0 - False)', 269 | default=0 270 | ) 271 | 272 | parser.add_argument( 273 | '--early_stopping', 274 | type=int, 275 | help='Perform early_stopping to select model at each task (1 - True, 0 - False)', 276 | default=1 277 | ) 278 | 279 | ##### IDBR style utils ##### 280 | parser.add_argument( 281 | '--one_head', 282 | type=int, 283 | help='Perform training with one large head for all tasks like in IDBR paper (1 - True, 0 - False)', 284 | default=0 285 | ) 286 | 287 | parser.add_argument( 288 | '--regspe', 289 | type=float, 290 | help='Regularization coef for task-specific features', 291 | default=0.5 292 | ) 293 | 294 | parser.add_argument( 295 | '--reggen', 296 | type=float, 297 | help='Regularization coef for general features', 298 | default=0.5 299 | ) 300 | 301 | parser.add_argument( 302 | '--tskcoe', 303 | type=float, 304 | help='Task loss coef', 305 | default=0.0 306 | ) 307 | 308 | parser.add_argument( 309 | '--nspcoe', 310 | type=float, 311 | help='NSP loss coef', 312 | default=0.0 313 | ) 314 | 315 | parser.add_argument( 316 | '--disen', 317 | type=int, 318 | help='Perform IDBR training (disentanglement)', 319 | default=0 320 | ) 321 | 322 | main(parser.parse_args()) 323 | -------------------------------------------------------------------------------- /BERT_codebase/train_soft_prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pandas as pd 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | import logging, os, argparse 7 | 8 | import dataset_utils, model_utils, continual_learning_utils, continual_learning_one_head 9 | 10 | 11 | def get_prefix_net(bottleneck_size = 800, network_type='MLP-1'): 12 | 13 | if network_type == 'MLP1': 14 | prefix_MLP = nn.Sequential( 15 | nn.Linear(768, bottleneck_size), 16 | nn.ReLU(), 17 | #nn.Linear(bottleneck_size, bottleneck_size), 18 | #nn.Tanh(), 19 | nn.Linear(bottleneck_size, 768), 20 | ) 21 | 22 | elif network_type == 'MLP2': 23 | prefix_MLP = nn.Sequential( 24 | nn.Linear(768, bottleneck_size), 25 | nn.ReLU(), 26 | nn.Linear(bottleneck_size, bottleneck_size), 27 | nn.Tanh(), 28 | nn.Linear(bottleneck_size, 768), 29 | ) 30 | 31 | elif network_type == 'transformer': 32 | device = 'cuda' 33 | encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=2, dropout=0.05).to(device) 34 | prefix_MLP = nn.TransformerEncoder(encoder_layer, num_layers=2).to(device) 35 | 36 | return prefix_MLP 37 | 38 | 39 | def train_prompt(args, CL_class, epochs = 50, cls_idx = 0, prompt_tuning=True): 40 | val_scores = [] 41 | task = args.task 42 | if torch.cuda.is_available(): 43 | device = torch.device("cuda") 44 | else: 45 | device = torch.device("cpu") 46 | 47 | dataloader_train, dataloader_val = CL_class.tasks_data_dict[task]['train'], CL_class.tasks_data_dict[task]['val'] 48 | #task = args.task #'cola' 49 | 50 | for epoch in range(epochs): 51 | print(epoch) 52 | CL_class.trainer.model.train().to(device) 53 | if args.prefix_len>0 and args.prefix_MLP != 'None': 54 | CL_class.trainer.prefix_MLP.train().to(device) 55 | 56 | for i, batch in enumerate(tqdm(dataloader_train)): 57 | batch = {k: v.to(device) for k, v in batch.items()} 58 | if prompt_tuning: # pass batch with prompt tuning 59 | CL_class.trainer.pass_batch_with_prompt(batch, task, device, prefix_len=args.prefix_len) 60 | else: # regular pass batch 61 | CL_class.trainer.pass_batch(batch, task, device, cls_idx=0) 62 | 63 | CL_class.trainer.model.eval() 64 | if args.prefix_len>0 and args.prefix_MLP != 'None': 65 | CL_class.trainer.prefix_MLP.eval() 66 | 67 | if prompt_tuning: 68 | result = CL_class.trainer.eval_with_prompt(dataloader_val, task, None, cls_idx=0) 69 | else: 70 | result = CL_class.trainer.eval(dataloader_val, task, None, cls_idx=0) 71 | print(' result = ',result, '\n') 72 | val_scores.append(result) 73 | return val_scores 74 | 75 | 76 | def main(args): 77 | 78 | save_path = os.path.join(args.save_dir, args.save_name) 79 | if not os.path.exists(save_path): 80 | os.mkdir(save_path) 81 | log_file = os.path.join(save_path, 'LR_logs') 82 | 83 | if args.prefix_MLP == 'None': 84 | prefix_net = None 85 | else: 86 | prefix_net = get_prefix_net(bottleneck_size = args.bottleneck, network_type=args.prefix_MLP) 87 | 88 | CL_class= continual_learning_utils.ContinualLearner(args.model_name, 89 | [args.task], # one task as a list of tasks 90 | batch_size=args.batch_size, 91 | select_k_per_class=args.select_k_per_class, 92 | memory_perc=args.memory_perc, 93 | block_attn=0, 94 | freeze_weights=args.freeze_weights, 95 | freeze_except=args.freeze_except, 96 | lr=args.lr, 97 | seq_len=args.seq_len, 98 | cls_idx_override=args.cls_idx_override, 99 | early_stopping= args.early_stopping==1, 100 | prefix_MLP=prefix_net, 101 | prefix_len=args.prefix_len, 102 | ) 103 | 104 | results = train_prompt(args, CL_class, epochs = args.num_epochs, cls_idx = 0, prompt_tuning= args.prompt_tuning==1) 105 | np.save(os.path.join(save_path, 'results.npy'), results) 106 | 107 | 108 | 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser( 114 | description='NLP training script in PyTorch' 115 | ) 116 | 117 | parser.add_argument( 118 | '--save_dir', 119 | type=str, 120 | help='base directory of all models / features (should not be changed)', 121 | default='/data/home/arazdai/CL/' 122 | ) 123 | 124 | parser.add_argument( 125 | '--save_name', 126 | type=str, 127 | help='folder name to save', 128 | required=True 129 | ) 130 | 131 | parser.add_argument( 132 | '--model_name', 133 | type=str, 134 | help='Name of the model used for training', 135 | default="bert-base-uncased" 136 | ) 137 | 138 | parser.add_argument( 139 | '--num_epochs', 140 | type=int, 141 | help='Number of epochs to train model', 142 | default=5 143 | ) 144 | 145 | parser.add_argument( 146 | '--memory_perc', 147 | type=float, 148 | help='Percentage of examples from previous tasks to use for data replay; if 0 then no data replay', 149 | default=0 #0.01 150 | ) 151 | 152 | parser.add_argument( 153 | '--data_replay_freq', 154 | type=int, 155 | help='Data replay happens after every X samples (if -1 then no data replay)', 156 | default=-1 #9 157 | ) 158 | 159 | parser.add_argument( 160 | '--batch_size', 161 | type=int, 162 | help='Batch size', 163 | default=8 164 | ) 165 | 166 | parser.add_argument( 167 | '--seq_len', 168 | type=int, 169 | help='Length of a single repeat (in #tokens)', 170 | default=512 171 | ) 172 | 173 | parser.add_argument( 174 | '--prompt_tuning', 175 | type=int, 176 | help='If 1, use prompt_batch_pass / eval (takes into account prefix MLP and wte), else use normal batch pass', 177 | default=1 178 | ) 179 | 180 | parser.add_argument( 181 | '--prefix_len', 182 | type=int, 183 | help='Length of the soft prompt to be tuned (default 0 - no soft prompt)', 184 | default=0 185 | ) 186 | 187 | parser.add_argument( 188 | '--num_repeats', 189 | type=int, 190 | help='Number of sentence repeats after the original sentence', 191 | default=0 192 | ) 193 | 194 | parser.add_argument( 195 | '--block_attn', 196 | type=int, 197 | help='Whether to use blockwise causal attention', 198 | default=0 199 | ) 200 | 201 | parser.add_argument( 202 | '--select_k_per_class', 203 | type=int, 204 | help='Select k examples from each class (default -1, i.e. no changes to the original dataset)', 205 | default=-1 206 | ) 207 | 208 | parser.add_argument( 209 | '--lr', 210 | type=float, 211 | help='Learning rate', 212 | default=3e-5 213 | ) 214 | 215 | parser.add_argument( 216 | '--freeze_weights', 217 | type=int, 218 | help='Whether to freeze model weigts (except word emb)', 219 | default=0 220 | ) 221 | 222 | parser.add_argument( 223 | '--freeze_except', 224 | type=str, 225 | help='If freeze_weights==1, freeze all weights except those that contain this keyword', 226 | default='word_embeddings' 227 | ) 228 | 229 | # parser.add_argument( 230 | # '--task_list', 231 | # nargs='+', 232 | # help='List of tasks for training', 233 | # required=True 234 | # ) 235 | 236 | parser.add_argument( 237 | '--task', 238 | type=str, 239 | help='Task to train on (only one)', 240 | required=True 241 | ) 242 | 243 | parser.add_argument( 244 | '--prefix_MLP', 245 | type=str, 246 | help='Type of network to pass soft prompt through', 247 | default='None' 248 | ) 249 | 250 | parser.add_argument( 251 | '--bottleneck', 252 | type=int, 253 | help='Bottleneck size of prefix MLP', 254 | default=800 255 | ) 256 | 257 | parser.add_argument( 258 | '--cls_idx_override', 259 | type=int, 260 | help='Position of classification token; by default will use current task k token - CLSk', 261 | default=None 262 | ) 263 | 264 | # parser.add_argument( 265 | # '--multitask', 266 | # type=int, 267 | # help='Perform multi-task learning (1 - True, 0 - False)', 268 | # default=0 269 | # ) 270 | 271 | parser.add_argument( 272 | '--early_stopping', 273 | type=int, 274 | help='Perform early_stopping to select model at each task (1 - True, 0 - False)', 275 | default=0 276 | ) 277 | 278 | main(parser.parse_args()) 279 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Progressive Prompts 2 | 3 | **Our work on Progressive Prompts is accepted to ICLR 2023!** 🎉 4 | 5 | This repo includes an original implementation of Anastasia Razdaibiedina, Yuning Mao, Rui Hou, Madian Khabsa, Mike Lewis and Amjad Almahairi. ["Progressive Prompts: Continual Learning for Language Models"](https://arxiv.org/abs/2301.12314), ICLR 2023. 6 | 7 | ### Table of contents 8 | * [Introduction](#star2-introduction) 9 | * [What's in this repository](#question-whats-in-this-repository) 10 | * [Installation](#wrench-installation) 11 | * [How to run](#zap-how-to-run) 12 | * [Contact](#raising_hand-questions) 13 | 14 | 15 | ## :star2: Introduction 16 | We introduce **Progressive Prompts** – a novel Continual Learning (CL) approach for language models. Our 17 | method is inspired by progressive networks ([A. Rusu et al., NeurIPS 2017](https://arxiv.org/pdf/1606.04671.pdf)), but is significantly more memory-efficient. In Progressive Prompts, we learn a separate set of virtual tokens, or ***soft prompt*** ([B. Lester et al., EMNLP 2021](https://arxiv.org/pdf/2104.08691.pdf)), for each incoming task and sequentially concatenate it with previously learned prompts. 18 | 19 | Our method can: 20 | 21 | 1) **alleviate catastrophic forgetting**; since it preserves the knowledge acquired by previous prompts, and 22 | 2) **transfer knowledge to future tasks**; since new prompts are sequentially concatenated with all prior prompts. 23 | 24 | ![Progressive Prompts schematics](/images/illustration.png) 25 | Figure: *Illustrating our proposed method **Progressive Prompts** and contrasting it with a simple 26 | adaptation of progressive networks using prompt tuning. In the simple adaptation of progressive 27 | networks we learn a separate prompt and repeat the frozen input embeddings for each new task. 28 | This setup requires repeating input tokens for each task. In Progressive Prompts we use the same 29 | input and progressively append new prompt for each new task. Prior task prompts are not modified 30 | by the addition of new prompts.* 31 | 32 | ## :question: What's in this repository 33 | 34 | This is our code structure: 35 | 36 | ``` 37 | |_T5_codebase/ 38 | |_t5_dataset.py --> T5 Dataset class for reading and processing datasets 39 | |_t5_continual.py --> Model class for T5 with prompt tuning and continual learning functions 40 | |_train_t5_cl.py --> Code to run continual learning experiments with T5 41 | 42 | |_BERT_codebase/ 43 | |_dataset_utils.py --> BERT Dataset class for reading and processing datasets 44 | |_model_utils.py --> Model class for BERT with prompt tuning and fine-tuning functions 45 | |_continual_learning_utils.py --> Continual Learner class for Progressive Prompts (with BERT) 46 | |_continual_learning_one_head.py --> Continual Learner class for regularization-based CL approaches for BERT 47 | |_train_cl2.py --> Code to run continual learning experiments with BERT 48 | 49 | |_datasets/src/data/ --> CL datasets from Zhang et. al., 2015 50 | |_amazon --> Amazon reviews (zip archive, since dataset is not available through HuggingFace datasets) 51 | (the rest of datasets can be either accessed through HuggingFace or downloaded by instructions below) 52 | ``` 53 | 54 | **Note**: we access most of the datasets for our experiments through HuggingFace datasets, including CL datasets from Zhang et. al., 2015. Since only one CL datasets from Zhang et. al. is not available on HuggingFace - Amazon Reviews, we uploaded its archived train / test data to ```datasets/src/data/amazon/```. To access the rest of CL datasets (Yelp, Yahoo, AG, DbPedia), you can either use their HuggingFace names in our training script or download them from [http://goo.gl/JyCnZq](http://goo.gl/JyCnZq) to ```datasets/src/data/```. 55 | 56 | ## :wrench: Installation 57 | 58 | Our implementation is based on PyTorch and HuggingFace (transformers + datasets). 59 | 60 | Requirements: 61 | * Python 3.8.5 62 | * Pytorch 1.10.0 63 | * transformers 4.20.0 64 | * datasets 2.3.2 65 | * tqdm, sklearn, numpy, pandas 66 | 67 | Step-by-step instructions to get you running Progressive Prompts: 68 | 69 | ### 1) Clone this repository to your local machine: 70 | 71 | ```bash 72 | git clone https://github.com/arazd/ProgressivePrompts 73 | ``` 74 | 75 | A folder called ```ProgressivePrompts``` with all the codebase should appear. 76 | 77 | ### 2) Install the required packages: 78 | 79 | Make sure that you have Anaconda installed. If not - follow this [miniconda installation](https://docs.conda.io/en/latest/miniconda.html). 80 | 81 | To run Progressive Prompts code on GPU, make sure that you have a CUDA capable GPU and the [drivers](https://www.nvidia.com/download/index.aspx?lang=en-us) for your GPU are up to date. In our implementation, we used and CUDA 11.0. 82 | 83 | You can re-create our conda enviroment from ```environment.yaml``` file: 84 | 85 | ```bash 86 | cd ProgressivePrompts 87 | conda env create -f environment.yaml 88 | ``` 89 | 90 | Your conda should start downloading and extracting packages. This can take ~15-20 minutes. 91 | 92 | ### 3) Activate the environment: 93 | 94 | Your environment should be called ```nlp```, and you can activate it now to run the scripts: 95 | 96 | ```bash 97 | conda activate nlp 98 | ``` 99 | 100 | ## :zap: How to run 101 | 102 | For example, to run Progressive Prompts with T5-large on four tasks (IMDb, CB, SST-2 and DbPedia): 103 | ```bash 104 | cd T5_codebase 105 | 106 | python train_t5_cl.py --task_list imdb cb sst2 dbpedia_14 --select_k_per_class 1000 \ 107 | --lr 0.3 --num_epochs 10 --freeze_weights 1 --prefix_len 10 \ 108 | --model_name t5-large --early_stopping 1 \ 109 | --save_name T5_experiment --save_dir my_path_to_save_directory 110 | ``` 111 | 112 | In the example above, we froze weights and trained a prompt of size 10 (per task) for 10 epochs. We also limited data to 1000 samples per class. 113 | For other arguments and their descriptions, please check ```T5_codebase/train_t5_cl.py``` file. 114 | 115 | 116 | To train Progressive Prompts on the same four tasks with BERT-base: 117 | ```bash 118 | cd BERT_codebase 119 | 120 | python train_cl2.py --task_list imdb cb sst2 dbpedia_14 --select_k_per_class 1000 \ 121 | --lr 3e-5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings \ 122 | --prompt_tuning 1 --prefix_len 10 --seq_len 450 --one_head 0 \ 123 | --model_name bert-base-uncased --early_stopping 1 \ 124 | --save_name BERT_experiment --save_dir my_path_to_save_directory 125 | ``` 126 | 127 | Note how soft prompts for BERT need to be trained with smaller learning rate and higher number of epochs. 128 | We also have some other BERT-specific arguments, one_head controls whether to use a separate head for each task, freeze_except allows to freeze all weights except word embeddings (since we include prompt tokens into vocabulary for BERT implementation), seq_len controls max input length (without prompt), prompt_tuning flag signals if we are doing prompt tuning. 129 | For other arguments and their descriptions, please check ```BERT_codebase/train_cl2.py``` file. 130 | 131 | 139 | 140 | 141 | ## :raising_hand: Questions 142 | If you have any questions about the paper or code, please contact Anastasia Razdaibiedina (anastasia.razdaibiedina[at]mail.utoronto.ca) or open an issue. 143 | 144 | ## :books: Citation 145 | If you use our code in your research, please cite our work: 146 | ```bibtex 147 | @inproceedings{razdaibiedina2023progressive, 148 | title={Progressive Prompts: Continual Learning for Language Models}, 149 | author={Razdaibiedina, Anastasia and Mao, Yuning and Hou, Rui and Khabsa, Madian and Lewis, Mike and Almahairi, Amjad}, 150 | booktitle={International Conference on Learning Representations}, 151 | year={2023} 152 | } 153 | ``` 154 | 155 | 163 | -------------------------------------------------------------------------------- /T5_codebase/t5_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | from datasets import load_dataset 6 | import datasets 7 | 8 | 9 | class T5Dataset: 10 | def __init__(self, tokenizer, task): 11 | """Dataset class for T5 model experiments. 12 | Args: 13 | task (str): Name of the downstream task. 14 | tokenizer (HuggingFace Tokenizer): T5 model tokenizer to use. 15 | """ 16 | 17 | self.tokenizer = tokenizer 18 | self.glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', \ 19 | 'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax'] 20 | self.superglue_datasets = ['copa', 'boolq', 'wic', 'wsc', 'cb', 'record', 'multirc', 'rte_superglue', 'wsc_bool'] 21 | 22 | # Column keys used in the dataset 23 | self.task_to_keys = { 24 | "cola": ("sentence", None), 25 | "mnli": ("premise", "hypothesis"), 26 | "mnli-mm": ("premise", "hypothesis"), 27 | "mrpc": ("sentence1", "sentence2"), 28 | #"qnli": ("question", "sentence"), 29 | "qnli": ("text1", "text2"), 30 | "qqp": ("question1", "question2"), 31 | "rte": ("sentence1", "sentence2"), 32 | "sst2": ("sentence", None), 33 | "stsb": ("sentence1", "sentence2"), 34 | "wnli": ("sentence1", "sentence2"), 35 | 36 | "boolq": ("passage", "question"), 37 | "copa": ('choice1', 'choice2', 'premise', 'question'), 38 | "wic": ("start1", "end1", "sentence1", "start2", "end2", "sentence2", "word"), 39 | "wsc": ("span1_text", "span1_index", "span2_text", "span2_index", "text"), 40 | "wsc_bool": ("span1_text", "span1_index", "span2_text", "span2_index", "text"), 41 | "cb": ("premise", "hypothesis"), 42 | "record": ("passage", "query", "entities"), 43 | "multirc": ("question", "answer", "paragraph"), 44 | "rte_superglue": ("premise", "hypothesis"), 45 | 46 | "scicite": ("sectionName", "string"), 47 | "imdb": ("text", None), 48 | 49 | "ag_news": ("text", None), 50 | "yelp_review_full": ("text", None), 51 | "yahoo_answers_topics": ("question_content", "best_answer"), 52 | "dbpedia_14": ("title", "content"), 53 | 54 | "ag": ("content", None), 55 | "yelp": ("content", None), 56 | "yahoo": ("content", None), 57 | "dbpedia": ("content", None), 58 | "amazon": ("content", None), 59 | } 60 | 61 | # Label text for T5 tasks 62 | # (T5 has text-to-text format for text and labels) 63 | self.task_to_labels = { 64 | "cola": ("not_acceptable", "acceptable"), 65 | "mnli": ("entailment", "neutral", "contradiction"), 66 | "mnli-mm": (), 67 | "mrpc": ("not_equivalent", "equivalent"), 68 | "qnli": ("entailment", "not_entailment"), 69 | "qqp": ("not_duplicate", "duplicate"), 70 | "rte": ("entailment", "not_entailment"), 71 | "sst2": ("negative", "positive"), 72 | "stsb": (), 73 | "wnli": (), 74 | 75 | "boolq": ("false", "true"), 76 | "copa": ("false", "true"), 77 | "wic": ("false", "true"), 78 | "wsc_bool": ("false", "true"), 79 | "cb": ("entailment", "contradiction", "neutral"), 80 | "multirc": ("false", "true"), 81 | "rte_superglue": ("entailment", "not_entailment"), 82 | 83 | "scicite": (), 84 | "imdb": ("negative", "positive"), 85 | 86 | "ag_news": ("world", "sports", "business", "science"), 87 | "yelp_review_full": ("terrible", "bad", "middle", "good", "wonderful"), 88 | "yahoo_answers_topics": ("society and culture", "science", "health", "education and reference", 89 | "computers and internet", "sports", "business", "entertainment and music", 90 | "family and relationships", "politics and government"), 91 | "dbpedia_14": ("company", "educationalinstitution", "artist", "athlete", "officeholder", 92 | "meanoftransportation", "building", "naturalplace", "village", "animal", 93 | "plant", "album", "film", "writtenwork"), 94 | 95 | "ag": ("world", "sports", "business", "science"), 96 | "yelp": ("terrible", "bad", "middle", "good", "wonderful"), 97 | "yahoo": ("society and culture", "science", "health", "education and reference", 98 | "computers and internet", "sports", "business", "entertainment and music", 99 | "family and relationships", "politics and government"), 100 | "dbpedia": ("company", "educationalinstitution", "artist", "athlete", "officeholder", 101 | "meanoftransportation", "building", "naturalplace", "village", "animal", 102 | "plant", "album", "film", "writtenwork"), 103 | "amazon": ("terrible", "bad", "middle", "good", "wonderful"), 104 | } 105 | 106 | self.task = task 107 | self.label_key = 'label' 108 | if 'yahoo_' in task: self.label_key = 'topic' 109 | if 'stsb' in task: self.label_key = 'similarity_score' 110 | if task=='record': self.label_key = 'answers' 111 | 112 | 113 | # Helper function to save idx of multirc questions (needed later for test metric computation) 114 | def save_multirc_questions_idx(self, val_ds): 115 | idx = [] 116 | i = 0 117 | x_prev, y_prev= val_ds['paragraph'][0], val_ds['question'][0] 118 | 119 | for x,y in zip(val_ds['paragraph'], val_ds['question']): 120 | if x_prev!=x or y_prev!=y: 121 | i += 1 122 | x_prev = x 123 | y_prev = y 124 | idx.append(i) 125 | self.multirc_idx = np.array(idx) 126 | 127 | 128 | # Helper function to select a subset of k samples per class in a dataset 129 | def select_subset_ds(self, ds, k=2000, seed=0): 130 | if self.task in ['stsb', 'record', 'wsc']: # non-discrete labels 131 | idx_total = np.random.choice(np.arange(ds.shape[0]), min(k,ds.shape[0]), replace=False) 132 | 133 | else: 134 | label_key = self.label_key 135 | N = len(ds[label_key]) 136 | idx_total = np.array([], dtype='int64') 137 | 138 | for l in set(ds[label_key]): 139 | idx = np.where(np.array(ds[label_key]) == l)[0] 140 | idx_total = np.concatenate([idx_total, # we cannot take more samples than there are available 141 | np.random.choice(idx, min(k, idx.shape[0]), replace=False)]) 142 | 143 | np.random.seed(seed) 144 | np.random.shuffle(idx_total) 145 | return ds.select(idx_total) 146 | 147 | 148 | # WSC task function to preprocess raw input & label text into tokenized dictionary 149 | def process_wsc(self, wsc_row): 150 | text_proc = wsc_row['text'].split(' ') 151 | #text_proc[wsc_row['span1_index']] = '*' + text_proc[wsc_row['span1_index']] +'*' 152 | target = text_proc[wsc_row['span1_index']] 153 | text_proc[wsc_row['span2_index']] = '*' + text_proc[wsc_row['span2_index']] + '*' 154 | text_proc = (' ').join(text_proc) 155 | return text_proc, target 156 | 157 | 158 | # Function to preprocess raw input & label text into tokenized dictionary 159 | def preprocess_function(self, examples, task, 160 | max_length=512, max_length_target=2, 161 | prefix_list=[]): 162 | tokenizer = self.tokenizer 163 | keys = self.task_to_keys[task] 164 | label_key = self.label_key 165 | 166 | if keys[1]!=None: 167 | if task=='record': 168 | text = 'passage : ' + str(examples['passage']) + ' query: ' + str(examples['query']) + ' entities: ' + ('; ').join((examples['entities'])) 169 | elif task=='wsc': 170 | text, target = self.process_wsc(examples) 171 | else: 172 | text = '' 173 | for key in keys: 174 | text += key + ': ' + str(examples[key]) + ' ' 175 | else: 176 | text = examples[keys[0]] 177 | 178 | if len(prefix_list)>0: 179 | text = (' ').join(prefix_list) + ' ' + text 180 | source = tokenizer(text.strip()+' ', 181 | truncation=True, 182 | #padding=False, 183 | padding='max_length', 184 | max_length=max_length) 185 | 186 | if task=='stsb': 187 | target = str(examples[label_key])[:3] 188 | elif task=='record': 189 | target = '; '.join(examples[label_key]) 190 | elif task=='wsc': 191 | pass # already obtained target 192 | else: 193 | target = self.task_to_labels[task][examples[label_key]] 194 | target += ' ' 195 | target = tokenizer( 196 | target, max_length=max_length_target, pad_to_max_length=True, #return_tensors="pt" 197 | ) 198 | 199 | dict_final = {"source_ids": source['input_ids'], 200 | "source_mask": source['attention_mask'], 201 | "target_ids": target['input_ids'], 202 | "target_mask": target['attention_mask']} 203 | return dict_final 204 | 205 | 206 | 207 | def get_final_ds(self, 208 | task, 209 | split, 210 | batch_size, 211 | k=-1, 212 | seed=0, 213 | return_test=False, 214 | target_len=2, 215 | max_length=512, 216 | prefix_list=[]): 217 | """Function that returns final T5 dataloader. 218 | Args: 219 | task (str): Name of the downstream task. 220 | split (str): Which data split to use (train/validation/test). 221 | batch_size (int): Batch size to use in the dataloader. 222 | k (int, optional): Number of samples to use for each class. Defaults to -1, not sub-sample the data. 223 | seed (int, optional): Seed used for random shuffle. Defaults to 0. 224 | return_test (bool, optional): Whether to create a test split. 225 | When True, two Dataloaders are returned. Defaults to False. 226 | target_len (int, optional): Length of the model output (in tokens). Defaults to 2. 227 | max_length (int, optional): Length of the model input (in tokens). Defaults to 512. 228 | prefix_list (List[str], optional): List of prompt virtual tokens to pre-pend to the input. 229 | We do not encode soft prompt as extra virtual tokens in the latest implementation. 230 | Defaults to [], empty list. 231 | 232 | Returns: 233 | Dataloader: Torch Dataloader with preprocessed input text & label. 234 | """ 235 | 236 | if task in ['amazon']: # amazon not available with hugging face 237 | df = pd.read_csv('../datasets/src/data/'+task+'/'+split+'.csv', header=None) 238 | df = df.rename(columns={0: "label", 1: "title", 2: "content"}) 239 | df['label'] = df['label'] - 1 240 | dataset = datasets.Dataset.from_pandas(df) 241 | elif task == 'mnli': 242 | dataset = load_dataset('LysandreJik/glue-mnli-train', split=split) 243 | elif task == 'qnli': 244 | dataset = load_dataset('SetFit/qnli', split=split) 245 | elif task == 'stsb': 246 | dataset = load_dataset('stsb_multi_mt', name='en', split=split if split=='train' else 'dev') 247 | else: 248 | if task not in self.glue_datasets and task not in self.superglue_datasets: 249 | dataset = load_dataset(task, split=split) 250 | else: 251 | benchmark = 'glue' if task not in self.superglue_datasets else 'super_glue' 252 | dataset = load_dataset(benchmark, 253 | task.replace('_superglue', '').replace('_bool', ''), 254 | split=split) 255 | 256 | # For yahoo dataset we need to filter out empty rows 257 | # (i.e. where "question" field is empty) 258 | if self.task == "yahoo_answers_topics": 259 | if split=='train': 260 | good_id = np.load('good_id_yahoo_train.npy') 261 | dataset = dataset.select(good_id) 262 | elif split=='test': 263 | good_id = np.load('good_id_yahoo_test.npy') 264 | dataset = dataset.select(good_id) 265 | 266 | # Using Lester et al. setting for WSC task, e.g. 267 | # using only positive samples (for output generation) 268 | if self.task == 'wsc': 269 | idx = np.where(np.array(dataset['label']) == 1)[0] 270 | dataset = dataset.select(idx) 271 | 272 | # Selecting k subset of the samples (if requested) 273 | if k!=-1: 274 | dataset = self.select_subset_ds(dataset, k=k) 275 | 276 | if k==-1 and split!='train' and self.task=='multirc': 277 | # we do not shuffle full validation set of multirc 278 | # but we save idx of the same questions 279 | # which are used for multirc test metric computation 280 | self.save_multirc_questions_idx(dataset) 281 | else: 282 | dataset = dataset.shuffle(seed=seed) 283 | 284 | # Returning the selected data split (train/val/test) 285 | if return_test==False: 286 | encoded_dataset = dataset.map(lambda x: self.preprocess_function(x, task, 287 | max_length=max_length, 288 | max_length_target=target_len, 289 | prefix_list=prefix_list), 290 | batched=False) 291 | encoded_dataset.set_format(type='torch', columns=['source_ids', 'source_mask', 292 | 'target_ids', 'target_mask']) 293 | dataloader = DataLoader(encoded_dataset, batch_size=batch_size) 294 | 295 | return dataloader 296 | 297 | # Creating an extra test set from the selected data split 298 | else: 299 | N = len(dataset) 300 | dataset_val = dataset.select(np.arange(0, N//2)) 301 | dataset_test = dataset.select(np.arange(N//2, N)) 302 | 303 | dataloaders_val_test = [] 304 | for dataset in [dataset_val, dataset_test]: 305 | encoded_dataset = dataset.map(lambda x: self.preprocess_function(x, task, 306 | max_length=max_length, 307 | max_length_target=target_len, 308 | prefix_list=prefix_list), 309 | batched=False) 310 | encoded_dataset.set_format(type='torch', columns=['source_ids', 'source_mask', 311 | 'target_ids', 'target_mask']) 312 | dataloader = DataLoader(encoded_dataset, batch_size=batch_size) 313 | dataloaders_val_test.append(dataloader) 314 | 315 | return dataloaders_val_test 316 | -------------------------------------------------------------------------------- /T5_codebase/train_prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pandas as pd 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | import logging, os, argparse 7 | 8 | import t5_model, t5_dataset 9 | from copy import deepcopy 10 | from transformers import AdamW 11 | 12 | 13 | class ResMLP(torch.nn.Module): 14 | def __init__(self, bottleneck_size, 15 | module_type='MLP1', 16 | emb_dimension=512, 17 | residual=True, 18 | dropout=0.0, 19 | #layer_norm=True 20 | ): 21 | super().__init__() 22 | if module_type=='MLP1': 23 | if dropout>0: 24 | self.module = nn.Sequential( 25 | nn.Linear(emb_dimension, bottleneck_size), 26 | nn.ReLU(), 27 | #nn.Tanh(), 28 | nn.Linear(bottleneck_size, emb_dimension), 29 | #nn.LayerNorm(emb_dimension), 30 | nn.Dropout(dropout) 31 | ) 32 | else: 33 | self.module = nn.Sequential( 34 | nn.Linear(emb_dimension, bottleneck_size), 35 | #nn.ReLU(), 36 | nn.Tanh(), 37 | nn.Linear(bottleneck_size, emb_dimension), 38 | ) 39 | 40 | elif module_type=='MLP2': 41 | self.module = nn.Sequential( 42 | nn.Linear(emb_dimension, bottleneck_size), 43 | nn.ReLU(), 44 | nn.Linear(bottleneck_size, bottleneck_size // 2), 45 | nn.Tanh(), 46 | nn.Linear(bottleneck_size // 2, emb_dimension), 47 | #nn.LayerNorm(emb_dimension), 48 | ) 49 | 50 | elif module_type=='transformer': 51 | device = 'cuda' 52 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dimension, nhead=2, dropout=0.05).to(device) 53 | self.module = nn.TransformerEncoder(self.encoder_layer, num_layers=2).to(device) 54 | 55 | self.residual = residual 56 | if self.residual: 57 | print('Using skip connection in MLP') 58 | 59 | def forward(self, inputs): 60 | if self.residual: 61 | return self.module(inputs) + inputs 62 | else: 63 | return self.module(inputs) 64 | 65 | 66 | 67 | def train_step_lester(trainer, batch, prefix_len, embed_prompt=False): 68 | model = trainer.model 69 | if embed_prompt: 70 | mlp = model.mlp 71 | tokenizer = trainer.tokenizer 72 | 73 | batch = {k: batch[k].to(trainer.device) for k in batch} 74 | lm_labels = batch["target_ids"] 75 | lm_labels[lm_labels[:, :] == tokenizer.pad_token_id] = -100 76 | 77 | inputs_embeds = model.encoder.embed_tokens(batch["source_ids"]) 78 | 79 | k = inputs_embeds.shape[0] 80 | if embed_prompt: 81 | prompt = mlp(model.prompt) 82 | else: 83 | prompt = model.prompt 84 | 85 | inputs_embeds = torch.concat([prompt.repeat(k, 1, 1), 86 | inputs_embeds], axis=1)[:,:512] 87 | 88 | source_mask_updated = torch.concat( (batch["source_mask"][0][0].repeat(k,prefix_len), 89 | batch["source_mask"]), axis=1)[:,:512] 90 | 91 | encoder_outputs = model.encoder( 92 | #input_ids=batch["source_ids"], 93 | attention_mask=source_mask_updated, #batch["source_mask"], 94 | #labels=lm_labels, 95 | #decoder_attention_mask=batch['target_mask'] 96 | #input_ids=input_ids, 97 | #attention_mask=attention_mask, 98 | inputs_embeds=inputs_embeds, 99 | head_mask=None, #head_mask, 100 | output_attentions=None, #output_attentions, 101 | output_hidden_states=None, #output_hidden_states, 102 | return_dict=None, #return_dict, 103 | ) 104 | 105 | outputs = model( 106 | input_ids=batch["source_ids"], 107 | attention_mask=source_mask_updated, #batch["source_mask"], 108 | labels=lm_labels, 109 | decoder_attention_mask=batch['target_mask'], 110 | encoder_outputs=encoder_outputs, 111 | ) 112 | loss = outputs[0] 113 | 114 | return loss 115 | 116 | 117 | 118 | def validate_lester(trainer, dataloader_val, task, embed_prompt, prefix_len, 119 | class_keys=['equivalent', 'different'], 120 | max_length=2, print_outputs=False): 121 | model = trainer.model 122 | if embed_prompt: 123 | mlp = model.mlp 124 | prompt = mlp(model.prompt) 125 | else: 126 | prompt = model.prompt 127 | 128 | tokenizer = trainer.tokenizer 129 | model.eval() 130 | 131 | corr, total = 0, 0 132 | loss_total = [] 133 | # try: 134 | # metric = datasets.load_metric('glue', task) 135 | # except: 136 | # metric = datasets.load_metric('accuracy') 137 | 138 | for i, batch in enumerate(tqdm(dataloader_val)): 139 | batch = {k:batch[k].to(trainer.device) for k in batch} 140 | 141 | inputs_embeds = model.encoder.embed_tokens(batch["source_ids"]).to(trainer.device) 142 | k = inputs_embeds.shape[0] 143 | 144 | inputs_embeds = torch.concat([prompt.repeat(k, 1, 1), 145 | inputs_embeds], axis=1)[:,:512] 146 | 147 | source_mask_updated = torch.concat( (batch["source_mask"][0][0].repeat(k,prefix_len), 148 | batch["source_mask"]), axis=1)[:,:512] 149 | 150 | 151 | encoder_outputs = model.encoder( 152 | #input_ids=batch["source_ids"], 153 | #attention_mask=batch["source_mask"], 154 | attention_mask=source_mask_updated, 155 | #labels=lm_labels, 156 | #decoder_attention_mask=batch['target_mask'] 157 | #input_ids=input_ids, 158 | #attention_mask=attention_mask, 159 | inputs_embeds=inputs_embeds, 160 | head_mask=None, #head_mask, 161 | output_attentions=None, #output_attentions, 162 | output_hidden_states=None, #output_hidden_states, 163 | return_dict=None, #return_dict, 164 | ) 165 | 166 | outs = model.generate( 167 | input_ids=batch["source_ids"], 168 | #attention_mask=batch["source_mask"], 169 | attention_mask=source_mask_updated, 170 | #labels=lm_labels, 171 | #decoder_attention_mask=batch['target_mask'], 172 | encoder_outputs=encoder_outputs, 173 | max_length=max_length, 174 | ) 175 | dec = [tokenizer.decode(ids) for ids in outs] 176 | texts = [tokenizer.decode(ids) for ids in batch['source_ids']] 177 | targets = [tokenizer.decode(ids) for ids in batch['target_ids']] 178 | 179 | #print(dec, texts, targets) 180 | corr += np.sum([trainer.process_str(x)==trainer.process_str(y) for x,y in zip(dec, targets)]) 181 | total += batch['source_ids'].shape[0] 182 | 183 | if i<10 and print_outputs: 184 | print(dec) 185 | print(targets) 186 | 187 | # CHANGE FOR MULTI CLASS!!! 188 | # metric.add_batch(predictions=[1 if class_keys[1] in x else 0 for x in dec], 189 | # references=[1 if class_keys[1] in x else 0 for x in targets]) 190 | 191 | # computing loss 192 | lm_labels = batch["target_ids"] 193 | lm_labels[lm_labels[:, :] == trainer.tokenizer.pad_token_id] = -100 194 | 195 | outputs = model( 196 | input_ids=batch["source_ids"], 197 | #attention_mask=batch["source_mask"], 198 | attention_mask=source_mask_updated, 199 | labels=lm_labels, 200 | decoder_attention_mask=batch['target_mask'], 201 | encoder_outputs=encoder_outputs, 202 | ) 203 | loss = outputs[0].detach().cpu().numpy() 204 | loss_total.append(loss) 205 | 206 | return corr/total, np.mean(loss_total) 207 | 208 | 209 | def train(TrainerT5, 210 | task, 211 | dataloader_train, 212 | dataloader_val, 213 | embed_prompt, 214 | class_keys, 215 | target_len, 216 | prefix_len, 217 | epochs=40, 218 | save_path=None): 219 | 220 | print('task = ', task) 221 | print("Using MLP? ", embed_prompt) 222 | model = TrainerT5.model 223 | model.to('cuda') 224 | 225 | #embed_prompt = True 226 | results_dict = {} 227 | results_dict['val'] = {'acc': [], 'loss': []} 228 | results_dict['train'] = {'acc': [], 'loss': []} 229 | 230 | for epoch in range(epochs): 231 | 232 | model.train() 233 | #mlp.train() 234 | 235 | for i, batch in enumerate(tqdm(dataloader_train)): 236 | batch = {k:batch[k].to('cuda') for k in batch} 237 | #loss = train_step_lester(TrainerT5, batch, TrainerT5.model.prompt, embed_prompt=embed_prompt) 238 | loss = train_step_lester(TrainerT5, batch, prefix_len, embed_prompt=embed_prompt) 239 | loss.backward() 240 | 241 | TrainerT5.optimizer.step() 242 | TrainerT5.optimizer.zero_grad() 243 | 244 | for dataloader, name in zip([dataloader_val, dataloader_train], 245 | ['val', 'train']): 246 | #for dataloader, name in zip([dataloader_val], 247 | # ['val']): 248 | 249 | acc, loss = validate_lester(TrainerT5, dataloader_val, task, 250 | embed_prompt, prefix_len, 251 | class_keys=class_keys, 252 | max_length=target_len, print_outputs=True) 253 | results_dict[name]['acc'].append(acc) 254 | results_dict[name]['loss'].append(loss) 255 | print(epoch, name, '->', acc, loss) 256 | #print('train acc ->', train_acc, train_f1) 257 | 258 | if save_path!=None and epoch%5==0: 259 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 260 | return results_dict 261 | 262 | 263 | 264 | 265 | def get_prompt(trainer, prompt_len): 266 | model = trainer.model 267 | N = model.encoder.embed_tokens.weight.shape[0] 268 | prompt_weigths = [] 269 | 270 | for i in range(prompt_len): 271 | with torch.no_grad(): 272 | j = np.random.randint(N) 273 | #j = 21 274 | w = deepcopy(model.encoder.embed_tokens.weight[j].detach().cpu().numpy()) 275 | prompt_weigths.append(w) 276 | prompt_weigths = np.array(prompt_weigths) 277 | return prompt_weigths 278 | 279 | 280 | 281 | def main(args): 282 | 283 | save_path = os.path.join(args.save_dir, args.save_name) 284 | if not os.path.exists(save_path): 285 | os.mkdir(save_path) 286 | 287 | TrainerT5= t5_model.PromptModelT5(model_name=args.model_name, 288 | prefix_len=0, 289 | freeze_weights=args.freeze_weights==1, 290 | freeze_except='xxxshared', # freeze all weights 291 | lr=args.lr, 292 | weight_decay=0.00, 293 | prompt_name='PRE', 294 | prefix_MLP='None', # using custom prefix MLP 295 | #mlp_bottleneck=args.mlp_bottleneck, 296 | #weight_decay_mlp=0.0, 297 | #mlp_lr=args.lr_mlp, 298 | #mlp_layer_norm=args.mlp_layer_norm==1, 299 | early_stopping=False, 300 | #opt=args.optimizer, 301 | ) 302 | 303 | prompt_weigths = get_prompt(TrainerT5, prompt_len=args.prefix_len) 304 | TrainerT5.model.prompt = nn.Parameter(torch.tensor(prompt_weigths, requires_grad=True)) 305 | print('created prompt: ', prompt_weigths.shape) 306 | 307 | if args.prefix_MLP != 'None': 308 | # adding MLP reparametrization for prompt 309 | print('Using MLP') 310 | TrainerT5.model.mlp = ResMLP(bottleneck_size=args.mlp_bottleneck, 311 | residual=args.residual_mlp==1, 312 | module_type=args.prefix_MLP, 313 | emb_dimension=prompt_weigths.shape[1], 314 | dropout=args.mlp_dropout, 315 | #layer_norm=False 316 | ) 317 | 318 | lr_mlp = args.lr_mlp if args.lr_mlp!=-1 else args.lr 319 | optimizer_grouped_parameters = [ 320 | { 321 | "params": [p for n, p in TrainerT5.model.named_parameters() if n=='prompt'], 322 | "weight_decay": 1e-5, 323 | "lr": args.lr, 324 | }, 325 | 326 | { 327 | "params": [p for n, p in TrainerT5.model.named_parameters() if 'mlp' in n], 328 | "weight_decay": 1e-5, 329 | "lr": lr_mlp, 330 | }, 331 | 332 | ] 333 | 334 | TrainerT5.optimizer = AdamW(optimizer_grouped_parameters, eps=1e-8) 335 | #TrainerT5.optimizer = Adafactor(optimizer_grouped_parameters)# eps=1e-8) 336 | #TrainerT5.optimizer 337 | 338 | task = args.task #'mrpc' 339 | target_len = args.target_len #2 340 | if task=='rte' or task=='mrpc': target_len=5 341 | 342 | ds2 = t5_dataset.T5Dataset(TrainerT5.tokenizer, task) 343 | dataloader_train = ds2.get_final_ds(task, 'train', batch_size=args.batch_size, k=args.select_k_per_class, 344 | target_len=target_len, prefix_list=[]) 345 | 346 | k_val = -1 if (args.select_k_per_class==-1 or task in ['mrpc', 'rte']) else int(0.2*args.select_k_per_class) 347 | dataloader_val = ds2.get_final_ds(task, 'validation', 348 | batch_size=args.batch_size, k=k_val, return_test=False, 349 | target_len=target_len, prefix_list=[]) 350 | 351 | class_keys = ds2.task_to_labels[task] 352 | results_dict = train(TrainerT5, 353 | task, 354 | dataloader_train, 355 | dataloader_val, 356 | embed_prompt=args.prefix_MLP != 'None', 357 | class_keys=class_keys, 358 | target_len=target_len, 359 | prefix_len=args.prefix_len, 360 | epochs=args.epochs, 361 | save_path=save_path) 362 | 363 | if args.early_stopping==1: 364 | TrainerT5.load_best_model() # for early stopping 365 | 366 | # test_acc, test_f1 = validate(TrainerT5, dataloader_test, task, ds.task_to_labels[task], target_len) 367 | # results_dict['test'] = {} 368 | # results_dict['test']['acc_direct'] = test_acc 369 | # for key in test_f1: 370 | # results_dict['test'][key] = test_f1[key] 371 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 372 | 373 | 374 | 375 | 376 | if __name__ == "__main__": 377 | parser = argparse.ArgumentParser( 378 | description='NLP training script in PyTorch' 379 | ) 380 | 381 | parser.add_argument( 382 | '--save_dir', 383 | type=str, 384 | help='base directory of all models / features (should not be changed)', 385 | default='/data/home/arazdai/T5_prompts/results/' 386 | ) 387 | 388 | parser.add_argument( 389 | '--save_name', 390 | type=str, 391 | help='folder name to save', 392 | required=True 393 | ) 394 | 395 | parser.add_argument( 396 | '--task', 397 | type=str, 398 | help='task to train t5 (e.g. mrpc, cola, rte)', 399 | required=True 400 | ) 401 | 402 | parser.add_argument( 403 | '--model_name', 404 | type=str, 405 | help='t5 model type', 406 | default='t5-small' 407 | ) 408 | 409 | parser.add_argument( 410 | '--batch_size', 411 | type=int, 412 | help='batch size', 413 | default=8 414 | ) 415 | 416 | 417 | # parser.add_argument( 418 | # '--optimizer', 419 | # type=str, 420 | # help='Which optimizer to use? (AdamW, LAMB etc.)', 421 | # default='AdamW' 422 | # ) 423 | 424 | parser.add_argument( 425 | '--select_k_per_class', 426 | type=int, 427 | help='Select k instances per class (default -1 = select all)', 428 | default=-1 429 | ) 430 | 431 | parser.add_argument( 432 | '--epochs', 433 | type=int, 434 | help='Number of epochs', 435 | default=50 436 | ) 437 | 438 | parser.add_argument( 439 | '--target_len', 440 | type=int, 441 | help='maximum length of the output (in tokens)', 442 | default=2 443 | ) 444 | 445 | parser.add_argument( 446 | '--prefix_len', 447 | type=int, 448 | help='prompt length (in tokens)', 449 | default=50 450 | ) 451 | 452 | parser.add_argument( 453 | '--early_stopping', 454 | type=int, 455 | help='Perform early stopping (1 - True, 0 - False)', 456 | default=1 457 | ) 458 | 459 | parser.add_argument( 460 | '--freeze_weights', 461 | type=int, 462 | help='Whether to freeze model weigts', 463 | default=0 464 | ) 465 | 466 | # parser.add_argument( 467 | # '--freeze_except', 468 | # type=str, 469 | # help='If freeze_weights==1, freeze all weights except those that contain this keyword', 470 | # default='shared' # shared stands for wte 471 | # ) 472 | 473 | 474 | 475 | 476 | parser.add_argument( 477 | '--lr', 478 | type=float, 479 | help='Learning rate', 480 | default=0.01 481 | ) 482 | 483 | parser.add_argument( 484 | '--mlp_dropout', 485 | type=float, 486 | help='Dropout rate of MLP (if 0 then no dropout)', 487 | default=0.0 488 | ) 489 | 490 | parser.add_argument( 491 | '--lr_mlp', 492 | type=float, 493 | help='Learning rate of MLP (if -1 then use the same LR as prompt)', 494 | default=-1 495 | ) 496 | 497 | parser.add_argument( 498 | '--prefix_MLP', 499 | type=str, 500 | help='Whether to do embeddings reparametrization with prefix MLP', 501 | default='None' 502 | ) 503 | 504 | parser.add_argument( 505 | '--mlp_bottleneck', 506 | type=int, 507 | help='MLP bottleneck size', 508 | default=1000 509 | ) 510 | 511 | parser.add_argument( 512 | '--residual_mlp', 513 | type=int, 514 | help='Whether to use skip connection in MLP', 515 | default=1 516 | ) 517 | 518 | 519 | # parser.add_argument( 520 | # '--mlp_layer_norm', 521 | # type=int, 522 | # help='Whether to use MLP layer norm (1 - use, 0 - not use)', 523 | # default=0 524 | # ) 525 | 526 | 527 | main(parser.parse_args()) 528 | -------------------------------------------------------------------------------- /T5_codebase/train_t5_cl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | import logging, os, argparse 6 | 7 | from t5_continual import T5ContinualLearner 8 | 9 | 10 | def main(args): 11 | save_path = os.path.join(args.save_dir, args.save_name) 12 | if not os.path.exists(save_path): 13 | os.mkdir(save_path) 14 | task_list = args.task_list 15 | 16 | model_name = args.model_name 17 | continual_learner = T5ContinualLearner(model_name, 18 | task_list, 19 | batch_size=args.batch_size, 20 | select_k_per_class=args.select_k_per_class, 21 | prefix_len=args.prefix_len, 22 | freeze_weights=args.freeze_weights==1, 23 | freeze_except=args.freeze_except, 24 | lr=args.lr, 25 | seq_len=args.seq_len, 26 | early_stopping=args.early_stopping==1, 27 | prefix_MLP=args.prefix_MLP, 28 | prefix_path=args.prefix_path if args.prefix_path!='' else None, 29 | mlp_layer_norm=args.mlp_layer_norm==1, 30 | bottleneck_size=args.bottleneck_size, 31 | get_test_subset=args.get_test_subset==1, 32 | memory_perc=args.memory_perc 33 | ) 34 | if args.get_test_subset==0: 35 | print("Not creating test subset") 36 | 37 | if args.multitask == 1: 38 | print('Multi task learning') 39 | results_dict = continual_learner.multi_task_training(num_epochs=args.num_epochs, save_path=save_path) 40 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 41 | 42 | else: 43 | if args.num_epochs<=50: 44 | eval_every_N = 1 45 | elif args.num_epochs>50 and args.num_epochs<=200: 46 | eval_every_N = 5 47 | elif args.num_epochs>200: 48 | eval_every_N = 10 49 | 50 | results_dict = continual_learner.train_continual(continual_learner.task_list, 51 | epochs=args.num_epochs, 52 | save_path=save_path, 53 | progressive=args.progressive==1, 54 | eval_every_N=eval_every_N, 55 | test_eval_after_every_task=args.test_eval_after_every_task==1, 56 | data_replay_freq=args.data_replay_freq, 57 | ) 58 | np.save(os.path.join(save_path, 'results_dict.npy'), results_dict) 59 | np.save(os.path.join(save_path, 'prompts.npy'), continual_learner.previous_prompts.detach().cpu().numpy()) 60 | 61 | 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser( 66 | description='NLP training script in PyTorch' 67 | ) 68 | 69 | parser.add_argument( 70 | '--save_dir', 71 | type=str, 72 | help='base directory of all models / features (should not be changed)', 73 | default='/data/home/arazdai/T5_prompts/T5_continual/' #'/scratch/hdd001/home/anastasia/CL/' 74 | ) 75 | 76 | parser.add_argument( 77 | '--save_name', 78 | type=str, 79 | help='folder name to save', 80 | required=True 81 | ) 82 | 83 | parser.add_argument( 84 | '--task_list', 85 | nargs='+', 86 | help='List of tasks for training', 87 | required=True 88 | ) 89 | 90 | parser.add_argument( 91 | '--model_name', 92 | type=str, 93 | help='Name of the model used for training', 94 | default="t5-base" 95 | ) 96 | 97 | parser.add_argument( 98 | '--num_epochs', 99 | type=int, 100 | help='Number of epochs to train model', 101 | default=5 102 | ) 103 | 104 | parser.add_argument( 105 | '--multitask', 106 | type=int, 107 | help='Whether to perform multi-task training', 108 | default=0 109 | ) 110 | 111 | parser.add_argument( 112 | '--batch_size', 113 | type=int, 114 | help='Batch size', 115 | default=8 116 | ) 117 | 118 | parser.add_argument( 119 | '--seq_len', 120 | type=int, 121 | help='Length of a single repeat (in #tokens)', 122 | default=512 123 | ) 124 | 125 | parser.add_argument( 126 | '--prefix_len', 127 | type=int, 128 | help='Length of prompt (in #tokens)', 129 | default=10 130 | ) 131 | 132 | parser.add_argument( 133 | '--prefix_path', 134 | type=str, 135 | help='path to a pre-trained progressive prefix (for superGLUE experiments)', 136 | default='' 137 | ) 138 | 139 | 140 | parser.add_argument( 141 | '--lr', 142 | type=float, 143 | help='Learning rate', 144 | default=0.3 145 | ) 146 | 147 | 148 | parser.add_argument( 149 | '--memory_perc', 150 | type=float, 151 | help='Memory perc', 152 | default=0.01 153 | ) 154 | 155 | parser.add_argument( 156 | '--data_replay_freq', 157 | type=float, 158 | help='Replay data every X iterations', 159 | default=-1 160 | ) 161 | 162 | parser.add_argument( 163 | '--select_k_per_class', 164 | type=int, 165 | help='Select k examples from each class (default -1, i.e. no changes to the original dataset)', 166 | default=-1 167 | ) 168 | 169 | parser.add_argument( 170 | '--test_eval_after_every_task', 171 | type=int, 172 | help='Whether to re-evaluate test accuracy after every task (0 - False, 1 - True)', 173 | default=0 174 | ) 175 | 176 | parser.add_argument( 177 | '--progressive', 178 | type=int, 179 | help='Whether to concatenate prompts in a progressive way (0 - False, 1 - True)', 180 | default=1 181 | ) 182 | 183 | parser.add_argument( 184 | '--freeze_weights', 185 | type=int, 186 | help='Whether to freeze model weigts (except word emb)', 187 | default=0 188 | ) 189 | 190 | parser.add_argument( 191 | '--freeze_except', 192 | type=str, 193 | help='If freeze_weights==1, freeze all weights except those that contain this keyword', 194 | default='xxxxxxx' # freeze all 195 | ) 196 | 197 | parser.add_argument( 198 | '--get_test_subset', 199 | type=int, 200 | help='Whether to create a separate test split', 201 | default=1 202 | ) 203 | 204 | parser.add_argument( 205 | '--early_stopping', 206 | type=int, 207 | help='If early_stopping==1, do early stopping based on val accuracy', 208 | default=1 # freeze all 209 | ) 210 | 211 | parser.add_argument( 212 | '--prefix_MLP', 213 | type=str, 214 | help='Type of MLP reparametrization (if None - use Lester original implementation)', 215 | default='None' # freeze all 216 | ) 217 | 218 | parser.add_argument( 219 | '--mlp_layer_norm', 220 | type=int, 221 | help='Do layer norm in MLP', 222 | default=1 # use layer norm 223 | ) 224 | 225 | parser.add_argument( 226 | '--bottleneck_size', 227 | type=int, 228 | help='MLP bottleneck size', 229 | default=800 230 | ) 231 | 232 | main(parser.parse_args()) 233 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ## Continual Learning benchmark data 2 | 3 | Due to GitHub space limitations, we only uploaded one CL benchmark dataset - Amazon Reviews, since it's not available through HuggingFace. Please unzip the dataset file for usage. 4 | 5 | To access the rest of CL datasets, you can either: 6 | * Access them through their HuggingFace identifiers in our training script (AG: ag_news, Yahoo: yahoo_answers_topics, DbPedia: dbpedia_14, Yelp: yelp_review_full). Note that for Yahoo dataset we filtered rows with empty text fields following Zhang et al. (non-empty row idx are saved under "good_ids_yahoo"). 7 | * Download them from Zhang et. al., 2015 [http://goo.gl/JyCnZq](http://goo.gl/JyCnZq) and put corresponding folders into ```datasets/src/data/``` 8 | -------------------------------------------------------------------------------- /datasets/src/data/amazon/Archive.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arazd/ProgressivePrompts/01572d6a73c0576b070ceee00dbe4f5bc278423f/datasets/src/data/amazon/Archive.zip -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: nlp 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_kmp_llvm 10 | - argon2-cffi=20.1.0=py39h27cfd23_1 11 | - asttokens=2.0.5=pyhd3eb1b0_0 12 | - attrs=21.4.0=pyhd3eb1b0_0 13 | - backcall=0.2.0=pyhd3eb1b0_0 14 | - beautifulsoup4=4.11.1=py39h06a4308_0 15 | - blas=1.0=mkl 16 | - bleach=4.1.0=pyhd3eb1b0_0 17 | - bzip2=1.0.8=h7f98852_4 18 | - ca-certificates=2022.4.26=h06a4308_0 19 | - cffi=1.15.0=py39hd667e15_1 20 | - cudatoolkit=11.3.1=h2bc3f7f_2 21 | - dbus=1.13.18=hb2f20db_0 22 | - debugpy=1.5.1=py39h295c915_0 23 | - decorator=5.1.1=pyhd3eb1b0_0 24 | - defusedxml=0.7.1=pyhd3eb1b0_0 25 | - entrypoints=0.4=py39h06a4308_0 26 | - executing=0.8.3=pyhd3eb1b0_0 27 | - expat=2.4.4=h295c915_0 28 | - ffmpeg=4.3=hf484d3e_0 29 | - fontconfig=2.13.1=h6c09931_0 30 | - freetype=2.10.4=h0708190_1 31 | - giflib=5.2.1=h36c2ea0_2 32 | - glib=2.56.2=hd408876_0 33 | - gmp=6.2.1=h58526e2_0 34 | - gnutls=3.6.13=h85f3911_1 35 | - gst-plugins-base=1.14.0=hbbd80ab_1 36 | - gstreamer=1.14.0=hb453b48_1 37 | - icu=58.2=he6710b0_3 38 | - ipykernel=6.9.1=py39h06a4308_0 39 | - ipython=8.2.0=py39h06a4308_0 40 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 41 | - ipywidgets=7.6.5=pyhd3eb1b0_1 42 | - jedi=0.18.1=py39h06a4308_1 43 | - jinja2=3.0.3=pyhd3eb1b0_0 44 | - jpeg=9e=h166bdaf_1 45 | - jsonschema=4.4.0=py39h06a4308_0 46 | - jupyter=1.0.0=py39h06a4308_7 47 | - jupyter_client=7.2.2=py39h06a4308_0 48 | - jupyter_console=6.4.3=pyhd3eb1b0_0 49 | - jupyter_core=4.9.2=py39h06a4308_0 50 | - jupyterlab_pygments=0.1.2=py_0 51 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 52 | - lame=3.100=h7f98852_1001 53 | - lcms2=2.12=hddcbb42_0 54 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 55 | - lerc=3.0=h9c3ff4c_0 56 | - libdeflate=1.12=h166bdaf_0 57 | - libffi=3.3=he6710b0_2 58 | - libgcc-ng=12.1.0=h8d9b700_16 59 | - libiconv=1.17=h166bdaf_0 60 | - libnsl=2.0.0=h7f98852_0 61 | - libpng=1.6.37=h21135ba_2 62 | - libsodium=1.0.18=h7b6447c_0 63 | - libstdcxx-ng=12.1.0=ha89aaad_16 64 | - libtiff=4.4.0=hc85c160_1 65 | - libuuid=1.0.3=h7f8727e_2 66 | - libuv=1.43.0=h7f98852_0 67 | - libwebp=1.2.2=h3452ae3_0 68 | - libwebp-base=1.2.2=h7f98852_1 69 | - libxcb=1.13=h7f98852_1004 70 | - libxml2=2.9.12=h03d6c58_0 71 | - libzlib=1.2.12=h166bdaf_1 72 | - llvm-openmp=14.0.4=he0ac6c6_0 73 | - lz4-c=1.9.3=h9c3ff4c_1 74 | - markupsafe=2.0.1=py39h27cfd23_0 75 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 76 | - mistune=0.8.4=py39h27cfd23_1000 77 | - mkl=2021.4.0=h8d4b97c_729 78 | - mkl-service=2.4.0=py39h7e14d7c_0 79 | - mkl_fft=1.3.1=py39h0c7bc48_1 80 | - mkl_random=1.2.2=py39hde0f152_0 81 | - nbclient=0.5.13=py39h06a4308_0 82 | - nbconvert=6.4.4=py39h06a4308_0 83 | - nbformat=5.3.0=py39h06a4308_0 84 | - ncurses=6.3=h27087fc_1 85 | - nest-asyncio=1.5.5=py39h06a4308_0 86 | - nettle=3.6=he412f7d_0 87 | - notebook=6.4.8=py39h06a4308_0 88 | - numpy=1.22.3=py39he7a7128_0 89 | - numpy-base=1.22.3=py39hf524024_0 90 | - openh264=2.1.1=h780b84a_0 91 | - openjpeg=2.4.0=hb52868f_1 92 | - openssl=1.1.1n=h7f8727e_0 93 | - packaging=21.3=pyhd3eb1b0_0 94 | - pandocfilters=1.5.0=pyhd3eb1b0_0 95 | - parso=0.8.3=pyhd3eb1b0_0 96 | - pcre=8.45=h295c915_0 97 | - pexpect=4.8.0=pyhd3eb1b0_3 98 | - pickleshare=0.7.5=pyhd3eb1b0_1003 99 | - pillow=9.1.1=py39hae2aec6_1 100 | - pip=22.1.2=pyhd8ed1ab_0 101 | - prometheus_client=0.13.1=pyhd3eb1b0_0 102 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 103 | - prompt_toolkit=3.0.20=hd3eb1b0_0 104 | - pthread-stubs=0.4=h36c2ea0_1001 105 | - ptyprocess=0.7.0=pyhd3eb1b0_2 106 | - pure_eval=0.2.2=pyhd3eb1b0_0 107 | - pycparser=2.21=pyhd3eb1b0_0 108 | - pygments=2.11.2=pyhd3eb1b0_0 109 | - pyqt=5.9.2=py39h2531618_6 110 | - pyrsistent=0.18.0=py39heee7806_0 111 | - python=3.9.12=h12debd9_0 112 | - python-dateutil=2.8.2=pyhd3eb1b0_0 113 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 114 | - python_abi=3.9=2_cp39 115 | - pytorch=1.10.1=py3.9_cuda11.3_cudnn8.2.0_0 116 | - pytorch-mutex=1.0=cuda 117 | - pyzmq=22.3.0=py39h295c915_2 118 | - qt=5.9.7=h5867ecd_1 119 | - qtconsole=5.3.0=pyhd3eb1b0_0 120 | - qtpy=2.0.1=pyhd3eb1b0_0 121 | - readline=8.1.2=h0f457ee_0 122 | - send2trash=1.8.0=pyhd3eb1b0_1 123 | - setuptools=62.3.4=py39hf3d152e_0 124 | - sip=4.19.13=py39h295c915_0 125 | - six=1.16.0=pyh6c4a22f_0 126 | - soupsieve=2.3.1=pyhd3eb1b0_0 127 | - sqlite=3.38.5=h4ff8645_0 128 | - stack_data=0.2.0=pyhd3eb1b0_0 129 | - tbb=2021.5.0=h924138e_1 130 | - terminado=0.13.1=py39h06a4308_0 131 | - testpath=0.5.0=pyhd3eb1b0_0 132 | - tk=8.6.12=h27826a3_0 133 | - torchaudio=0.10.1=py39_cu113 134 | - torchvision=0.11.2=py39_cu113 135 | - tornado=6.1=py39h27cfd23_0 136 | - traitlets=5.1.1=pyhd3eb1b0_0 137 | - typing_extensions=4.2.0=pyha770c72_1 138 | - tzdata=2022a=h191b570_0 139 | - wcwidth=0.2.5=pyhd3eb1b0_0 140 | - webencodings=0.5.1=py39h06a4308_1 141 | - wheel=0.37.1=pyhd8ed1ab_0 142 | - widgetsnbextension=3.5.2=py39h06a4308_0 143 | - xorg-libxau=1.0.9=h7f98852_0 144 | - xorg-libxdmcp=1.1.3=h7f98852_0 145 | - xz=5.2.5=h516909a_1 146 | - zeromq=4.3.4=h2531618_0 147 | - zlib=1.2.12=h166bdaf_1 148 | - zstd=1.5.2=h8a70e8d_1 149 | - pip: 150 | - adapter-transformers==3.0.1 151 | - aiohttp==3.8.1 152 | - aiosignal==1.2.0 153 | - async-timeout==4.0.2 154 | - certifi==2022.6.15 155 | - charset-normalizer==2.0.12 156 | - click==8.1.3 157 | - cycler==0.11.0 158 | - data==0.4 159 | - datasets==2.3.2 160 | - dill==0.3.5.1 161 | - filelock==3.7.1 162 | - fonttools==4.33.3 163 | - frozenlist==1.3.0 164 | - fsspec==2022.5.0 165 | - funcsigs==1.0.2 166 | - future==0.18.2 167 | - huggingface-hub==0.7.0 168 | - idna==3.3 169 | - joblib==1.1.0 170 | - kiwisolver==1.4.3 171 | - latex==0.7.0 172 | - matplotlib==3.5.2 173 | - multidict==6.0.2 174 | - multiprocess==0.70.13 175 | - nltk==3.7 176 | - pandas==1.4.2 177 | - pyarrow==8.0.0 178 | - pyparsing==3.0.9 179 | - pytorch-ranger==0.1.1 180 | - pytz==2022.1 181 | - pyyaml==6.0 182 | - regex==2022.6.2 183 | - requests==2.28.0 184 | - responses==0.18.0 185 | - sacremoses==0.0.53 186 | - scikit-learn==1.1.1 187 | - scipy==1.8.1 188 | - seaborn==0.12.0 189 | - sentencepiece==0.1.96 190 | - shutilwhich==1.1.0 191 | - sklearn==0.0 192 | - tempdir==0.7.1 193 | - threadpoolctl==3.1.0 194 | - tokenizers==0.12.1 195 | - torch-optimizer==0.3.0 196 | - tqdm==4.64.0 197 | - transformers==4.20.0 198 | - urllib3==1.26.9 199 | - xxhash==3.0.0 200 | - yarl==1.7.2 201 | prefix: /data/home/arazdai/miniconda/envs/nlp 202 | -------------------------------------------------------------------------------- /images/illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arazd/ProgressivePrompts/01572d6a73c0576b070ceee00dbe4f5bc278423f/images/illustration.png -------------------------------------------------------------------------------- /images/test.png: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------