├── .gitignore ├── contextualize_calibration.py ├── fewshot.py ├── fewshot_softpilot.py ├── filter_method.py ├── readme.md ├── scripts ├── run_fewshot.sh ├── run_pilot.sh └── run_zeroshot.sh └── zeroshot.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | ckpts 3 | results*.txt -------------------------------------------------------------------------------- /contextualize_calibration.py: -------------------------------------------------------------------------------- 1 | 2 | from yacs.config import CfgNode 3 | from openprompt.data_utils import FewShotSampler 4 | from torch.utils.data.dataset import Dataset 5 | from transformers.data.processors.utils import InputExample 6 | from openprompt.pipeline_base import PromptDataLoader, PromptModel, PromptForClassification 7 | from typing import * 8 | import torch 9 | # from openprompt.utils.custom_tqdm import tqdm 10 | from tqdm import tqdm 11 | 12 | 13 | 14 | def calibrate(prompt_model: PromptForClassification, dataloader: PromptDataLoader) -> torch.Tensor: 15 | r"""Calibrate. See `Paper `_ 16 | 17 | Args: 18 | prompt_model (:obj:`PromptForClassification`): the PromptForClassification model. 19 | dataloader (:obj:`List`): the dataloader to conduct the calibrate, could be a virtual one, i.e. contain an only-template example. 20 | 21 | Return: 22 | (:obj:`torch.Tensor`) A tensor of shape (vocabsize) or (mask_num, vocabsize), the logits calculated for each word in the vocabulary 23 | """ 24 | all_logits = [] 25 | prompt_model.eval() 26 | for batch in tqdm(dataloader,desc='ContextCali'): 27 | batch = batch.to(prompt_model.device) 28 | logits = prompt_model.forward_without_verbalize(batch) 29 | all_logits.append(logits.detach()) 30 | all_logits = torch.cat(all_logits, dim=0) 31 | return all_logits 32 | 33 | -------------------------------------------------------------------------------- /fewshot.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | from openprompt.data_utils.text_classification_dataset import AgnewsProcessor, DBpediaProcessor, ImdbProcessor, AmazonProcessor 4 | from openprompt.data_utils.huggingface_dataset import YahooAnswersTopicsProcessor 5 | import torch 6 | from openprompt.data_utils.utils import InputExample 7 | import argparse 8 | import numpy as np 9 | 10 | from openprompt import PromptDataLoader 11 | from openprompt.prompts import ManualVerbalizer, KnowledgeableVerbalizer, SoftVerbalizer, AutomaticVerbalizer 12 | from openprompt.prompts import ManualTemplate 13 | 14 | 15 | parser = argparse.ArgumentParser("") 16 | 17 | parser.add_argument("--model", type=str, default='roberta') 18 | parser.add_argument("--model_name_or_path", default='../plm_cache/roberta-large') 19 | parser.add_argument("--result_file", type=str, default="sfs_scripts/results_fewshot_manual_kpt.txt") 20 | parser.add_argument("--openprompt_path", type=str, default="OpenPrompt") 21 | 22 | parser.add_argument("--shot", type=int, default=5) 23 | parser.add_argument("--seed", type=int, default=144) 24 | parser.add_argument("--plm_eval_mode", action="store_true") 25 | parser.add_argument("--verbalizer", type=str) 26 | parser.add_argument("--calibration", action="store_true") 27 | parser.add_argument("--filter", default="none", type=str) 28 | parser.add_argument("--template_id", type=int) 29 | parser.add_argument("--dataset",type=str) 30 | 31 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 32 | parser.add_argument("--max_epochs", type=int, default=5) 33 | parser.add_argument("--kptw_lr", default=0.06, type=float) 34 | parser.add_argument("--pred_temp", default=1.0, type=float) 35 | parser.add_argument("--max_token_split", default=-1, type=int) 36 | args = parser.parse_args() 37 | 38 | import random 39 | this_run_unicode = str(random.randint(0, 1e10)) 40 | 41 | from openprompt.utils.reproduciblity import set_seed 42 | set_seed(args.seed) 43 | 44 | from openprompt.plms import load_plm 45 | plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path) 46 | 47 | dataset = {} 48 | 49 | if args.dataset == "agnews": 50 | dataset['train'] = AgnewsProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/agnews/") 51 | dataset['test'] = AgnewsProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/agnews/") 52 | class_labels =AgnewsProcessor().get_labels() 53 | scriptsbase = "TextClassification/agnews" 54 | scriptformat = "txt" 55 | cutoff=0.5 56 | max_seq_l = 128 57 | batch_s = 30 58 | elif args.dataset == "dbpedia": 59 | dataset['train'] = DBpediaProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/dbpedia/") 60 | dataset['test'] = DBpediaProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/dbpedia/") 61 | class_labels =DBpediaProcessor().get_labels() 62 | scriptsbase = "TextClassification/dbpedia" 63 | scriptformat = "txt" 64 | cutoff=0.5 65 | max_seq_l = 128 66 | batch_s = 30 67 | elif args.dataset == "yahoo": 68 | dataset['train'] = YahooAnswersTopicsProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/yahoo_answers_topics/") 69 | dataset['test'] = YahooAnswersTopicsProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/yahoo_answers_topics/") 70 | class_labels =YahooAnswersTopicsProcessor().get_labels() 71 | scriptsbase = "TextClassification/yahoo_answers_topics" 72 | scriptformat = "json" 73 | cutoff=0.5 74 | max_seq_l = 128 75 | batch_s = 30 76 | elif args.dataset == "imdb": 77 | dataset['train'] = ImdbProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/imdb/") 78 | dataset['test'] = ImdbProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/imdb/") 79 | class_labels = ImdbProcessor().get_labels() 80 | scriptsbase = "TextClassification/imdb" 81 | scriptformat = "txt" 82 | cutoff=0 83 | max_seq_l = 512 84 | batch_s = 5 85 | elif args.dataset == "amazon": 86 | dataset['train'] = AmazonProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/amazon/") 87 | dataset['test'] = AmazonProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/amazon/") 88 | class_labels = AmazonProcessor().get_labels() 89 | scriptsbase = "TextClassification/amazon" 90 | scriptformat = "txt" 91 | cutoff=0 92 | max_seq_l = 512 93 | batch_s = 5 94 | else: 95 | raise NotImplementedError 96 | 97 | 98 | mytemplate = ManualTemplate(tokenizer=tokenizer).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_template.txt", choice=args.template_id) 99 | 100 | 101 | if args.verbalizer == "kpt": 102 | myverbalizer = KnowledgeableVerbalizer(tokenizer, classes=class_labels, candidate_frac=cutoff, pred_temp=args.pred_temp, max_token_split=args.max_token_split).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/knowledgeable_verbalizer.{scriptformat}") 103 | elif args.verbalizer == "manual": 104 | myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_verbalizer.{scriptformat}") 105 | elif args.verbalizer == "soft": 106 | myverbalizer = SoftVerbalizer(tokenizer, model=plm, classes=class_labels).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_verbalizer.{scriptformat}") 107 | elif args.verbalizer == "auto": 108 | myverbalizer = AutomaticVerbalizer(tokenizer, classes=class_labels) 109 | 110 | 111 | # (contextual) calibration 112 | if args.verbalizer in ["kpt","manual"]: 113 | if args.calibration or args.filter != "none": 114 | from openprompt.data_utils.data_sampler import FewShotSampler 115 | support_sampler = FewShotSampler(num_examples_total=200, also_sample_dev=False) 116 | dataset['support'] = support_sampler(dataset['train'], seed=args.seed) 117 | 118 | # for example in dataset['support']: 119 | # example.label = -1 # remove the labels of support set for clarification 120 | support_dataloader = PromptDataLoader(dataset=dataset["support"], template=mytemplate, tokenizer=tokenizer, 121 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 122 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 123 | truncate_method="tail") 124 | 125 | 126 | from openprompt import PromptForClassification 127 | use_cuda = True 128 | prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False, plm_eval_mode=args.plm_eval_mode) 129 | if use_cuda: 130 | prompt_model= prompt_model.cuda() 131 | 132 | 133 | 134 | # HP 135 | # if args.calibration: 136 | if args.verbalizer in ["kpt","manual"]: 137 | if args.calibration or args.filter != "none": 138 | org_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(len(class_labels))] 139 | from contextualize_calibration import calibrate 140 | # calculate the calibration logits 141 | cc_logits = calibrate(prompt_model, support_dataloader) 142 | print("the calibration logits is", cc_logits) 143 | print("origial label words num {}".format(org_label_words_num)) 144 | 145 | if args.calibration: 146 | myverbalizer.register_calibrate_logits(cc_logits.mean(dim=0)) 147 | new_label_words_num = [len(myverbalizer.label_words[i]) for i in range(len(class_labels))] 148 | print("After filtering, number of label words per class: {}".format(new_label_words_num)) 149 | 150 | 151 | from filter_method import * 152 | if args.filter == "tfidf_filter": 153 | tfidf_filter(myverbalizer, cc_logits, class_labels) 154 | elif args.filter == "none": 155 | pass 156 | else: 157 | raise NotImplementedError 158 | 159 | 160 | # register the logits to the verbalizer so that the verbalizer will divide the calibration probability in producing label logits 161 | # currently, only ManualVerbalizer and KnowledgeableVerbalizer support calibration. 162 | 163 | from openprompt.data_utils.data_sampler import FewShotSampler 164 | sampler = FewShotSampler(num_examples_per_label=args.shot, also_sample_dev=True, num_examples_per_label_dev=args.shot) 165 | dataset['train'], dataset['validation'] = sampler(dataset['train'], seed=args.seed) 166 | 167 | 168 | train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer, 169 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 170 | batch_size=batch_s,shuffle=True, teacher_forcing=False, predict_eos_token=False, 171 | truncate_method="tail") 172 | 173 | validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer, 174 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 175 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 176 | truncate_method="tail") 177 | 178 | # zero-shot test 179 | test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer, 180 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 181 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 182 | truncate_method="tail") 183 | 184 | 185 | def evaluate(prompt_model, dataloader, desc): 186 | prompt_model.eval() 187 | allpreds = [] 188 | alllabels = [] 189 | pbar = tqdm(dataloader, desc=desc) 190 | for step, inputs in enumerate(pbar): 191 | if use_cuda: 192 | inputs = inputs.cuda() 193 | logits = prompt_model(inputs) 194 | labels = inputs['label'] 195 | alllabels.extend(labels.cpu().tolist()) 196 | allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist()) 197 | acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds) 198 | return acc 199 | ############ 200 | ############# 201 | ############### 202 | 203 | from transformers import AdamW, get_linear_schedule_with_warmup 204 | loss_func = torch.nn.CrossEntropyLoss() 205 | 206 | 207 | def prompt_initialize(verbalizer, prompt_model, init_dataloader): 208 | dataloader = init_dataloader 209 | with torch.no_grad(): 210 | for batch in tqdm(dataloader, desc="Init_using_{}".format("train")): 211 | batch = batch.cuda() 212 | logits = prompt_model(batch) 213 | verbalizer.optimize_to_initialize() 214 | 215 | 216 | if args.verbalizer == "soft": 217 | 218 | 219 | no_decay = ['bias', 'LayerNorm.weight'] 220 | 221 | # it's always good practice to set no decay to biase and LayerNorm parameters 222 | optimizer_grouped_parameters1 = [ 223 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 224 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 225 | ] 226 | 227 | # Using different optimizer for prompt parameters and model parameters 228 | 229 | optimizer_grouped_parameters2 = [ 230 | {'params': prompt_model.verbalizer.group_parameters_1, "lr":3e-5}, 231 | {'params': prompt_model.verbalizer.group_parameters_2, "lr":3e-4}, 232 | ] 233 | 234 | 235 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 236 | optimizer2 = AdamW(optimizer_grouped_parameters2) 237 | 238 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 239 | scheduler1 = get_linear_schedule_with_warmup( 240 | optimizer1, 241 | num_warmup_steps=0, num_training_steps=tot_step) 242 | 243 | scheduler2 = get_linear_schedule_with_warmup( 244 | optimizer2, 245 | num_warmup_steps=0, num_training_steps=tot_step) 246 | 247 | elif args.verbalizer == "auto": 248 | prompt_initialize(myverbalizer, prompt_model, train_dataloader) 249 | 250 | no_decay = ['bias', 'LayerNorm.weight'] 251 | 252 | # it's always good practice to set no decay to biase and LayerNorm parameters 253 | optimizer_grouped_parameters1 = [ 254 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 255 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 256 | ] 257 | 258 | # Using different optimizer for prompt parameters and model parameters 259 | 260 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 261 | 262 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 263 | scheduler1 = get_linear_schedule_with_warmup( 264 | optimizer1, 265 | num_warmup_steps=0, num_training_steps=tot_step) 266 | 267 | optimizer2 = None 268 | scheduler2 = None 269 | 270 | elif args.verbalizer == "kpt": 271 | no_decay = ['bias', 'LayerNorm.weight'] 272 | 273 | # it's always good practice to set no decay to biase and LayerNorm parameters 274 | optimizer_grouped_parameters1 = [ 275 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 276 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 277 | ] 278 | 279 | # Using different optimizer for prompt parameters and model parameters 280 | 281 | # optimizer_grouped_parameters2 = [ 282 | # {'params': , "lr":1e-1}, 283 | # ] 284 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 285 | optimizer2 = AdamW(prompt_model.verbalizer.parameters(), lr=args.kptw_lr) 286 | # print(optimizer_grouped_parameters2) 287 | 288 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 289 | scheduler1 = get_linear_schedule_with_warmup( 290 | optimizer1, 291 | num_warmup_steps=0, num_training_steps=tot_step) 292 | 293 | # scheduler2 = get_linear_schedule_with_warmup( 294 | # optimizer2, 295 | # num_warmup_steps=0, num_training_steps=tot_step) 296 | scheduler2 = None 297 | 298 | elif args.verbalizer == "manual": 299 | no_decay = ['bias', 'LayerNorm.weight'] 300 | 301 | # it's always good practice to set no decay to biase and LayerNorm parameters 302 | optimizer_grouped_parameters1 = [ 303 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 304 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 305 | ] 306 | 307 | # Using different optimizer for prompt parameters and model parameters 308 | 309 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 310 | 311 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 312 | scheduler1 = get_linear_schedule_with_warmup( 313 | optimizer1, 314 | num_warmup_steps=0, num_training_steps=tot_step) 315 | 316 | optimizer2 = None 317 | scheduler2 = None 318 | 319 | 320 | tot_loss = 0 321 | log_loss = 0 322 | best_val_acc = 0 323 | for epoch in range(args.max_epochs): 324 | tot_loss = 0 325 | prompt_model.train() 326 | for step, inputs in enumerate(train_dataloader): 327 | if use_cuda: 328 | inputs = inputs.cuda() 329 | logits = prompt_model(inputs) 330 | labels = inputs['label'] 331 | loss = loss_func(logits, labels) 332 | loss.backward() 333 | torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0) 334 | tot_loss += loss.item() 335 | optimizer1.step() 336 | scheduler1.step() 337 | optimizer1.zero_grad() 338 | if optimizer2 is not None: 339 | optimizer2.step() 340 | optimizer2.zero_grad() 341 | if scheduler2 is not None: 342 | scheduler2.step() 343 | 344 | val_acc = evaluate(prompt_model, validation_dataloader, desc="Valid") 345 | if val_acc>=best_val_acc: 346 | torch.save(prompt_model.state_dict(),f"ckpts/{this_run_unicode}.ckpt") 347 | best_val_acc = val_acc 348 | print("Epoch {}, val_acc {}".format(epoch, val_acc), flush=True) 349 | 350 | prompt_model.load_state_dict(torch.load(f"ckpts/{this_run_unicode}.ckpt")) 351 | prompt_model = prompt_model.cuda() 352 | test_acc = evaluate(prompt_model, test_dataloader, desc="Test") 353 | 354 | 355 | 356 | 357 | 358 | content_write = "="*20+"\n" 359 | content_write += f"dataset {args.dataset}\t" 360 | content_write += f"temp {args.template_id}\t" 361 | content_write += f"seed {args.seed}\t" 362 | content_write += f"shot {args.shot}\t" 363 | content_write += f"verb {args.verbalizer}\t" 364 | content_write += f"cali {args.calibration}\t" 365 | content_write += f"filt {args.filter}\t" 366 | content_write += f"maxsplit {args.max_token_split}\t" 367 | content_write += f"kptw_lr {args.kptw_lr}\t" 368 | content_write += "\n" 369 | content_write += f"Acc: {test_acc}" 370 | content_write += "\n\n" 371 | 372 | print(content_write) 373 | 374 | with open(f"{args.result_file}", "a") as fout: 375 | fout.write(content_write) 376 | 377 | import os 378 | os.remove(f"ckpts/{this_run_unicode}.ckpt") -------------------------------------------------------------------------------- /fewshot_softpilot.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | from openprompt.data_utils.text_classification_dataset import AgnewsProcessor, DBpediaProcessor, ImdbProcessor, AmazonProcessor 4 | from openprompt.data_utils.huggingface_dataset import YahooAnswersTopicsProcessor 5 | import torch 6 | from openprompt.data_utils.utils import InputExample 7 | import argparse 8 | import numpy as np 9 | 10 | from openprompt import PromptDataLoader 11 | from openprompt.prompts import ManualVerbalizer, KnowledgeableVerbalizer, SoftVerbalizer, AutomaticVerbalizer 12 | from openprompt.prompts import ManualTemplate 13 | 14 | 15 | parser = argparse.ArgumentParser("") 16 | parser.add_argument("--shot", type=int, default=5) 17 | parser.add_argument("--seed", type=int, default=144) 18 | 19 | parser.add_argument("--plm_eval_mode", action="store_true") 20 | parser.add_argument("--model", type=str, default='roberta') # tested model are gpt2/t5 21 | parser.add_argument("--model_name_or_path", default='../../plm_cache/roberta-large') 22 | parser.add_argument("--openprompt_path", type=str, default="OpenPrompt") 23 | 24 | parser.add_argument("--verbalizer", type=str) 25 | parser.add_argument("--calibration", action="store_true") 26 | parser.add_argument("--not_manual", action="store_true") 27 | parser.add_argument("--filter", default="none", type=str) 28 | parser.add_argument("--template_id", type=int) 29 | parser.add_argument("--dataset",type=str) 30 | parser.add_argument("--result_file", type=str, default="../sfs_scripts/results_fewshot_manual_kpt.txt") 31 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 32 | parser.add_argument("--max_epochs", type=int, default=5) 33 | parser.add_argument("--kptw_lr", default=0.06, type=float) 34 | parser.add_argument("--pred_temp", default=1.0, type=float) 35 | parser.add_argument("--max_token_split", default=-1, type=int) 36 | args = parser.parse_args() 37 | 38 | import random 39 | this_run_unicode = str(random.randint(0, 1e10)) 40 | 41 | from openprompt.utils.reproduciblity import set_seed 42 | set_seed(args.seed) 43 | 44 | from openprompt.plms import load_plm 45 | plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path) 46 | 47 | dataset = {} 48 | 49 | if args.dataset == "agnews": 50 | dataset['train'] = AgnewsProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/agnews/") 51 | dataset['test'] = AgnewsProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/agnews/") 52 | class_labels =AgnewsProcessor().get_labels() 53 | scriptsbase = "TextClassification/agnews" 54 | scriptformat = "txt" 55 | cutoff=0.5 56 | max_seq_l = 128 57 | batch_s = 30 58 | elif args.dataset == "dbpedia": 59 | dataset['train'] = DBpediaProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/dbpedia/") 60 | dataset['test'] = DBpediaProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/dbpedia/") 61 | class_labels =DBpediaProcessor().get_labels() 62 | scriptsbase = "TextClassification/dbpedia" 63 | scriptformat = "txt" 64 | cutoff=0.5 65 | max_seq_l = 128 66 | batch_s = 30 67 | elif args.dataset == "yahoo": 68 | dataset['train'] = YahooAnswersTopicsProcessor().get_train_examples() 69 | dataset['test'] = YahooAnswersTopicsProcessor().get_test_examples() 70 | class_labels =YahooAnswersTopicsProcessor().get_labels() 71 | scriptsbase = "TextClassification/yahoo_answers_topics" 72 | scriptformat = "json" 73 | cutoff=0.5 74 | max_seq_l = 128 75 | batch_s = 30 76 | elif args.dataset == "imdb": 77 | dataset['train'] = ImdbProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/imdb/") 78 | dataset['test'] = ImdbProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/imdb/") 79 | class_labels = ImdbProcessor().get_labels() 80 | scriptsbase = "TextClassification/imdb" 81 | scriptformat = "txt" 82 | cutoff=0 83 | max_seq_l = 512 84 | batch_s = 5 85 | elif args.dataset == "amazon": 86 | dataset['train'] = AmazonProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/amazon/") 87 | dataset['test'] = AmazonProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/amazon/") 88 | class_labels = AmazonProcessor().get_labels() 89 | scriptsbase = "TextClassification/amazon" 90 | scriptformat = "txt" 91 | cutoff=0 92 | max_seq_l = 512 93 | batch_s = 5 94 | else: 95 | raise NotImplementedError 96 | 97 | 98 | mytemplate = ManualTemplate(tokenizer=tokenizer).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_template.txt", choice=args.template_id) 99 | 100 | 101 | if args.verbalizer == "kpt": 102 | myverbalizer = KnowledgeableVerbalizer(tokenizer, classes=class_labels, candidate_frac=cutoff, pred_temp=args.pred_temp, max_token_split=args.max_token_split).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/knowledgeable_verbalizer.{scriptformat}") 103 | elif args.verbalizer == "manual": 104 | myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_verbalizer.{scriptformat}") 105 | elif args.verbalizer == "soft": 106 | if args.not_manual: 107 | myverbalizer = SoftVerbalizer(tokenizer, model=plm, classes=class_labels)#.from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_verbalizer.{scriptformat}") 108 | else: 109 | myverbalizer = SoftVerbalizer(tokenizer, model=plm, classes=class_labels).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_verbalizer.{scriptformat}") 110 | elif args.verbalizer == "auto": 111 | myverbalizer = AutomaticVerbalizer(tokenizer, classes=class_labels) 112 | 113 | 114 | # (contextual) calibration 115 | if args.verbalizer in ["kpt","manual"]: 116 | if args.calibration or args.filter != "none": 117 | from openprompt.data_utils.data_sampler import FewShotSampler 118 | support_sampler = FewShotSampler(num_examples_total=200, also_sample_dev=False) 119 | dataset['support'] = support_sampler(dataset['train'], seed=args.seed) 120 | 121 | # for example in dataset['support']: 122 | # example.label = -1 # remove the labels of support set for clarification 123 | support_dataloader = PromptDataLoader(dataset=dataset["support"], template=mytemplate, tokenizer=tokenizer, 124 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 125 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 126 | truncate_method="tail") 127 | 128 | 129 | from openprompt import PromptForClassification 130 | use_cuda = True 131 | prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False, plm_eval_mode=args.plm_eval_mode) 132 | if use_cuda: 133 | prompt_model= prompt_model.cuda() 134 | 135 | 136 | 137 | # HP 138 | # if args.calibration: 139 | if args.verbalizer in ["kpt","manual"]: 140 | if args.calibration or args.filter != "none": 141 | org_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(len(class_labels))] 142 | from contextualize_calibration import calibrate 143 | # calculate the calibration logits 144 | cc_logits = calibrate(prompt_model, support_dataloader) 145 | print("the calibration logits is", cc_logits) 146 | print("origial label words num {}".format(org_label_words_num)) 147 | 148 | if args.calibration: 149 | myverbalizer.register_calibrate_logits(cc_logits.mean(dim=0)) 150 | new_label_words_num = [len(myverbalizer.label_words[i]) for i in range(len(class_labels))] 151 | print("After filtering, number of label words per class: {}".format(new_label_words_num)) 152 | 153 | 154 | from filter_method import * 155 | 156 | if args.filter == "tfidf_filter": 157 | tfidf_filter(myverbalizer, cc_logits, class_labels) 158 | elif args.filter == "none": 159 | pass 160 | else: 161 | raise NotImplementedError 162 | 163 | 164 | # register the logits to the verbalizer so that the verbalizer will divide the calibration probability in producing label logits 165 | # currently, only ManualVerbalizer and KnowledgeableVerbalizer support calibration. 166 | 167 | from openprompt.data_utils.data_sampler import FewShotSampler 168 | sampler = FewShotSampler(num_examples_per_label=args.shot, also_sample_dev=True, num_examples_per_label_dev=args.shot) 169 | dataset['train'], dataset['validation'] = sampler(dataset['train'], seed=args.seed) 170 | 171 | 172 | train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer, 173 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 174 | batch_size=batch_s,shuffle=True, teacher_forcing=False, predict_eos_token=False, 175 | truncate_method="tail") 176 | 177 | validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer, 178 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 179 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 180 | truncate_method="tail") 181 | 182 | # zero-shot test 183 | test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer, 184 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 185 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 186 | truncate_method="tail") 187 | 188 | 189 | def evaluate(prompt_model, dataloader, desc): 190 | prompt_model.eval() 191 | allpreds = [] 192 | alllabels = [] 193 | pbar = tqdm(dataloader, desc=desc) 194 | for step, inputs in enumerate(pbar): 195 | if use_cuda: 196 | inputs = inputs.cuda() 197 | logits = prompt_model(inputs) 198 | labels = inputs['label'] 199 | alllabels.extend(labels.cpu().tolist()) 200 | allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist()) 201 | acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds) 202 | return acc 203 | ############ 204 | ############# 205 | ############### 206 | 207 | from transformers import AdamW, get_linear_schedule_with_warmup 208 | loss_func = torch.nn.CrossEntropyLoss() 209 | 210 | 211 | def prompt_initialize(verbalizer, prompt_model, init_dataloader): 212 | dataloader = init_dataloader 213 | with torch.no_grad(): 214 | for batch in tqdm(dataloader, desc="Init_using_{}".format("train")): 215 | batch = batch.cuda() 216 | logits = prompt_model(batch) 217 | verbalizer.optimize_to_initialize() 218 | 219 | 220 | if args.verbalizer == "soft": 221 | 222 | 223 | no_decay = ['bias', 'LayerNorm.weight'] 224 | 225 | # it's always good practice to set no decay to biase and LayerNorm parameters 226 | optimizer_grouped_parameters1 = [ 227 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 228 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 229 | ] 230 | 231 | # Using different optimizer for prompt parameters and model parameters 232 | 233 | optimizer_grouped_parameters2 = [ 234 | {'params': prompt_model.verbalizer.group_parameters_1, "lr":3e-5}, 235 | {'params': prompt_model.verbalizer.group_parameters_2, "lr":3e-4}, 236 | ] 237 | 238 | 239 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 240 | optimizer2 = AdamW(optimizer_grouped_parameters2) 241 | 242 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 243 | scheduler1 = get_linear_schedule_with_warmup( 244 | optimizer1, 245 | num_warmup_steps=0, num_training_steps=tot_step) 246 | 247 | scheduler2 = get_linear_schedule_with_warmup( 248 | optimizer2, 249 | num_warmup_steps=0, num_training_steps=tot_step) 250 | 251 | elif args.verbalizer == "auto": 252 | prompt_initialize(myverbalizer, prompt_model, train_dataloader) 253 | 254 | no_decay = ['bias', 'LayerNorm.weight'] 255 | 256 | # it's always good practice to set no decay to biase and LayerNorm parameters 257 | optimizer_grouped_parameters1 = [ 258 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 259 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 260 | ] 261 | 262 | # Using different optimizer for prompt parameters and model parameters 263 | 264 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 265 | 266 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 267 | scheduler1 = get_linear_schedule_with_warmup( 268 | optimizer1, 269 | num_warmup_steps=0, num_training_steps=tot_step) 270 | 271 | optimizer2 = None 272 | scheduler2 = None 273 | 274 | elif args.verbalizer == "kpt": 275 | no_decay = ['bias', 'LayerNorm.weight'] 276 | 277 | # it's always good practice to set no decay to biase and LayerNorm parameters 278 | optimizer_grouped_parameters1 = [ 279 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 280 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 281 | ] 282 | 283 | # Using different optimizer for prompt parameters and model parameters 284 | 285 | # optimizer_grouped_parameters2 = [ 286 | # {'params': , "lr":1e-1}, 287 | # ] 288 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 289 | optimizer2 = AdamW(prompt_model.verbalizer.parameters(), lr=args.kptw_lr) 290 | # print(optimizer_grouped_parameters2) 291 | 292 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 293 | scheduler1 = get_linear_schedule_with_warmup( 294 | optimizer1, 295 | num_warmup_steps=0, num_training_steps=tot_step) 296 | 297 | # scheduler2 = get_linear_schedule_with_warmup( 298 | # optimizer2, 299 | # num_warmup_steps=0, num_training_steps=tot_step) 300 | scheduler2 = None 301 | 302 | elif args.verbalizer == "manual": 303 | no_decay = ['bias', 'LayerNorm.weight'] 304 | 305 | # it's always good practice to set no decay to biase and LayerNorm parameters 306 | optimizer_grouped_parameters1 = [ 307 | {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 308 | {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 309 | ] 310 | 311 | # Using different optimizer for prompt parameters and model parameters 312 | 313 | optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5) 314 | 315 | tot_step = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epochs 316 | scheduler1 = get_linear_schedule_with_warmup( 317 | optimizer1, 318 | num_warmup_steps=0, num_training_steps=tot_step) 319 | 320 | optimizer2 = None 321 | scheduler2 = None 322 | 323 | 324 | tot_loss = 0 325 | log_loss = 0 326 | best_val_acc = 0 327 | for epoch in range(args.max_epochs): 328 | tot_loss = 0 329 | prompt_model.train() 330 | for step, inputs in enumerate(train_dataloader): 331 | if use_cuda: 332 | inputs = inputs.cuda() 333 | logits = prompt_model(inputs) 334 | labels = inputs['label'] 335 | loss = loss_func(logits, labels) 336 | loss.backward() 337 | torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0) 338 | tot_loss += loss.item() 339 | optimizer1.step() 340 | scheduler1.step() 341 | optimizer1.zero_grad() 342 | if optimizer2 is not None: 343 | optimizer2.step() 344 | optimizer2.zero_grad() 345 | if scheduler2 is not None: 346 | scheduler2.step() 347 | 348 | val_acc = evaluate(prompt_model, validation_dataloader, desc="Valid") 349 | if val_acc>=best_val_acc: 350 | torch.save(prompt_model.state_dict(),f"ckpts/{this_run_unicode}.ckpt") 351 | best_val_acc = val_acc 352 | print("Epoch {}, val_acc {}".format(epoch, val_acc), flush=True) 353 | 354 | # print("verbalizer weights", myverbalizer.label_words_weights, flush=True) 355 | prompt_model.load_state_dict(torch.load(f"ckpts/{this_run_unicode}.ckpt")) 356 | prompt_model = prompt_model.cuda() 357 | test_acc = evaluate(prompt_model, test_dataloader, desc="Test") 358 | 359 | 360 | 361 | 362 | ############ 363 | ############# 364 | ############### 365 | 366 | 367 | 368 | 369 | # roughly ~0.853 when using template 0 370 | 371 | 372 | 373 | content_write = "="*20+"\n" 374 | content_write += f"dataset {args.dataset}\t" 375 | content_write += f"temp {args.template_id}\t" 376 | content_write += f"seed {args.seed}\t" 377 | content_write += f"shot {args.shot}\t" 378 | content_write += f"verb {args.verbalizer}\t" 379 | content_write += f"cali {args.calibration}\t" 380 | content_write += f"filt {args.filter}\t" 381 | content_write += f"maxsplit {args.max_token_split}\t" 382 | content_write += f"kptw_lr {args.kptw_lr}\t" 383 | content_write += f"not_manual {args.not_manual}\t" 384 | content_write += "\n" 385 | content_write += f"Acc: {test_acc}" 386 | content_write += "\n\n" 387 | 388 | print(content_write) 389 | 390 | with open(f"{args.result_file}", "a") as fout: 391 | fout.write(content_write) 392 | 393 | import os 394 | os.remove(f"ckpts/{this_run_unicode}.ckpt") -------------------------------------------------------------------------------- /filter_method.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics.pairwise import cosine_similarity 3 | import numpy as np 4 | 5 | 6 | 7 | def tfidf_filter(myverbalizer, cc_logits, class_labels): 8 | myrecord = "" 9 | class_num = len(class_labels) 10 | norm_ord = 10/(class_num-2+1e-2) +1 11 | print("norm_ord", norm_ord) 12 | context_size = cc_logits.shape[0] 13 | tobeproject = cc_logits.transpose(0,1).unsqueeze(0) 14 | ret = [] 15 | for i in range(tobeproject.shape[-1]): 16 | ret.append(myverbalizer.project(tobeproject[:,:,i]).unsqueeze(-1)) 17 | ret = torch.cat(ret, dim=-1) 18 | label_words_cc_logits = ret.squeeze() 19 | 20 | label_words_cc_logits = label_words_cc_logits - label_words_cc_logits.mean(dim=-1,keepdims=True)#, dim=-1) 21 | 22 | first_label_logits = label_words_cc_logits[:,0,:] 23 | orgshape = label_words_cc_logits.shape 24 | label_words_cc_logits = label_words_cc_logits.reshape(-1,context_size) 25 | sim_mat = cosine_similarity(label_words_cc_logits.cpu().numpy(),first_label_logits.cpu().numpy() ).reshape(*orgshape[:-1],first_label_logits.shape[0]) 26 | sim_mat = sim_mat - 10000.0* (1-myverbalizer.label_words_mask.unsqueeze(-1).cpu().numpy()) 27 | 28 | new_label_words = [] 29 | max_lbw_num_pclass = myverbalizer.label_words_mask.shape[-1] 30 | outputers = [] 31 | for class_id in range(len(myverbalizer.label_words)): 32 | tfidf_scores = [] 33 | tf_scores = [] 34 | idf_scores = [] 35 | num_words_in_class = len(myverbalizer.label_words[class_id]) 36 | for in_class_id in range(max_lbw_num_pclass): 37 | if myverbalizer.label_words_mask[class_id, in_class_id] > 0: 38 | word_sim_scores = sim_mat[class_id, in_class_id] 39 | tf_score = word_sim_scores[class_id] 40 | idf_score_source = np.concatenate([word_sim_scores[:class_id], word_sim_scores[class_id+1:]]) 41 | idf_score = 1/ (np.linalg.norm(idf_score_source, ord=norm_ord)/np.power((class_num-1), 1/norm_ord)) 42 | tfidf_score = tf_score * idf_score #+1e-15) 43 | if tf_score<0: 44 | tfidf_score = -100 45 | tfidf_scores.append(tfidf_score) 46 | tf_scores.append(tf_score) 47 | idf_scores.append(idf_score) 48 | 49 | outputer = list(zip(myverbalizer.label_words[class_id], 50 | tfidf_scores, 51 | tf_scores, 52 | idf_scores)) 53 | 54 | outputer = sorted(outputer, key=lambda x:-x[1]) 55 | outputers.append(outputer) 56 | 57 | cut_optimality = [] 58 | max_outputer_len = max([len(outputers[class_id]) for class_id in range(len(outputers))]) 59 | for cut_potent in range(max_outputer_len): 60 | cut_rate = cut_potent/max_outputer_len 61 | loss = 0 62 | for class_id in range(len(myverbalizer.label_words)): 63 | cut_potent_this_class = int(cut_rate*len(outputers[class_id])) 64 | if len(outputers[class_id]) <= cut_potent_this_class: 65 | boundary_score = outputers[class_id][-1][1] 66 | else: 67 | boundary_score = outputers[class_id][cut_potent_this_class][1] 68 | loss += (boundary_score-1)**2 69 | cut_optimality.append([cut_rate, loss]) 70 | optimal_cut_rate = sorted(cut_optimality, key=lambda x:x[1])[0][0] 71 | print("optimal_cut rate is {}".format(optimal_cut_rate)) 72 | for class_id in range(len(myverbalizer.label_words)): 73 | cut = int(len(outputers[class_id])*optimal_cut_rate) 74 | if cut==0: 75 | cut=1 76 | # cut = optimal_cut 77 | new_l = [x[0] for x in outputers[class_id][:cut]] 78 | removed_words = [x[0] for x in outputers[class_id][cut:]] 79 | myrecord += f"Class {class_id} {new_l}\n" 80 | myrecord +=f"Class {class_id} rm: {removed_words}\n" 81 | new_label_words.append(new_l) 82 | myverbalizer.label_words = new_label_words 83 | myverbalizer = myverbalizer.cuda() 84 | noww_label_words_num = [len(myverbalizer.label_words[i]) for i in range(len(class_labels))] 85 | myrecord += f"Phase 3 {noww_label_words_num}\n" 86 | return myrecord 87 | 88 | 89 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # KPT source code 3 | 4 | Here is the source code for our ACL 2022 paper 5 | [Knowledgeable Prompt-tuning: Incorporating Knowledge into Prompt Verbalizer for Text Classification](https://arxiv.org/abs/2108.02035) 6 | 7 | ## install openprompt 8 | 9 | 10 | Please install via git clone. This helps keep the dataset downloading scripts. 11 | 12 | ```bash 13 | git clone git@github.com:thunlp/OpenPrompt.git 14 | cd OpenPrompt 15 | python setup.py install 16 | ``` 17 | 18 | 19 | ## Download the dataset 20 | ``` 21 | cd OpenPrompt/datasets 22 | bash download_text_classification.sh 23 | ``` 24 | 25 | ## Run the scripts 26 | for fewshot experiment 27 | ``` 28 | bash scripts/run_fewshot.sh 29 | ``` 30 | for zeroshot experiment 31 | ``` 32 | bash scripts/run_zeroshot.sh 33 | ``` 34 | for pilot experiment in appendix 35 | ``` 36 | bash scripts/run_pilot.sh 37 | ``` 38 | 39 | The possible arguments in the scripts are in the comment of the scripts. 40 | Please choose the combination according to your need. -------------------------------------------------------------------------------- /scripts/run_fewshot.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=python 2 | BASEPATH="./" 3 | DATASET=amazon #agnews dbpedia imdb amazon yahoo 4 | TEMPLATEID=0 # 1 2 3 5 | SEED=144 # 145 146 147 148 6 | SHOT=5 # 0 1 10 20 7 | VERBALIZER=kpt #soft auto 8 | FILTER=tfidf_filter # none 9 | KPTWLR=0.0 # 0.06 10 | MAXTOKENSPLIT=-1 # 1 11 | MODEL_NAME_OR_PATH="../plm_cache/roberta-large" 12 | RESULTPATH="results_fewshot" 13 | OPENPROMPTPATH="OpenPrompt" 14 | 15 | cd $BASEPATH 16 | 17 | CUDA_VISIBLE_DEVICES=7 $PYTHONPATH fewshot.py \ 18 | --model_name_or_path $MODEL_NAME_OR_PATH \ 19 | --result_file $RESULTPATH \ 20 | --openprompt_path $OPENPROMPTPATH \ 21 | --result_file results_fewshot_norefine.txt \ 22 | --dataset $DATASET \ 23 | --template_id $TEMPLATEID \ 24 | --seed $SEED \ 25 | --shot $SHOT \ 26 | --verbalizer $VERBALIZER \ 27 | --max_token_split $MAXTOKENSPLIT \ 28 | --kptw_lr $KPTWLR -------------------------------------------------------------------------------- /scripts/run_pilot.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=/mnt/sfs_turbo/zhangshudan/anaconda3/envs/kpt/bin/python 2 | BASEPATH=/mnt/sfs_turbo/hsd/thunlp_openprompt_private/ 3 | DATASET=agnews # dbpedia imdb amazon yahoo 4 | TEMPLATEID=0 # 1 2 3 5 | SEED=144 # 145 146 147 148 6 | SHOT=5 # 0 1 10 20 7 | VERBALIZER=soft # 8 | MODEL_NAME_OR_PATH="../plm_cache/roberta-large" 9 | RESULTPATH="results_fewshot_softpilot" 10 | OPENPROMPTPATH="OpenPrompt" 11 | 12 | cd $BASEPATH 13 | 14 | CUDA_VISIBLE_DEVICES=0 $PYTHONPATH forkptexp/fewshot_softpilot.py \ 15 | --model_name_or_path $MODEL_NAME_OR_PATH \ 16 | --result_file $RESULTPATH \ 17 | --openprompt_path $OPENPROMPTPATH \ 18 | --dataset $DATASET \ 19 | --template_id $TEMPLATEID \ 20 | --seed $SEED \ 21 | --shot $SHOT \ 22 | --verbalizer $VERBALIZER \ 23 | --not_manual -------------------------------------------------------------------------------- /scripts/run_zeroshot.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=python3 2 | BASEPATH="./" 3 | DATASET=yahoo #agnews dbpedia imdb amazon yahoo 4 | TEMPLATEID=0 # 1 2 3 5 | SEED=144 # 145 146 147 148 6 | SHOT=5 # 0 1 10 20 7 | VERBALIZER=kpt # 8 | CALIBRATION="--calibration" # "" 9 | FILTER=tfidf_filter # none 10 | MODEL_NAME_OR_PATH="../plm_cache/roberta-large" 11 | RESULTPATH="results_zeroshot" 12 | OPENPROMPTPATH="OpenPrompt" 13 | 14 | cd $BASEPATH 15 | 16 | CUDA_VISIBLE_DEVICES=0 $PYTHONPATH zeroshot.py \ 17 | --model_name_or_path $MODEL_NAME_OR_PATH \ 18 | --result_file $RESULTPATH \ 19 | --openprompt_path $OPENPROMPTPATH \ 20 | --dataset $DATASET \ 21 | --template_id $TEMPLATEID \ 22 | --seed $SEED \ 23 | --verbalizer $VERBALIZER $CALIBRATION \ 24 | --filter $FILTER -------------------------------------------------------------------------------- /zeroshot.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tqdm import tqdm 4 | from openprompt.data_utils.text_classification_dataset import AgnewsProcessor, DBpediaProcessor, ImdbProcessor, AmazonProcessor 5 | from openprompt.data_utils.huggingface_dataset import YahooAnswersTopicsProcessor 6 | import torch 7 | from openprompt.data_utils.utils import InputExample 8 | import argparse 9 | import numpy as np 10 | 11 | from openprompt import PromptDataLoader 12 | from openprompt.prompts import ManualVerbalizer, KnowledgeableVerbalizer 13 | from openprompt.prompts import ManualTemplate 14 | 15 | 16 | parser = argparse.ArgumentParser("") 17 | parser.add_argument("--shot", type=int, default=0) 18 | parser.add_argument("--seed", type=int, default=144) 19 | 20 | parser.add_argument("--plm_eval_mode", action="store_true") 21 | parser.add_argument("--model", type=str, default='roberta') 22 | parser.add_argument("--model_name_or_path", default='../plm_cache/roberta-large') 23 | parser.add_argument("--result_file", type=str, default="sfs_scripts/results_fewshot_manual_kpt.txt") 24 | parser.add_argument("--openprompt_path", type=str, default="OpenPrompt") 25 | 26 | parser.add_argument("--verbalizer", type=str) 27 | parser.add_argument("--calibration", action="store_true") 28 | parser.add_argument("--nocut", action="store_true") 29 | parser.add_argument("--filter", default="none", type=str) 30 | parser.add_argument("--template_id", type=int) 31 | parser.add_argument("--max_token_split", default=-1, type=int) 32 | parser.add_argument("--dataset",type=str) 33 | parser.add_argument("--write_filter_record", action="store_true") 34 | args = parser.parse_args() 35 | 36 | from openprompt.utils.reproduciblity import set_seed 37 | set_seed(args.seed) 38 | 39 | from openprompt.plms import load_plm 40 | plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path) 41 | 42 | dataset = {} 43 | 44 | if args.dataset == "agnews": 45 | dataset['train'] = AgnewsProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/agnews/") 46 | dataset['test'] = AgnewsProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/agnews/") 47 | class_labels =AgnewsProcessor().get_labels() 48 | scriptsbase = "TextClassification/agnews" 49 | scriptformat = "txt" 50 | cutoff=0.5 if (not args.nocut) else 0.0 51 | max_seq_l = 128 52 | batch_s = 30 53 | elif args.dataset == "dbpedia": 54 | dataset['train'] = DBpediaProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/dbpedia/") 55 | dataset['test'] = DBpediaProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/dbpedia/") 56 | class_labels =DBpediaProcessor().get_labels() 57 | scriptsbase = "TextClassification/dbpedia" 58 | scriptformat = "txt" 59 | cutoff=0.5 if (not args.nocut) else 0.0 60 | max_seq_l = 128 61 | batch_s = 30 62 | elif args.dataset == "yahoo": 63 | dataset['train'] = YahooAnswersTopicsProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/yahoo_answers_topics/") 64 | dataset['test'] = YahooAnswersTopicsProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/yahoo_answers_topics/") 65 | class_labels =YahooAnswersTopicsProcessor().get_labels() 66 | scriptsbase = "TextClassification/yahoo_answers_topics" 67 | scriptformat = "json" 68 | cutoff=0.5 if (not args.nocut) else 0.0 69 | max_seq_l = 128 70 | batch_s = 30 71 | elif args.dataset == "imdb": 72 | dataset['train'] = ImdbProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/imdb/") 73 | dataset['test'] = ImdbProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/imdb/") 74 | class_labels = ImdbProcessor().get_labels() 75 | scriptsbase = "TextClassification/imdb" 76 | scriptformat = "txt" 77 | cutoff=0 78 | max_seq_l = 512 79 | batch_s = 5 80 | elif args.dataset == "amazon": 81 | dataset['train'] = AmazonProcessor().get_train_examples(f"{args.openprompt_path}/datasets/TextClassification/amazon/") 82 | dataset['test'] = AmazonProcessor().get_test_examples(f"{args.openprompt_path}/datasets/TextClassification/amazon/") 83 | class_labels = AmazonProcessor().get_labels() 84 | scriptsbase = "TextClassification/amazon" 85 | scriptformat = "txt" 86 | cutoff=0 87 | max_seq_l = 512 88 | batch_s = 5 89 | else: 90 | raise NotImplementedError 91 | 92 | 93 | mytemplate = ManualTemplate(tokenizer=tokenizer).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_template.txt", choice=args.template_id) 94 | 95 | 96 | if args.verbalizer == "kpt": 97 | myverbalizer = KnowledgeableVerbalizer(tokenizer, classes=class_labels, candidate_frac=cutoff, max_token_split=args.max_token_split).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/knowledgeable_verbalizer.{scriptformat}") 98 | elif args.verbalizer == "manual": 99 | myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"{args.openprompt_path}/scripts/{scriptsbase}/manual_verbalizer.{scriptformat}") 100 | elif args.verbalizer == "soft": 101 | raise NotImplementedError 102 | elif args.verbalizer == "auto": 103 | raise NotImplementedError 104 | 105 | # (contextual) calibration 106 | if args.calibration: 107 | from openprompt.data_utils.data_sampler import FewShotSampler 108 | support_sampler = FewShotSampler(num_examples_total=200, also_sample_dev=False) 109 | dataset['support'] = support_sampler(dataset['train'], seed=args.seed) 110 | 111 | for example in dataset['support']: 112 | example.label = -1 # remove the labels of support set for clarification 113 | support_dataloader = PromptDataLoader(dataset=dataset["support"], template=mytemplate, tokenizer=tokenizer, 114 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 115 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 116 | truncate_method="tail") 117 | 118 | 119 | from openprompt import PromptForClassification 120 | use_cuda = True 121 | prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False, plm_eval_mode=args.plm_eval_mode) 122 | if use_cuda: 123 | prompt_model= prompt_model.cuda() 124 | 125 | 126 | myrecord = "" 127 | # HP 128 | if args.calibration: 129 | org_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(len(class_labels))] 130 | from contextualize_calibration import calibrate 131 | # calculate the calibration logits 132 | cc_logits = calibrate(prompt_model, support_dataloader) 133 | print("the calibration logits is", cc_logits) 134 | myrecord += "Phase 1 {}\n".format(org_label_words_num) 135 | 136 | myverbalizer.register_calibrate_logits(cc_logits.mean(dim=0)) 137 | new_label_words_num = [len(myverbalizer.label_words[i]) for i in range(len(class_labels))] 138 | myrecord += "Phase 2 {}\n".format(new_label_words_num) 139 | 140 | 141 | from filter_method import * 142 | if args.filter == "tfidf_filter": 143 | record = tfidf_filter(myverbalizer, cc_logits, class_labels) 144 | myrecord += record 145 | elif args.filter == "none": 146 | pass 147 | else: 148 | raise NotImplementedError 149 | 150 | 151 | # register the logits to the verbalizer so that the verbalizer will divide the calibration probability in producing label logits 152 | # currently, only ManualVerbalizer and KnowledgeableVerbalizer support calibration. 153 | 154 | # 155 | if args.write_filter_record: 156 | record_prefix = "="*20+"\n" 157 | record_prefix += f"dataset {args.dataset}\t" 158 | record_prefix += f"temp {args.template_id}\t" 159 | record_prefix += f"seed {args.seed}\t" 160 | record_prefix += f"cali {args.calibration}\t" 161 | record_prefix += f"filt {args.filter}\t" 162 | record_prefix += "\n" 163 | myrecord = record_prefix +myrecord 164 | with open("../sfs_scripts/filter_record_file.txt",'a') as fout_rec: 165 | fout_rec.write(myrecord) 166 | exit() 167 | 168 | 169 | # zero-shot test 170 | test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer, 171 | tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 172 | batch_size=batch_s,shuffle=False, teacher_forcing=False, predict_eos_token=False, 173 | truncate_method="tail") 174 | allpreds = [] 175 | alllabels = [] 176 | pbar = tqdm(test_dataloader) 177 | for step, inputs in enumerate(pbar): 178 | if use_cuda: 179 | inputs = inputs.cuda() 180 | logits = prompt_model(inputs) 181 | labels = inputs['label'] 182 | alllabels.extend(labels.cpu().tolist()) 183 | allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist()) 184 | acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds) 185 | 186 | 187 | # roughly ~0.853 when using template 0 188 | 189 | 190 | 191 | content_write = "="*20+"\n" 192 | content_write += f"dataset {args.dataset}\t" 193 | content_write += f"temp {args.template_id}\t" 194 | content_write += f"seed {args.seed}\t" 195 | content_write += f"verb {args.verbalizer}\t" 196 | content_write += f"cali {args.calibration}\t" 197 | content_write += f"filt {args.filter}\t" 198 | content_write += f"nocut {args.nocut}\t" 199 | content_write += f"maxsplit {args.max_token_split}\t" 200 | content_write += "\n" 201 | content_write += f"Acc: {acc}" 202 | content_write += "\n\n" 203 | 204 | print(content_write) 205 | 206 | with open(f"{args.result_file}", "a") as fout: 207 | fout.write(content_write) --------------------------------------------------------------------------------