├── README.md ├── data └── .gitignore ├── icl.py ├── kNNPrompting.png ├── knn_prompting.py ├── llm └── .gitignore ├── run_icl.sh ├── run_knnprompting.sh └── utils ├── __init__.py ├── anchor.py ├── dataset.py └── template.py /README.md: -------------------------------------------------------------------------------- 1 | # KNNPrompting 2 | Released code for our ICLR23 paper: [KNN Prompting: Beyond-Context Learning with Calibration-Free Nearest Neighbor Inference](https://openreview.net/forum?id=fe2S7736sNS) 3 | 4 |
5 | Framework of kNNPrompting 6 |
7 | 8 | ## Preparation 9 | ### Environment 10 | The code is tested under torch==1.12.0 and transformers==4.20.1, though the requirement of spefic version is not very strict, run with no bugs, then you are set. 11 | ### Model 12 | Prepare your LLM ([gpt2](https://huggingface.co/gpt2-xl/tree/main) or opt) in `./llm/`, I personally prefer download them myself and configure the local path in scripts. 13 | ### Data 14 | [Download](https://drive.google.com/file/d/1Yh2blPkJvMtdm5xWKoHr2fLp2i2Bn5Ir/view?usp=share_link) dataset and unzip them in `./data`.\ 15 | The structure of the project looks like: 16 | ``` 17 | . 18 | ├── run_icl.sh 19 | ├── run_knnprompting.sh 20 | ├── icl.py 21 | ├── knn_prompting.py 22 | ├── utils 23 | │ ├── anchor.py 24 | │ ├── dataset.py 25 | │ ├── __init__.py 26 | │ └── template.py 27 | ├── llm 28 | │   └── gpt2-xl 29 | │   ├── config.json 30 | │   ├── merges.txt 31 | │   ├── pytorch_model.bin 32 | │   ├── tokenizer.json 33 | │   └── vocab.json 34 | └── data 35 |    └── sst2 36 |       ├── dev_subsample.jsonl 37 |       ├── test.jsonl 38 |       └── train.jsonl 39 | ``` 40 | 41 | ## Run 42 | Run kNNPrompting or In-Context Learning as follows, check the configuration in the script including dataset, llm, seed, etc. 43 | ``` 44 | bash run_knnprompting.sh 45 | ``` 46 | or 47 | ``` 48 | bash run_icl.sh 49 | ``` 50 | ## Results 51 | As the entire framework is training-free, you shall get **exact** results w.r.t. random seeds as follows (invariant to different environment): 52 | 53 | | Seed | 1 | 2 | 3 | 4 | 5 | 54 | | ----------------------------------- | ------ | ------ | ------ | ------ | ------ | 55 | | **In-Context Learning** (gpt2-xl) | 0.8438 | 0.8125 | 0.7227 | 0.8633 | 0.8242 | 56 | | **KNN Prompting** (gpt2-xl, N=1024) | 0.8711 | 0.8867 | 0.8906 | 0.8711 | 0.8906 | 57 | 58 | Full results are listed in the paper (see Table 8 and others). 59 | 60 | ## Citation 61 | * If you have any quesitons, feel free to open an issue. 62 | * If you find this repo useful, please cite us as: 63 | ``` 64 | @inproceedings{ 65 | xu2023knn, 66 | title={\$k\${NN} Prompting: Beyond-Context Learning with Calibration-Free Nearest Neighbor Inference}, 67 | author={Benfeng Xu and Quan Wang and Zhendong Mao and Yajuan Lyu and Qiaoqiao She and Yongdong Zhang}, 68 | booktitle={The Eleventh International Conference on Learning Representations }, 69 | year={2023}, 70 | url={https://openreview.net/forum?id=fe2S7736sNS} 71 | } 72 | ``` -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /icl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from datetime import datetime 4 | from time import sleep 5 | import logging 6 | import argparse 7 | from tqdm import tqdm 8 | import csv 9 | import os 10 | 11 | import torch 12 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 13 | 14 | from utils.dataset import * 15 | from utils.template import * 16 | 17 | 18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # To suppress warnings about parallelism in tokenizers 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="In-Context Learning baseline.") 24 | parser.add_argument( 25 | "--llm_dir", 26 | type=str, 27 | default=None, 28 | ) 29 | parser.add_argument( 30 | "--data_dir", 31 | type=str, 32 | default=None, 33 | ) 34 | parser.add_argument( 35 | "--dataset", 36 | type=str, 37 | default=None, 38 | ) 39 | parser.add_argument( 40 | "--seed", 41 | type=int, 42 | default=None, 43 | ) 44 | parser.add_argument( 45 | "--n_train_shot", 46 | type=int, 47 | default=None, 48 | ) 49 | parser.add_argument( 50 | "--output_dir", 51 | type=str, 52 | default=None, 53 | ) 54 | args = parser.parse_args() 55 | if args.output_dir is not None: 56 | os.makedirs(args.output_dir, exist_ok=True) 57 | 58 | return args 59 | 60 | 61 | def llm_gen(model, prompt, tokenizer, max_context_len): 62 | inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=model.device) 63 | if inputs['input_ids'].shape[1] > max_context_len: 64 | inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:] 65 | inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:] 66 | with torch.no_grad(): 67 | logits = model.forward(input_ids=inputs['input_ids'], 68 | attention_mask=inputs['attention_mask'], 69 | return_dict=True).logits.detach().cpu() 70 | # the output prob is shifted by -1, so we should use the output at the last input token position 71 | # gen_logits.shape = [1, 50257] 72 | gen_logits = logits[:, -1, :] 73 | 74 | return gen_logits 75 | 76 | 77 | def parse_response(gen_logits, tokenizer, id2verb): 78 | gen_prob = torch.softmax(gen_logits, dim=-1) 79 | prob_per_cls = [] 80 | for label_verb in id2verb: 81 | label_verb_token_id = tokenizer.encode(' ' + label_verb)[-1] # note the space before label word 82 | prob_per_cls.append(gen_prob[:, label_verb_token_id]) 83 | pred = torch.argmax(torch.cat(prob_per_cls, dim=0)).tolist() 84 | return pred 85 | 86 | 87 | def main(): 88 | args = parse_args() 89 | 90 | logging.basicConfig( 91 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 92 | datefmt="%m/%d/%Y %H:%M:%S", 93 | level=logging.INFO, 94 | ) 95 | logger.setLevel(logging.INFO) 96 | 97 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 98 | 99 | tokenizer = AutoTokenizer.from_pretrained(args.llm_dir) 100 | # set pad token ids for batched inference cus gpt2 does not have one 101 | tokenizer.padding_side = "left" 102 | tokenizer.pad_token = tokenizer.eos_token 103 | tokenizer.pad_token_id = tokenizer.eos_token_id 104 | model_config = AutoConfig.from_pretrained(args.llm_dir) 105 | model = AutoModelForCausalLM.from_pretrained(args.llm_dir) 106 | model.to(device) 107 | model.eval() 108 | 109 | if 'gpt2' in args.llm_dir: 110 | max_context_len = 1024 111 | else: 112 | max_context_len = 2048 113 | 114 | # prepare dataset 115 | if args.dataset == 'sst2': 116 | AutoDataset = SST2Dataset 117 | elif args.dataset == 'subj': 118 | AutoDataset = SUBJDataset 119 | elif args.dataset == 'agnews': 120 | AutoDataset = AGNEWSDataset 121 | elif args.dataset == 'cb': 122 | AutoDataset = CBDataset 123 | elif args.dataset == 'cr': 124 | AutoDataset = CRDataset 125 | elif args.dataset == 'dbpedia': 126 | AutoDataset = DBPEDIADataset 127 | elif args.dataset == 'mpqa': 128 | AutoDataset = MPQADataset 129 | elif args.dataset == 'mr': 130 | AutoDataset = MRDataset 131 | elif args.dataset == 'rte': 132 | AutoDataset = RTEDataset 133 | elif args.dataset == 'trec': 134 | AutoDataset = TRECDataset 135 | 136 | dataset_dir = os.path.join(args.data_dir, args.dataset) 137 | train_data = AutoDataset(dataset_dir, mode='train') 138 | dev_data = AutoDataset(dataset_dir, mode='dev') 139 | 140 | # inference 141 | train_data.subsamplebyshot(args.n_train_shot, args.seed) 142 | logger.info(f"===== eval on {dev_data.__len__()} dev examples =====") 143 | prompt_prefix = make_prompt(train_data, args.dataset, mode='train') 144 | dev_labels = [] 145 | dev_pred = [] 146 | label2id = dev_data.label2id 147 | id2verb = train_data.id2verb 148 | for ins in tqdm(dev_data.data, total=dev_data.__len__()): 149 | dev_labels.append(label2id[ins['label']]) 150 | prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference') 151 | gen_logits = llm_gen(model, prompt, tokenizer, max_context_len) 152 | dev_pred.append(parse_response(gen_logits, tokenizer, id2verb)) 153 | 154 | dev_correct = [1 if dev_labels[i] == dev_pred[i] else 0 for i in range(len(dev_labels))] 155 | acc = sum(dev_correct) / len(dev_labels) 156 | logger.info(f"Acc: {acc}") 157 | 158 | # logging 159 | save_results_file = os.path.join(args.output_dir, 'results_icl.csv') 160 | csv_exists = os.path.isfile(save_results_file) 161 | with open(save_results_file, 'a+', newline='') as csvfile: 162 | csvwriter = csv.writer(csvfile) 163 | if not csv_exists: 164 | csvwriter.writerow(['dataset', 'llm', 'n_train_shot', 'seed', 'acc']) 165 | csvwriter.writerow([args.dataset, 166 | args.llm_dir, 167 | args.n_train_shot, 168 | args.seed, 169 | acc]) 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /kNNPrompting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenfengXu/KNNPrompting/050d7e455113c0afa82de1537210007c34e96e57/kNNPrompting.png -------------------------------------------------------------------------------- /knn_prompting.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from datetime import datetime 4 | from time import sleep 5 | import logging 6 | import argparse 7 | from tqdm import tqdm 8 | import csv 9 | import os 10 | 11 | import torch 12 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 13 | 14 | from utils.dataset import * 15 | from utils.template import * 16 | from utils.anchor import AnchorStore 17 | 18 | 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="KNN Prompting.") 25 | parser.add_argument( 26 | "--llm_dir", 27 | type=str, 28 | default=None, 29 | ) 30 | parser.add_argument( 31 | "--data_dir", 32 | type=str, 33 | default=None, 34 | ) 35 | parser.add_argument( 36 | "--dataset", 37 | type=str, 38 | default=None, 39 | ) 40 | parser.add_argument( 41 | "--seed", 42 | type=int, 43 | default=None, 44 | ) 45 | parser.add_argument( 46 | "--n_train_shot", 47 | type=int, 48 | default=None, 49 | ) 50 | parser.add_argument( 51 | "--n_demo_shot", 52 | type=int, 53 | default=None, 54 | ) 55 | parser.add_argument( 56 | "--n_anchor_shot", 57 | type=int, 58 | default=None, 59 | ) 60 | parser.add_argument( 61 | "--knn", 62 | type=int, 63 | default=None, 64 | ) 65 | parser.add_argument( 66 | "--output_dir", 67 | type=str, 68 | default=None, 69 | ) 70 | args = parser.parse_args() 71 | if args.output_dir is not None: 72 | os.makedirs(args.output_dir, exist_ok=True) 73 | 74 | return args 75 | 76 | 77 | def llm_gen(model, prompt, tokenizer, max_context_len): 78 | inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=model.device) 79 | if inputs['input_ids'].shape[1] > max_context_len: 80 | inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:] 81 | inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:] 82 | with torch.no_grad(): 83 | logits = model.forward(input_ids=inputs['input_ids'], 84 | attention_mask=inputs['attention_mask'], 85 | return_dict=True).logits.detach().cpu() 86 | # the output prob is shifted by -1, so we should use the output at the last input token position 87 | # gen_logits.shape = [1, 50257] 88 | gen_logits = logits[:, -1, :] 89 | 90 | return gen_logits 91 | 92 | 93 | def main(): 94 | args = parse_args() 95 | 96 | args.n_anchor_shot = args.n_train_shot - args.n_demo_shot 97 | if args.n_anchor_shot <= 0: 98 | raise Exception("Num. of demonstration must be set smaller than num. of training.") 99 | 100 | args.knn = min(args.knn, args.n_anchor_shot) # knn can not exceed num. of anchors 101 | 102 | logging.basicConfig( 103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 104 | datefmt="%m/%d/%Y %H:%M:%S", 105 | level=logging.INFO, 106 | ) 107 | logger.setLevel(logging.INFO) 108 | 109 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 110 | 111 | tokenizer = AutoTokenizer.from_pretrained(args.llm_dir, use_fast=False) 112 | # set pad token ids for batched inference cus gpt2 does not have one 113 | tokenizer.padding_side = "left" 114 | tokenizer.pad_token = tokenizer.eos_token 115 | tokenizer.pad_token_id = tokenizer.eos_token_id 116 | model_config = AutoConfig.from_pretrained(args.llm_dir) 117 | model = AutoModelForCausalLM.from_pretrained(args.llm_dir) 118 | model.to(device) 119 | model.eval() 120 | 121 | if 'gpt2' in args.llm_dir: 122 | max_context_len = 1024 123 | else: 124 | max_context_len = 2048 125 | 126 | # prepare dataset 127 | if args.dataset == 'sst2': 128 | AutoDataset = SST2Dataset 129 | elif args.dataset == 'subj': 130 | AutoDataset = SUBJDataset 131 | elif args.dataset == 'agnews': 132 | AutoDataset = AGNEWSDataset 133 | elif args.dataset == 'cb': 134 | AutoDataset = CBDataset 135 | elif args.dataset == 'cr': 136 | AutoDataset = CRDataset 137 | elif args.dataset == 'dbpedia': 138 | AutoDataset = DBPEDIADataset 139 | elif args.dataset == 'mpqa': 140 | AutoDataset = MPQADataset 141 | elif args.dataset == 'mr': 142 | AutoDataset = MRDataset 143 | elif args.dataset == 'rte': 144 | AutoDataset = RTEDataset 145 | elif args.dataset == 'trec': 146 | AutoDataset = TRECDataset 147 | 148 | datadir = os.path.join(args.data_dir, args.dataset) 149 | train_data = AutoDataset(datadir, mode='train') 150 | dev_data = AutoDataset(datadir, mode='dev') 151 | 152 | anchor_data = AutoDataset(datadir, mode='train') 153 | 154 | # Stage1: Meta Test 155 | train_data.subsamplebyshot(args.n_demo_shot, args.seed) 156 | prompt_prefix = make_prompt(train_data, args.dataset, mode='train') 157 | anchor_data.subsamplebyshot(args.n_anchor_shot, args.seed, exclude=train_data.data) 158 | label2id = dev_data.label2id 159 | id2verb = train_data.id2verb 160 | logger.info(f"===== build anchor store of {anchor_data.__len__()} anchor examples =====") 161 | anchor_store = AnchorStore(K=anchor_data.__len__(), 162 | dim=model_config.vocab_size, 163 | knn=args.knn, 164 | n_class=len(label2id)) 165 | for ins in tqdm(anchor_data.data, total=anchor_data.__len__()): 166 | labels = label2id[ins['label']] 167 | prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference') 168 | gen_logits = llm_gen(model, prompt, tokenizer, max_context_len) 169 | anchor_store.enqueue(torch.softmax(gen_logits, dim=-1), torch.tensor(labels)) 170 | 171 | # Stage2: Formal Test 172 | logger.info(f"===== eval on {dev_data.__len__()} dev examples =====") 173 | dev_labels = [] 174 | dev_pred = [] 175 | for ins in tqdm(dev_data.data, total=dev_data.__len__()): 176 | dev_labels.append(label2id[ins['label']]) 177 | prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference') 178 | gen_logits = llm_gen(model, prompt, tokenizer, max_context_len) 179 | dev_pred.extend(anchor_store.knn_infer(torch.softmax(gen_logits, dim=-1))) 180 | 181 | dev_correct = [1 if dev_labels[i] == dev_pred[i] else 0 for i in range(len(dev_labels))] 182 | acc = sum(dev_correct) / len(dev_labels) 183 | logger.info(f"Acc: {acc}") 184 | 185 | # logging 186 | save_results_file = os.path.join(args.output_dir, 'results_knnprompting.csv'.format(args.dataset)) 187 | csv_exists = os.path.isfile(save_results_file) 188 | with open(save_results_file, 'a+', newline='') as csvfile: 189 | csvwriter = csv.writer(csvfile) 190 | if not csv_exists: 191 | csvwriter.writerow(['dataset', 'llm', 'n_train_shot', 'n_demo_shot', 'n_anchor_shot', 'seed', 'knn', 'acc']) 192 | csvwriter.writerow([args.dataset, 193 | args.llm_dir, 194 | args.n_train_shot, 195 | args.n_demo_shot, 196 | args.n_anchor_shot, 197 | args.seed, 198 | args.knn, 199 | acc]) 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | -------------------------------------------------------------------------------- /llm/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /run_icl.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | LLM=gpt2-xl 4 | LLM_DIR=./llm/${LLM} 5 | DATA_DIR=./data/ 6 | BATCHSIZE=1 7 | 8 | # Set maxshot w.r.t. context length 9 | if [[ "${LLM}" == "gpt2-xl" ]] || [[ "{$LLM}" == "gpt2-large" ]]; then 10 | # max context length = 1024 11 | array1=(mpqa) # maxshot = 32 12 | array2=(sst2) # maxshot = 16 13 | array3=(subj cr mr trec) # maxshot = 8 14 | array4=(rte) # maxshot = 4 15 | array5=(agnews cb) # maxshot = 2 16 | array6=(dbpedia) # maxshot = 1 17 | else 18 | # max context length = 2048 19 | array1=(sst2 mpqa) 20 | array2=(subj cr mr trec) 21 | array3=(rte) 22 | array4=(agnews cb) 23 | array5=(none) 24 | array6=(dbpedia) 25 | fi 26 | 27 | # for DATASET in sst2 subj mpqa agnews cb cr dbpedia mr rte trec; do 28 | 29 | DATASET=sst2 30 | 31 | if [[ "${array1[@]}" =~ "${DATASET}" ]]; then 32 | NSHOT=32 33 | elif [[ "${array2[@]}" =~ "${DATASET}" ]]; then 34 | NSHOT=16 35 | elif [[ "${array3[@]}" =~ "${DATASET}" ]]; then 36 | NSHOT=8 37 | elif [[ "${array4[@]}" =~ "${DATASET}" ]]; then 38 | NSHOT=4 39 | elif [[ "${array5[@]}" =~ "${DATASET}" ]]; then 40 | NSHOT=2 41 | else 42 | NSHOT=1 43 | fi 44 | 45 | for SEED in 1 2 3 4 5; do 46 | 47 | python3 icl.py \ 48 | --llm_dir ${LLM_DIR} \ 49 | --dataset ${DATASET} \ 50 | --data_dir ${DATA_DIR} \ 51 | --n_train_shot ${NSHOT} \ 52 | --seed ${SEED} \ 53 | --output_dir ./output 54 | 55 | done -------------------------------------------------------------------------------- /run_knnprompting.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | LLM=gpt2-xl 4 | LLM_DIR=./llm/${LLM} 5 | DATA_DIR=./data/ 6 | 7 | # Set max demonstration shot w.r.t. context length 8 | if [[ "${LLM}" == "gpt2-xl" ]] || [[ "{$LLM}" == "gpt2-large" ]]; then 9 | # max context length = 1024 10 | array1=(mpqa) # maxshot = 32 11 | array2=(sst2) # maxshot = 16 12 | array3=(subj cr mr trec) # maxshot = 8 13 | array4=(rte) # maxshot = 4 14 | array5=(agnews cb) # maxshot = 2 15 | array6=(dbpedia) # maxshot = 1 16 | else 17 | # max context length = 2048 18 | array1=(sst2 mpqa) 19 | array2=(subj cr mr trec) 20 | array3=(rte) 21 | array4=(agnews cb) 22 | array5=(none) 23 | array6=(dbpedia) 24 | fi 25 | 26 | # for DATASET in sst2 subj mpqa agnews cb cr dbpedia mr rte trec; do 27 | DATASET=sst2 28 | 29 | if [[ "${array1[@]}" =~ "${DATASET}" ]]; then 30 | N_DEMO_SHOT=32 31 | elif [[ "${array2[@]}" =~ "${DATASET}" ]]; then 32 | N_DEMO_SHOT=16 33 | elif [[ "${array3[@]}" =~ "${DATASET}" ]]; then 34 | N_DEMO_SHOT=8 35 | elif [[ "${array4[@]}" =~ "${DATASET}" ]]; then 36 | N_DEMO_SHOT=4 37 | elif [[ "${array5[@]}" =~ "${DATASET}" ]]; then 38 | N_DEMO_SHOT=2 39 | else 40 | N_DEMO_SHOT=1 41 | fi 42 | 43 | N_TRAIN_SHOT=1024 44 | KNN=3 45 | for SEED in 1 2 3 4 5; do 46 | 47 | python3 knn_prompting.py \ 48 | --llm_dir ${LLM_DIR} \ 49 | --dataset ${DATASET} \ 50 | --data_dir ${DATA_DIR} \ 51 | --n_train_shot ${N_TRAIN_SHOT} \ 52 | --n_demo_shot ${N_DEMO_SHOT} \ 53 | --seed ${SEED} \ 54 | --output_dir ./output \ 55 | --knn ${KNN} 56 | 57 | done 58 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenfengXu/KNNPrompting/050d7e455113c0afa82de1537210007c34e96e57/utils/__init__.py -------------------------------------------------------------------------------- /utils/anchor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class AnchorStore(nn.Module): 6 | 7 | def __init__(self, K=1024, dim=50257, knn=1, n_class=2): 8 | super(AnchorStore, self).__init__() 9 | 10 | self.register_buffer("queue_anchor", torch.randn(K, dim)) 11 | self.register_buffer("queue_label", torch.zeros(K, dtype=torch.long)) 12 | self.queue_label.fill_(-1) 13 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 14 | self.knn = knn 15 | self.n_class = n_class 16 | 17 | def enqueue(self, anchors, labels): 18 | 19 | ptr = int(self.queue_ptr) 20 | bs = anchors.shape[0] 21 | 22 | self.queue_anchor[ptr:ptr + bs, :] = anchors 23 | self.queue_label[ptr:ptr + bs] = labels 24 | self.queue_ptr[0] = ptr + bs 25 | 26 | def knn_infer(self, query): 27 | 28 | # kl_div.shape = [1, len(self.queue_anchor)] 29 | kl_distance = torch.mean(self.queue_anchor[:, None, :] * (self.queue_anchor[:, None, :].log() - query.log()), dim=2).transpose(1, 0) 30 | if self.knn == 1: 31 | # directly return the nearest neighbor 32 | return self.queue_label[kl_distance.argmin(dim=1)].tolist() 33 | else: 34 | values, indices = torch.topk(kl_distance, self.knn, dim=1, largest=False) 35 | # count for each category within k nearest neighbors, and return the dominant category 36 | # knn_cnt.shape = [1, self.n_class] 37 | knn_cnt = torch.zeros((query.shape[0], self.n_class)) 38 | for i in range(self.n_class): 39 | knn_cnt[:, i] = (self.queue_label[indices] == i).sum(dim=1) 40 | return knn_cnt.argmax(dim=1).tolist() 41 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import codecs 3 | import random 4 | from pathlib import Path 5 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 6 | import os 7 | import math 8 | 9 | import torch 10 | from torch.utils.data import Dataset, Sampler 11 | 12 | 13 | class BASEDataset: 14 | def __init__( 15 | self, 16 | data_dir, 17 | mode 18 | ): 19 | """data key: sentence, label[0/1]""" 20 | super().__init__() 21 | if mode == 'dev': 22 | mode = 'dev_subsample' 23 | data_file = os.path.join(data_dir, mode + '.jsonl') 24 | self.data = [] 25 | with open(data_file, 'r') as f: 26 | read_lines = f.readlines() 27 | for line in read_lines: 28 | instance = json.loads(line.strip()) 29 | self.data.append(instance) 30 | # customize your own label map in inheritance 31 | self.id2label = {0: 'negative', 1: 'positive'} 32 | self.label2id = {'negative': 0, 'positive': 1} 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def subsamplebyshot(self, n_shot, seed, exclude=None): 38 | # exclude 39 | if exclude is not None: 40 | for ins in exclude: 41 | self.data.remove(ins) 42 | # aggregate data by each category 43 | random.seed(seed) 44 | data_by_cls = {} 45 | for i in range(self.__len__()): 46 | if self.label2id[self.data[i]['label']] not in data_by_cls: 47 | data_by_cls[self.label2id[self.data[i]['label']]] = [] 48 | data_by_cls[self.label2id[self.data[i]['label']]].append(self.data[i]) 49 | # evenly sample n examples from each category 50 | data_subsample = [] 51 | for cls in data_by_cls.keys(): 52 | data_subsampled_by_cls = random.sample(data_by_cls[cls], min(n_shot, len(data_by_cls[cls]))) 53 | data_subsample.extend(data_subsampled_by_cls) 54 | random.shuffle(data_subsample) 55 | self.data = data_subsample 56 | 57 | 58 | class SST2Dataset(BASEDataset): 59 | def __init__( 60 | self, 61 | data_dir, 62 | mode 63 | ): 64 | """data key: sentence, label[0/1]""" 65 | super().__init__(data_dir, mode) 66 | self.label2id = {'0': 0, '1': 1} 67 | self.label2verb = {'0': 'negative', '1': 'positive'} 68 | self.id2verb = ['negative', 'positive'] 69 | 70 | 71 | class SUBJDataset(BASEDataset): 72 | def __init__( 73 | self, 74 | data_dir, 75 | mode 76 | ): 77 | """data key: sentence, label[0/1]""" 78 | super().__init__(data_dir, mode) 79 | # subj only has test set 80 | self.label2id = {'0': 0, '1': 1} 81 | self.label2verb = {'0': 'subjective', '1': 'objective'} 82 | self.id2verb = ['subjective', 'objective'] 83 | 84 | 85 | class AGNEWSDataset(BASEDataset): 86 | def __init__( 87 | self, 88 | data_dir, 89 | mode 90 | ): 91 | """data key: sentence, label[0/1]""" 92 | super().__init__(data_dir, mode) 93 | self.label2id = {'1': 0, '2': 1, '3': 2, '4': 3} 94 | self.label2verb = {'1': 'world', '2': 'sports', '3': 'business', '4': 'technology'} 95 | self.id2verb = ['world', 'sports', 'business', 'technology'] 96 | 97 | 98 | class CBDataset(BASEDataset): 99 | def __init__( 100 | self, 101 | data_dir, 102 | mode 103 | ): 104 | """data key: sentence, label[0/1]""" 105 | super().__init__(data_dir, mode) 106 | self.label2id = {'contradiction': 0, 'entailment': 1, 'neutral': 2} 107 | self.label2verb = {'contradiction': 'false', 'entailment': 'true', 'neutral': 'neither'} 108 | self.id2verb = ['false', 'true', 'neither'] 109 | 110 | 111 | class CRDataset(BASEDataset): 112 | def __init__( 113 | self, 114 | data_dir, 115 | mode 116 | ): 117 | """data key: sentence, label[0/1]""" 118 | super().__init__(data_dir, mode) 119 | self.label2id = {'0': 0, '1': 1} 120 | self.label2verb = {'0': 'negative', '1': 'positive'} 121 | self.id2verb = ['negative', 'positive'] 122 | 123 | 124 | class DBPEDIADataset(BASEDataset): 125 | def __init__( 126 | self, 127 | data_dir, 128 | mode 129 | ): 130 | """data key: sentence, label[0/1]""" 131 | if mode == 'dev': 132 | mode = 'dev_subsample' 133 | else: 134 | mode = 'train_subset' # this is an exception case 135 | super().__init__(data_dir, mode) 136 | self.label2id = {'1': 0, '2': 1, '3': 2, '4': 3, '5': 4, 137 | '6': 5, '7': 6, '8': 7, '9': 8, '10': 9, 138 | '11': 10, '12': 11, '13': 12, '14': 13} 139 | self.label2verb = {'1': 'company', '2': 'school', '3': 'artist', '4': 'athlete', '5': 'politics', 140 | '6': 'transportation', '7': 'building', '8': 'nature', '9': 'village', '10': 'animal', 141 | '11': 'plant', '12': 'album', '13': 'film', '14': 'book'} 142 | self.id2verb = ['company', 'school', 'artist', 'athlete', 'politics', 143 | 'transportation', 'building', 'nature', 'village', 'animal', 144 | 'plant', 'album', 'film', 'book'] 145 | 146 | 147 | class MPQADataset(BASEDataset): 148 | def __init__( 149 | self, 150 | data_dir, 151 | mode 152 | ): 153 | """data key: sentence, label[0/1]""" 154 | super().__init__(data_dir, mode) 155 | self.label2id = {'0': 0, '1': 1} 156 | self.label2verb = {'0': 'negative', '1': 'positive'} 157 | self.id2verb = ['negative', 'positive'] 158 | 159 | 160 | class MRDataset(BASEDataset): 161 | def __init__( 162 | self, 163 | data_dir, 164 | mode 165 | ): 166 | """data key: sentence, label[0/1]""" 167 | super().__init__(data_dir, mode) 168 | self.label2id = {'0': 0, '1': 1} 169 | self.label2verb = {'0': 'negative', '1': 'positive'} 170 | self.id2verb = ['negative', 'positive'] 171 | 172 | 173 | class RTEDataset(BASEDataset): 174 | def __init__( 175 | self, 176 | data_dir, 177 | mode 178 | ): 179 | """data key: sentence, label[0/1]""" 180 | super().__init__(data_dir, mode) 181 | self.label2id = {'not_entailment': 0, 'entailment': 1} 182 | self.label2verb = {'not_entailment': 'false', 'entailment': 'true'} 183 | self.id2verb = ['false', 'true'] 184 | 185 | 186 | class SST5Dataset(BASEDataset): 187 | def __init__( 188 | self, 189 | data_dir, 190 | mode 191 | ): 192 | """data key: sentence, label[0/1]""" 193 | super().__init__(data_dir, mode) 194 | self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4} 195 | self.label2verb = {'0': 'terrible', '1': 'bad', '2': 'okay', '3': 'good', '4': 'great'} 196 | self.id2verb = ['terrible', 'bad', 'okay', 'good', 'great'] 197 | 198 | 199 | class TRECDataset(BASEDataset): 200 | def __init__( 201 | self, 202 | data_dir, 203 | mode 204 | ): 205 | """data key: sentence, label[0/1]""" 206 | super().__init__(data_dir, mode) 207 | self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5} 208 | self.label2verb = {'0': 'description', '1': 'entity', '2': 'expression', '3': 'human','4': 'location', '5': 'number'} 209 | self.id2verb = ['description', 'entity', 'expression', 'human', 'location', 'number'] -------------------------------------------------------------------------------- /utils/template.py: -------------------------------------------------------------------------------- 1 | def make_prompt(dataset, dataset_name, mode, indices=None): 2 | if dataset_name == 'sst2': 3 | template_func = template_sst2 4 | elif dataset_name == 'subj': 5 | template_func = template_subj 6 | elif dataset_name == 'agnews': 7 | template_func = template_agnews 8 | elif dataset_name == 'cb': 9 | template_func = template_cb 10 | elif dataset_name == 'cr': 11 | template_func = template_cr 12 | elif dataset_name == 'dbpedia': 13 | template_func = template_dbpedia 14 | elif dataset_name == 'mpqa': 15 | template_func = template_mpqa 16 | elif dataset_name == 'mr': 17 | template_func = template_mr 18 | elif dataset_name == 'rte': 19 | template_func = template_rte 20 | elif dataset_name == 'sst5': 21 | template_func = template_sst5 22 | elif dataset_name == 'trec': 23 | template_func = template_trec 24 | if mode == 'inference': 25 | return template_func(dataset, None, mode) 26 | prompt = '' 27 | if mode == 'compose': # inputs are different, list of examples instead of dataset class 28 | for ins in dataset.index(indices): 29 | prompt += template_func(ins, dataset.label2verb[ins['label']], 'train') 30 | prompt += '\n' 31 | return prompt 32 | for ins in dataset.data: 33 | prompt += template_func(ins, dataset.label2verb[ins['label']], mode) 34 | prompt += '\n' 35 | return prompt 36 | 37 | 38 | def template_sst2(ins, label, mode): 39 | if mode == 'train': 40 | return f"Review: {ins['sentence']}\nSentiment: {label}\n" 41 | else: 42 | return f"Review: {ins['sentence']}\nSentiment:" 43 | 44 | 45 | def template_subj(ins, label, mode): 46 | if mode == 'train': 47 | return f"Input: {ins['sentence']}\nType: {label}\n" 48 | else: 49 | return f"Input: {ins['sentence']}\nType:" 50 | 51 | 52 | def template_agnews(ins, label, mode): 53 | if mode == 'train': 54 | return f"input: {ins['sentence']}\ntype: {label}\n" 55 | else: 56 | return f"input: {ins['sentence']}\ntype:" 57 | 58 | 59 | def template_cb(ins, label, mode): 60 | if mode == 'train': 61 | return f"premise: {ins['premise']}\nhypothesis: {ins['hypothesis']}\nprediction: {label}\n" 62 | else: 63 | return f"premise: {ins['premise']}\nhypothesis: {ins['hypothesis']}\nprediction:" 64 | 65 | 66 | def template_cr(ins, label, mode): 67 | if mode == 'train': 68 | return f"Review: {ins['sentence']}\nSentiment: {label}\n" 69 | else: 70 | return f"Review: {ins['sentence']}\nSentiment:" 71 | 72 | 73 | def template_dbpedia(ins, label, mode): 74 | if mode == 'train': 75 | return f"input: {ins['sentence']}\ntype: {label}\n" 76 | else: 77 | return f"input: {ins['sentence']}\ntype:" 78 | 79 | 80 | def template_mpqa(ins, label, mode): 81 | if mode == 'train': 82 | return f"Review: {ins['sentence']}\nSentiment: {label}\n" 83 | else: 84 | return f"Review: {ins['sentence']}\nSentiment:" 85 | 86 | 87 | def template_mr(ins, label, mode): 88 | if mode == 'train': 89 | return f"Review: {ins['sentence']}\nSentiment: {label}\n" 90 | else: 91 | return f"Review: {ins['sentence']}\nSentiment:" 92 | 93 | 94 | def template_rte(ins, label, mode): 95 | if mode == 'train': 96 | return f"premise: {ins['sentence_1']}\nhypothesis: {ins['sentence_2']}\nprediction: {label}\n" 97 | else: 98 | return f"premise: {ins['sentence_1']}\nhypothesis: {ins['sentence_2']}\nprediction:" 99 | 100 | 101 | def template_sst5(ins, label, mode): 102 | if mode == 'train': 103 | return f"Review: {ins['sentence']}\nSentiment: {label}\n" 104 | else: 105 | return f"Review: {ins['sentence']}\nSentiment:" 106 | 107 | 108 | def template_trec(ins, label, mode): 109 | if mode == 'train': 110 | return f"Question: {ins['sentence']}\nType: {label}\n" 111 | else: 112 | return f"Question: {ins['sentence']}\nType:" 113 | 114 | 115 | def sent_sim_template(ins, dataset_name): 116 | # ['sst2', 'subj', 'mpqa', 'agnews', 'cb', 'cr', 'dbpedia', 'mr', 'rte', 'trec'] 117 | if dataset_name in ['sst2', 'subj', 'mpqa', 'agnews', 'cr', 'dbpedia', 'mr', 'trec']: 118 | return ins['sentence'] 119 | elif dataset_name in ['cb']: 120 | return f"premise: {ins['premise']}. hypothesis: {ins['hypothesis']}" 121 | elif dataset_name in ['rte']: 122 | return f"premise: {ins['sentence_1']}. hypothesis: {ins['sentence_2']}" 123 | --------------------------------------------------------------------------------