├── README.md
├── data
└── .gitignore
├── icl.py
├── kNNPrompting.png
├── knn_prompting.py
├── llm
└── .gitignore
├── run_icl.sh
├── run_knnprompting.sh
└── utils
├── __init__.py
├── anchor.py
├── dataset.py
└── template.py
/README.md:
--------------------------------------------------------------------------------
1 | # KNNPrompting
2 | Released code for our ICLR23 paper: [KNN Prompting: Beyond-Context Learning with Calibration-Free Nearest Neighbor Inference](https://openreview.net/forum?id=fe2S7736sNS)
3 |
4 |
5 |

6 |
7 |
8 | ## Preparation
9 | ### Environment
10 | The code is tested under torch==1.12.0 and transformers==4.20.1, though the requirement of spefic version is not very strict, run with no bugs, then you are set.
11 | ### Model
12 | Prepare your LLM ([gpt2](https://huggingface.co/gpt2-xl/tree/main) or opt) in `./llm/`, I personally prefer download them myself and configure the local path in scripts.
13 | ### Data
14 | [Download](https://drive.google.com/file/d/1Yh2blPkJvMtdm5xWKoHr2fLp2i2Bn5Ir/view?usp=share_link) dataset and unzip them in `./data`.\
15 | The structure of the project looks like:
16 | ```
17 | .
18 | ├── run_icl.sh
19 | ├── run_knnprompting.sh
20 | ├── icl.py
21 | ├── knn_prompting.py
22 | ├── utils
23 | │ ├── anchor.py
24 | │ ├── dataset.py
25 | │ ├── __init__.py
26 | │ └── template.py
27 | ├── llm
28 | │ └── gpt2-xl
29 | │ ├── config.json
30 | │ ├── merges.txt
31 | │ ├── pytorch_model.bin
32 | │ ├── tokenizer.json
33 | │ └── vocab.json
34 | └── data
35 | └── sst2
36 | ├── dev_subsample.jsonl
37 | ├── test.jsonl
38 | └── train.jsonl
39 | ```
40 |
41 | ## Run
42 | Run kNNPrompting or In-Context Learning as follows, check the configuration in the script including dataset, llm, seed, etc.
43 | ```
44 | bash run_knnprompting.sh
45 | ```
46 | or
47 | ```
48 | bash run_icl.sh
49 | ```
50 | ## Results
51 | As the entire framework is training-free, you shall get **exact** results w.r.t. random seeds as follows (invariant to different environment):
52 |
53 | | Seed | 1 | 2 | 3 | 4 | 5 |
54 | | ----------------------------------- | ------ | ------ | ------ | ------ | ------ |
55 | | **In-Context Learning** (gpt2-xl) | 0.8438 | 0.8125 | 0.7227 | 0.8633 | 0.8242 |
56 | | **KNN Prompting** (gpt2-xl, N=1024) | 0.8711 | 0.8867 | 0.8906 | 0.8711 | 0.8906 |
57 |
58 | Full results are listed in the paper (see Table 8 and others).
59 |
60 | ## Citation
61 | * If you have any quesitons, feel free to open an issue.
62 | * If you find this repo useful, please cite us as:
63 | ```
64 | @inproceedings{
65 | xu2023knn,
66 | title={\$k\${NN} Prompting: Beyond-Context Learning with Calibration-Free Nearest Neighbor Inference},
67 | author={Benfeng Xu and Quan Wang and Zhendong Mao and Yajuan Lyu and Qiaoqiao She and Yongdong Zhang},
68 | booktitle={The Eleventh International Conference on Learning Representations },
69 | year={2023},
70 | url={https://openreview.net/forum?id=fe2S7736sNS}
71 | }
72 | ```
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/icl.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from datetime import datetime
4 | from time import sleep
5 | import logging
6 | import argparse
7 | from tqdm import tqdm
8 | import csv
9 | import os
10 |
11 | import torch
12 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
13 |
14 | from utils.dataset import *
15 | from utils.template import *
16 |
17 |
18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # To suppress warnings about parallelism in tokenizers
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description="In-Context Learning baseline.")
24 | parser.add_argument(
25 | "--llm_dir",
26 | type=str,
27 | default=None,
28 | )
29 | parser.add_argument(
30 | "--data_dir",
31 | type=str,
32 | default=None,
33 | )
34 | parser.add_argument(
35 | "--dataset",
36 | type=str,
37 | default=None,
38 | )
39 | parser.add_argument(
40 | "--seed",
41 | type=int,
42 | default=None,
43 | )
44 | parser.add_argument(
45 | "--n_train_shot",
46 | type=int,
47 | default=None,
48 | )
49 | parser.add_argument(
50 | "--output_dir",
51 | type=str,
52 | default=None,
53 | )
54 | args = parser.parse_args()
55 | if args.output_dir is not None:
56 | os.makedirs(args.output_dir, exist_ok=True)
57 |
58 | return args
59 |
60 |
61 | def llm_gen(model, prompt, tokenizer, max_context_len):
62 | inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=model.device)
63 | if inputs['input_ids'].shape[1] > max_context_len:
64 | inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:]
65 | inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:]
66 | with torch.no_grad():
67 | logits = model.forward(input_ids=inputs['input_ids'],
68 | attention_mask=inputs['attention_mask'],
69 | return_dict=True).logits.detach().cpu()
70 | # the output prob is shifted by -1, so we should use the output at the last input token position
71 | # gen_logits.shape = [1, 50257]
72 | gen_logits = logits[:, -1, :]
73 |
74 | return gen_logits
75 |
76 |
77 | def parse_response(gen_logits, tokenizer, id2verb):
78 | gen_prob = torch.softmax(gen_logits, dim=-1)
79 | prob_per_cls = []
80 | for label_verb in id2verb:
81 | label_verb_token_id = tokenizer.encode(' ' + label_verb)[-1] # note the space before label word
82 | prob_per_cls.append(gen_prob[:, label_verb_token_id])
83 | pred = torch.argmax(torch.cat(prob_per_cls, dim=0)).tolist()
84 | return pred
85 |
86 |
87 | def main():
88 | args = parse_args()
89 |
90 | logging.basicConfig(
91 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
92 | datefmt="%m/%d/%Y %H:%M:%S",
93 | level=logging.INFO,
94 | )
95 | logger.setLevel(logging.INFO)
96 |
97 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98 |
99 | tokenizer = AutoTokenizer.from_pretrained(args.llm_dir)
100 | # set pad token ids for batched inference cus gpt2 does not have one
101 | tokenizer.padding_side = "left"
102 | tokenizer.pad_token = tokenizer.eos_token
103 | tokenizer.pad_token_id = tokenizer.eos_token_id
104 | model_config = AutoConfig.from_pretrained(args.llm_dir)
105 | model = AutoModelForCausalLM.from_pretrained(args.llm_dir)
106 | model.to(device)
107 | model.eval()
108 |
109 | if 'gpt2' in args.llm_dir:
110 | max_context_len = 1024
111 | else:
112 | max_context_len = 2048
113 |
114 | # prepare dataset
115 | if args.dataset == 'sst2':
116 | AutoDataset = SST2Dataset
117 | elif args.dataset == 'subj':
118 | AutoDataset = SUBJDataset
119 | elif args.dataset == 'agnews':
120 | AutoDataset = AGNEWSDataset
121 | elif args.dataset == 'cb':
122 | AutoDataset = CBDataset
123 | elif args.dataset == 'cr':
124 | AutoDataset = CRDataset
125 | elif args.dataset == 'dbpedia':
126 | AutoDataset = DBPEDIADataset
127 | elif args.dataset == 'mpqa':
128 | AutoDataset = MPQADataset
129 | elif args.dataset == 'mr':
130 | AutoDataset = MRDataset
131 | elif args.dataset == 'rte':
132 | AutoDataset = RTEDataset
133 | elif args.dataset == 'trec':
134 | AutoDataset = TRECDataset
135 |
136 | dataset_dir = os.path.join(args.data_dir, args.dataset)
137 | train_data = AutoDataset(dataset_dir, mode='train')
138 | dev_data = AutoDataset(dataset_dir, mode='dev')
139 |
140 | # inference
141 | train_data.subsamplebyshot(args.n_train_shot, args.seed)
142 | logger.info(f"===== eval on {dev_data.__len__()} dev examples =====")
143 | prompt_prefix = make_prompt(train_data, args.dataset, mode='train')
144 | dev_labels = []
145 | dev_pred = []
146 | label2id = dev_data.label2id
147 | id2verb = train_data.id2verb
148 | for ins in tqdm(dev_data.data, total=dev_data.__len__()):
149 | dev_labels.append(label2id[ins['label']])
150 | prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference')
151 | gen_logits = llm_gen(model, prompt, tokenizer, max_context_len)
152 | dev_pred.append(parse_response(gen_logits, tokenizer, id2verb))
153 |
154 | dev_correct = [1 if dev_labels[i] == dev_pred[i] else 0 for i in range(len(dev_labels))]
155 | acc = sum(dev_correct) / len(dev_labels)
156 | logger.info(f"Acc: {acc}")
157 |
158 | # logging
159 | save_results_file = os.path.join(args.output_dir, 'results_icl.csv')
160 | csv_exists = os.path.isfile(save_results_file)
161 | with open(save_results_file, 'a+', newline='') as csvfile:
162 | csvwriter = csv.writer(csvfile)
163 | if not csv_exists:
164 | csvwriter.writerow(['dataset', 'llm', 'n_train_shot', 'seed', 'acc'])
165 | csvwriter.writerow([args.dataset,
166 | args.llm_dir,
167 | args.n_train_shot,
168 | args.seed,
169 | acc])
170 |
171 | if __name__ == "__main__":
172 | main()
--------------------------------------------------------------------------------
/kNNPrompting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BenfengXu/KNNPrompting/050d7e455113c0afa82de1537210007c34e96e57/kNNPrompting.png
--------------------------------------------------------------------------------
/knn_prompting.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from datetime import datetime
4 | from time import sleep
5 | import logging
6 | import argparse
7 | from tqdm import tqdm
8 | import csv
9 | import os
10 |
11 | import torch
12 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
13 |
14 | from utils.dataset import *
15 | from utils.template import *
16 | from utils.anchor import AnchorStore
17 |
18 |
19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description="KNN Prompting.")
25 | parser.add_argument(
26 | "--llm_dir",
27 | type=str,
28 | default=None,
29 | )
30 | parser.add_argument(
31 | "--data_dir",
32 | type=str,
33 | default=None,
34 | )
35 | parser.add_argument(
36 | "--dataset",
37 | type=str,
38 | default=None,
39 | )
40 | parser.add_argument(
41 | "--seed",
42 | type=int,
43 | default=None,
44 | )
45 | parser.add_argument(
46 | "--n_train_shot",
47 | type=int,
48 | default=None,
49 | )
50 | parser.add_argument(
51 | "--n_demo_shot",
52 | type=int,
53 | default=None,
54 | )
55 | parser.add_argument(
56 | "--n_anchor_shot",
57 | type=int,
58 | default=None,
59 | )
60 | parser.add_argument(
61 | "--knn",
62 | type=int,
63 | default=None,
64 | )
65 | parser.add_argument(
66 | "--output_dir",
67 | type=str,
68 | default=None,
69 | )
70 | args = parser.parse_args()
71 | if args.output_dir is not None:
72 | os.makedirs(args.output_dir, exist_ok=True)
73 |
74 | return args
75 |
76 |
77 | def llm_gen(model, prompt, tokenizer, max_context_len):
78 | inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=model.device)
79 | if inputs['input_ids'].shape[1] > max_context_len:
80 | inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:]
81 | inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:]
82 | with torch.no_grad():
83 | logits = model.forward(input_ids=inputs['input_ids'],
84 | attention_mask=inputs['attention_mask'],
85 | return_dict=True).logits.detach().cpu()
86 | # the output prob is shifted by -1, so we should use the output at the last input token position
87 | # gen_logits.shape = [1, 50257]
88 | gen_logits = logits[:, -1, :]
89 |
90 | return gen_logits
91 |
92 |
93 | def main():
94 | args = parse_args()
95 |
96 | args.n_anchor_shot = args.n_train_shot - args.n_demo_shot
97 | if args.n_anchor_shot <= 0:
98 | raise Exception("Num. of demonstration must be set smaller than num. of training.")
99 |
100 | args.knn = min(args.knn, args.n_anchor_shot) # knn can not exceed num. of anchors
101 |
102 | logging.basicConfig(
103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
104 | datefmt="%m/%d/%Y %H:%M:%S",
105 | level=logging.INFO,
106 | )
107 | logger.setLevel(logging.INFO)
108 |
109 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
110 |
111 | tokenizer = AutoTokenizer.from_pretrained(args.llm_dir, use_fast=False)
112 | # set pad token ids for batched inference cus gpt2 does not have one
113 | tokenizer.padding_side = "left"
114 | tokenizer.pad_token = tokenizer.eos_token
115 | tokenizer.pad_token_id = tokenizer.eos_token_id
116 | model_config = AutoConfig.from_pretrained(args.llm_dir)
117 | model = AutoModelForCausalLM.from_pretrained(args.llm_dir)
118 | model.to(device)
119 | model.eval()
120 |
121 | if 'gpt2' in args.llm_dir:
122 | max_context_len = 1024
123 | else:
124 | max_context_len = 2048
125 |
126 | # prepare dataset
127 | if args.dataset == 'sst2':
128 | AutoDataset = SST2Dataset
129 | elif args.dataset == 'subj':
130 | AutoDataset = SUBJDataset
131 | elif args.dataset == 'agnews':
132 | AutoDataset = AGNEWSDataset
133 | elif args.dataset == 'cb':
134 | AutoDataset = CBDataset
135 | elif args.dataset == 'cr':
136 | AutoDataset = CRDataset
137 | elif args.dataset == 'dbpedia':
138 | AutoDataset = DBPEDIADataset
139 | elif args.dataset == 'mpqa':
140 | AutoDataset = MPQADataset
141 | elif args.dataset == 'mr':
142 | AutoDataset = MRDataset
143 | elif args.dataset == 'rte':
144 | AutoDataset = RTEDataset
145 | elif args.dataset == 'trec':
146 | AutoDataset = TRECDataset
147 |
148 | datadir = os.path.join(args.data_dir, args.dataset)
149 | train_data = AutoDataset(datadir, mode='train')
150 | dev_data = AutoDataset(datadir, mode='dev')
151 |
152 | anchor_data = AutoDataset(datadir, mode='train')
153 |
154 | # Stage1: Meta Test
155 | train_data.subsamplebyshot(args.n_demo_shot, args.seed)
156 | prompt_prefix = make_prompt(train_data, args.dataset, mode='train')
157 | anchor_data.subsamplebyshot(args.n_anchor_shot, args.seed, exclude=train_data.data)
158 | label2id = dev_data.label2id
159 | id2verb = train_data.id2verb
160 | logger.info(f"===== build anchor store of {anchor_data.__len__()} anchor examples =====")
161 | anchor_store = AnchorStore(K=anchor_data.__len__(),
162 | dim=model_config.vocab_size,
163 | knn=args.knn,
164 | n_class=len(label2id))
165 | for ins in tqdm(anchor_data.data, total=anchor_data.__len__()):
166 | labels = label2id[ins['label']]
167 | prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference')
168 | gen_logits = llm_gen(model, prompt, tokenizer, max_context_len)
169 | anchor_store.enqueue(torch.softmax(gen_logits, dim=-1), torch.tensor(labels))
170 |
171 | # Stage2: Formal Test
172 | logger.info(f"===== eval on {dev_data.__len__()} dev examples =====")
173 | dev_labels = []
174 | dev_pred = []
175 | for ins in tqdm(dev_data.data, total=dev_data.__len__()):
176 | dev_labels.append(label2id[ins['label']])
177 | prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference')
178 | gen_logits = llm_gen(model, prompt, tokenizer, max_context_len)
179 | dev_pred.extend(anchor_store.knn_infer(torch.softmax(gen_logits, dim=-1)))
180 |
181 | dev_correct = [1 if dev_labels[i] == dev_pred[i] else 0 for i in range(len(dev_labels))]
182 | acc = sum(dev_correct) / len(dev_labels)
183 | logger.info(f"Acc: {acc}")
184 |
185 | # logging
186 | save_results_file = os.path.join(args.output_dir, 'results_knnprompting.csv'.format(args.dataset))
187 | csv_exists = os.path.isfile(save_results_file)
188 | with open(save_results_file, 'a+', newline='') as csvfile:
189 | csvwriter = csv.writer(csvfile)
190 | if not csv_exists:
191 | csvwriter.writerow(['dataset', 'llm', 'n_train_shot', 'n_demo_shot', 'n_anchor_shot', 'seed', 'knn', 'acc'])
192 | csvwriter.writerow([args.dataset,
193 | args.llm_dir,
194 | args.n_train_shot,
195 | args.n_demo_shot,
196 | args.n_anchor_shot,
197 | args.seed,
198 | args.knn,
199 | acc])
200 |
201 |
202 | if __name__ == "__main__":
203 | main()
204 |
--------------------------------------------------------------------------------
/llm/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/run_icl.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | LLM=gpt2-xl
4 | LLM_DIR=./llm/${LLM}
5 | DATA_DIR=./data/
6 | BATCHSIZE=1
7 |
8 | # Set maxshot w.r.t. context length
9 | if [[ "${LLM}" == "gpt2-xl" ]] || [[ "{$LLM}" == "gpt2-large" ]]; then
10 | # max context length = 1024
11 | array1=(mpqa) # maxshot = 32
12 | array2=(sst2) # maxshot = 16
13 | array3=(subj cr mr trec) # maxshot = 8
14 | array4=(rte) # maxshot = 4
15 | array5=(agnews cb) # maxshot = 2
16 | array6=(dbpedia) # maxshot = 1
17 | else
18 | # max context length = 2048
19 | array1=(sst2 mpqa)
20 | array2=(subj cr mr trec)
21 | array3=(rte)
22 | array4=(agnews cb)
23 | array5=(none)
24 | array6=(dbpedia)
25 | fi
26 |
27 | # for DATASET in sst2 subj mpqa agnews cb cr dbpedia mr rte trec; do
28 |
29 | DATASET=sst2
30 |
31 | if [[ "${array1[@]}" =~ "${DATASET}" ]]; then
32 | NSHOT=32
33 | elif [[ "${array2[@]}" =~ "${DATASET}" ]]; then
34 | NSHOT=16
35 | elif [[ "${array3[@]}" =~ "${DATASET}" ]]; then
36 | NSHOT=8
37 | elif [[ "${array4[@]}" =~ "${DATASET}" ]]; then
38 | NSHOT=4
39 | elif [[ "${array5[@]}" =~ "${DATASET}" ]]; then
40 | NSHOT=2
41 | else
42 | NSHOT=1
43 | fi
44 |
45 | for SEED in 1 2 3 4 5; do
46 |
47 | python3 icl.py \
48 | --llm_dir ${LLM_DIR} \
49 | --dataset ${DATASET} \
50 | --data_dir ${DATA_DIR} \
51 | --n_train_shot ${NSHOT} \
52 | --seed ${SEED} \
53 | --output_dir ./output
54 |
55 | done
--------------------------------------------------------------------------------
/run_knnprompting.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | LLM=gpt2-xl
4 | LLM_DIR=./llm/${LLM}
5 | DATA_DIR=./data/
6 |
7 | # Set max demonstration shot w.r.t. context length
8 | if [[ "${LLM}" == "gpt2-xl" ]] || [[ "{$LLM}" == "gpt2-large" ]]; then
9 | # max context length = 1024
10 | array1=(mpqa) # maxshot = 32
11 | array2=(sst2) # maxshot = 16
12 | array3=(subj cr mr trec) # maxshot = 8
13 | array4=(rte) # maxshot = 4
14 | array5=(agnews cb) # maxshot = 2
15 | array6=(dbpedia) # maxshot = 1
16 | else
17 | # max context length = 2048
18 | array1=(sst2 mpqa)
19 | array2=(subj cr mr trec)
20 | array3=(rte)
21 | array4=(agnews cb)
22 | array5=(none)
23 | array6=(dbpedia)
24 | fi
25 |
26 | # for DATASET in sst2 subj mpqa agnews cb cr dbpedia mr rte trec; do
27 | DATASET=sst2
28 |
29 | if [[ "${array1[@]}" =~ "${DATASET}" ]]; then
30 | N_DEMO_SHOT=32
31 | elif [[ "${array2[@]}" =~ "${DATASET}" ]]; then
32 | N_DEMO_SHOT=16
33 | elif [[ "${array3[@]}" =~ "${DATASET}" ]]; then
34 | N_DEMO_SHOT=8
35 | elif [[ "${array4[@]}" =~ "${DATASET}" ]]; then
36 | N_DEMO_SHOT=4
37 | elif [[ "${array5[@]}" =~ "${DATASET}" ]]; then
38 | N_DEMO_SHOT=2
39 | else
40 | N_DEMO_SHOT=1
41 | fi
42 |
43 | N_TRAIN_SHOT=1024
44 | KNN=3
45 | for SEED in 1 2 3 4 5; do
46 |
47 | python3 knn_prompting.py \
48 | --llm_dir ${LLM_DIR} \
49 | --dataset ${DATASET} \
50 | --data_dir ${DATA_DIR} \
51 | --n_train_shot ${N_TRAIN_SHOT} \
52 | --n_demo_shot ${N_DEMO_SHOT} \
53 | --seed ${SEED} \
54 | --output_dir ./output \
55 | --knn ${KNN}
56 |
57 | done
58 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BenfengXu/KNNPrompting/050d7e455113c0afa82de1537210007c34e96e57/utils/__init__.py
--------------------------------------------------------------------------------
/utils/anchor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class AnchorStore(nn.Module):
6 |
7 | def __init__(self, K=1024, dim=50257, knn=1, n_class=2):
8 | super(AnchorStore, self).__init__()
9 |
10 | self.register_buffer("queue_anchor", torch.randn(K, dim))
11 | self.register_buffer("queue_label", torch.zeros(K, dtype=torch.long))
12 | self.queue_label.fill_(-1)
13 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
14 | self.knn = knn
15 | self.n_class = n_class
16 |
17 | def enqueue(self, anchors, labels):
18 |
19 | ptr = int(self.queue_ptr)
20 | bs = anchors.shape[0]
21 |
22 | self.queue_anchor[ptr:ptr + bs, :] = anchors
23 | self.queue_label[ptr:ptr + bs] = labels
24 | self.queue_ptr[0] = ptr + bs
25 |
26 | def knn_infer(self, query):
27 |
28 | # kl_div.shape = [1, len(self.queue_anchor)]
29 | kl_distance = torch.mean(self.queue_anchor[:, None, :] * (self.queue_anchor[:, None, :].log() - query.log()), dim=2).transpose(1, 0)
30 | if self.knn == 1:
31 | # directly return the nearest neighbor
32 | return self.queue_label[kl_distance.argmin(dim=1)].tolist()
33 | else:
34 | values, indices = torch.topk(kl_distance, self.knn, dim=1, largest=False)
35 | # count for each category within k nearest neighbors, and return the dominant category
36 | # knn_cnt.shape = [1, self.n_class]
37 | knn_cnt = torch.zeros((query.shape[0], self.n_class))
38 | for i in range(self.n_class):
39 | knn_cnt[:, i] = (self.queue_label[indices] == i).sum(dim=1)
40 | return knn_cnt.argmax(dim=1).tolist()
41 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import codecs
3 | import random
4 | from pathlib import Path
5 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
6 | import os
7 | import math
8 |
9 | import torch
10 | from torch.utils.data import Dataset, Sampler
11 |
12 |
13 | class BASEDataset:
14 | def __init__(
15 | self,
16 | data_dir,
17 | mode
18 | ):
19 | """data key: sentence, label[0/1]"""
20 | super().__init__()
21 | if mode == 'dev':
22 | mode = 'dev_subsample'
23 | data_file = os.path.join(data_dir, mode + '.jsonl')
24 | self.data = []
25 | with open(data_file, 'r') as f:
26 | read_lines = f.readlines()
27 | for line in read_lines:
28 | instance = json.loads(line.strip())
29 | self.data.append(instance)
30 | # customize your own label map in inheritance
31 | self.id2label = {0: 'negative', 1: 'positive'}
32 | self.label2id = {'negative': 0, 'positive': 1}
33 |
34 | def __len__(self):
35 | return len(self.data)
36 |
37 | def subsamplebyshot(self, n_shot, seed, exclude=None):
38 | # exclude
39 | if exclude is not None:
40 | for ins in exclude:
41 | self.data.remove(ins)
42 | # aggregate data by each category
43 | random.seed(seed)
44 | data_by_cls = {}
45 | for i in range(self.__len__()):
46 | if self.label2id[self.data[i]['label']] not in data_by_cls:
47 | data_by_cls[self.label2id[self.data[i]['label']]] = []
48 | data_by_cls[self.label2id[self.data[i]['label']]].append(self.data[i])
49 | # evenly sample n examples from each category
50 | data_subsample = []
51 | for cls in data_by_cls.keys():
52 | data_subsampled_by_cls = random.sample(data_by_cls[cls], min(n_shot, len(data_by_cls[cls])))
53 | data_subsample.extend(data_subsampled_by_cls)
54 | random.shuffle(data_subsample)
55 | self.data = data_subsample
56 |
57 |
58 | class SST2Dataset(BASEDataset):
59 | def __init__(
60 | self,
61 | data_dir,
62 | mode
63 | ):
64 | """data key: sentence, label[0/1]"""
65 | super().__init__(data_dir, mode)
66 | self.label2id = {'0': 0, '1': 1}
67 | self.label2verb = {'0': 'negative', '1': 'positive'}
68 | self.id2verb = ['negative', 'positive']
69 |
70 |
71 | class SUBJDataset(BASEDataset):
72 | def __init__(
73 | self,
74 | data_dir,
75 | mode
76 | ):
77 | """data key: sentence, label[0/1]"""
78 | super().__init__(data_dir, mode)
79 | # subj only has test set
80 | self.label2id = {'0': 0, '1': 1}
81 | self.label2verb = {'0': 'subjective', '1': 'objective'}
82 | self.id2verb = ['subjective', 'objective']
83 |
84 |
85 | class AGNEWSDataset(BASEDataset):
86 | def __init__(
87 | self,
88 | data_dir,
89 | mode
90 | ):
91 | """data key: sentence, label[0/1]"""
92 | super().__init__(data_dir, mode)
93 | self.label2id = {'1': 0, '2': 1, '3': 2, '4': 3}
94 | self.label2verb = {'1': 'world', '2': 'sports', '3': 'business', '4': 'technology'}
95 | self.id2verb = ['world', 'sports', 'business', 'technology']
96 |
97 |
98 | class CBDataset(BASEDataset):
99 | def __init__(
100 | self,
101 | data_dir,
102 | mode
103 | ):
104 | """data key: sentence, label[0/1]"""
105 | super().__init__(data_dir, mode)
106 | self.label2id = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
107 | self.label2verb = {'contradiction': 'false', 'entailment': 'true', 'neutral': 'neither'}
108 | self.id2verb = ['false', 'true', 'neither']
109 |
110 |
111 | class CRDataset(BASEDataset):
112 | def __init__(
113 | self,
114 | data_dir,
115 | mode
116 | ):
117 | """data key: sentence, label[0/1]"""
118 | super().__init__(data_dir, mode)
119 | self.label2id = {'0': 0, '1': 1}
120 | self.label2verb = {'0': 'negative', '1': 'positive'}
121 | self.id2verb = ['negative', 'positive']
122 |
123 |
124 | class DBPEDIADataset(BASEDataset):
125 | def __init__(
126 | self,
127 | data_dir,
128 | mode
129 | ):
130 | """data key: sentence, label[0/1]"""
131 | if mode == 'dev':
132 | mode = 'dev_subsample'
133 | else:
134 | mode = 'train_subset' # this is an exception case
135 | super().__init__(data_dir, mode)
136 | self.label2id = {'1': 0, '2': 1, '3': 2, '4': 3, '5': 4,
137 | '6': 5, '7': 6, '8': 7, '9': 8, '10': 9,
138 | '11': 10, '12': 11, '13': 12, '14': 13}
139 | self.label2verb = {'1': 'company', '2': 'school', '3': 'artist', '4': 'athlete', '5': 'politics',
140 | '6': 'transportation', '7': 'building', '8': 'nature', '9': 'village', '10': 'animal',
141 | '11': 'plant', '12': 'album', '13': 'film', '14': 'book'}
142 | self.id2verb = ['company', 'school', 'artist', 'athlete', 'politics',
143 | 'transportation', 'building', 'nature', 'village', 'animal',
144 | 'plant', 'album', 'film', 'book']
145 |
146 |
147 | class MPQADataset(BASEDataset):
148 | def __init__(
149 | self,
150 | data_dir,
151 | mode
152 | ):
153 | """data key: sentence, label[0/1]"""
154 | super().__init__(data_dir, mode)
155 | self.label2id = {'0': 0, '1': 1}
156 | self.label2verb = {'0': 'negative', '1': 'positive'}
157 | self.id2verb = ['negative', 'positive']
158 |
159 |
160 | class MRDataset(BASEDataset):
161 | def __init__(
162 | self,
163 | data_dir,
164 | mode
165 | ):
166 | """data key: sentence, label[0/1]"""
167 | super().__init__(data_dir, mode)
168 | self.label2id = {'0': 0, '1': 1}
169 | self.label2verb = {'0': 'negative', '1': 'positive'}
170 | self.id2verb = ['negative', 'positive']
171 |
172 |
173 | class RTEDataset(BASEDataset):
174 | def __init__(
175 | self,
176 | data_dir,
177 | mode
178 | ):
179 | """data key: sentence, label[0/1]"""
180 | super().__init__(data_dir, mode)
181 | self.label2id = {'not_entailment': 0, 'entailment': 1}
182 | self.label2verb = {'not_entailment': 'false', 'entailment': 'true'}
183 | self.id2verb = ['false', 'true']
184 |
185 |
186 | class SST5Dataset(BASEDataset):
187 | def __init__(
188 | self,
189 | data_dir,
190 | mode
191 | ):
192 | """data key: sentence, label[0/1]"""
193 | super().__init__(data_dir, mode)
194 | self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4}
195 | self.label2verb = {'0': 'terrible', '1': 'bad', '2': 'okay', '3': 'good', '4': 'great'}
196 | self.id2verb = ['terrible', 'bad', 'okay', 'good', 'great']
197 |
198 |
199 | class TRECDataset(BASEDataset):
200 | def __init__(
201 | self,
202 | data_dir,
203 | mode
204 | ):
205 | """data key: sentence, label[0/1]"""
206 | super().__init__(data_dir, mode)
207 | self.label2id = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5}
208 | self.label2verb = {'0': 'description', '1': 'entity', '2': 'expression', '3': 'human','4': 'location', '5': 'number'}
209 | self.id2verb = ['description', 'entity', 'expression', 'human', 'location', 'number']
--------------------------------------------------------------------------------
/utils/template.py:
--------------------------------------------------------------------------------
1 | def make_prompt(dataset, dataset_name, mode, indices=None):
2 | if dataset_name == 'sst2':
3 | template_func = template_sst2
4 | elif dataset_name == 'subj':
5 | template_func = template_subj
6 | elif dataset_name == 'agnews':
7 | template_func = template_agnews
8 | elif dataset_name == 'cb':
9 | template_func = template_cb
10 | elif dataset_name == 'cr':
11 | template_func = template_cr
12 | elif dataset_name == 'dbpedia':
13 | template_func = template_dbpedia
14 | elif dataset_name == 'mpqa':
15 | template_func = template_mpqa
16 | elif dataset_name == 'mr':
17 | template_func = template_mr
18 | elif dataset_name == 'rte':
19 | template_func = template_rte
20 | elif dataset_name == 'sst5':
21 | template_func = template_sst5
22 | elif dataset_name == 'trec':
23 | template_func = template_trec
24 | if mode == 'inference':
25 | return template_func(dataset, None, mode)
26 | prompt = ''
27 | if mode == 'compose': # inputs are different, list of examples instead of dataset class
28 | for ins in dataset.index(indices):
29 | prompt += template_func(ins, dataset.label2verb[ins['label']], 'train')
30 | prompt += '\n'
31 | return prompt
32 | for ins in dataset.data:
33 | prompt += template_func(ins, dataset.label2verb[ins['label']], mode)
34 | prompt += '\n'
35 | return prompt
36 |
37 |
38 | def template_sst2(ins, label, mode):
39 | if mode == 'train':
40 | return f"Review: {ins['sentence']}\nSentiment: {label}\n"
41 | else:
42 | return f"Review: {ins['sentence']}\nSentiment:"
43 |
44 |
45 | def template_subj(ins, label, mode):
46 | if mode == 'train':
47 | return f"Input: {ins['sentence']}\nType: {label}\n"
48 | else:
49 | return f"Input: {ins['sentence']}\nType:"
50 |
51 |
52 | def template_agnews(ins, label, mode):
53 | if mode == 'train':
54 | return f"input: {ins['sentence']}\ntype: {label}\n"
55 | else:
56 | return f"input: {ins['sentence']}\ntype:"
57 |
58 |
59 | def template_cb(ins, label, mode):
60 | if mode == 'train':
61 | return f"premise: {ins['premise']}\nhypothesis: {ins['hypothesis']}\nprediction: {label}\n"
62 | else:
63 | return f"premise: {ins['premise']}\nhypothesis: {ins['hypothesis']}\nprediction:"
64 |
65 |
66 | def template_cr(ins, label, mode):
67 | if mode == 'train':
68 | return f"Review: {ins['sentence']}\nSentiment: {label}\n"
69 | else:
70 | return f"Review: {ins['sentence']}\nSentiment:"
71 |
72 |
73 | def template_dbpedia(ins, label, mode):
74 | if mode == 'train':
75 | return f"input: {ins['sentence']}\ntype: {label}\n"
76 | else:
77 | return f"input: {ins['sentence']}\ntype:"
78 |
79 |
80 | def template_mpqa(ins, label, mode):
81 | if mode == 'train':
82 | return f"Review: {ins['sentence']}\nSentiment: {label}\n"
83 | else:
84 | return f"Review: {ins['sentence']}\nSentiment:"
85 |
86 |
87 | def template_mr(ins, label, mode):
88 | if mode == 'train':
89 | return f"Review: {ins['sentence']}\nSentiment: {label}\n"
90 | else:
91 | return f"Review: {ins['sentence']}\nSentiment:"
92 |
93 |
94 | def template_rte(ins, label, mode):
95 | if mode == 'train':
96 | return f"premise: {ins['sentence_1']}\nhypothesis: {ins['sentence_2']}\nprediction: {label}\n"
97 | else:
98 | return f"premise: {ins['sentence_1']}\nhypothesis: {ins['sentence_2']}\nprediction:"
99 |
100 |
101 | def template_sst5(ins, label, mode):
102 | if mode == 'train':
103 | return f"Review: {ins['sentence']}\nSentiment: {label}\n"
104 | else:
105 | return f"Review: {ins['sentence']}\nSentiment:"
106 |
107 |
108 | def template_trec(ins, label, mode):
109 | if mode == 'train':
110 | return f"Question: {ins['sentence']}\nType: {label}\n"
111 | else:
112 | return f"Question: {ins['sentence']}\nType:"
113 |
114 |
115 | def sent_sim_template(ins, dataset_name):
116 | # ['sst2', 'subj', 'mpqa', 'agnews', 'cb', 'cr', 'dbpedia', 'mr', 'rte', 'trec']
117 | if dataset_name in ['sst2', 'subj', 'mpqa', 'agnews', 'cr', 'dbpedia', 'mr', 'trec']:
118 | return ins['sentence']
119 | elif dataset_name in ['cb']:
120 | return f"premise: {ins['premise']}. hypothesis: {ins['hypothesis']}"
121 | elif dataset_name in ['rte']:
122 | return f"premise: {ins['sentence_1']}. hypothesis: {ins['sentence_2']}"
123 |
--------------------------------------------------------------------------------